-
Notifications
You must be signed in to change notification settings - Fork 270
Expand file tree
/
Copy pathrun_esmfold.py
More file actions
111 lines (90 loc) · 4.53 KB
/
run_esmfold.py
File metadata and controls
111 lines (90 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script is based on https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb
import os
import habana_frameworks.torch.core as htcore
import torch
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers.models.esm.openfold_utils.protein import Protein as OFProtein
from transformers.models.esm.openfold_utils.protein import to_pdb
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from optimum.habana.utils import HabanaGenerationTime
os.environ["PT_HPU_ENABLE_H2D_DYNAMIC_SLICE"] = "0"
os.environ["PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES"] = "1"
try:
from optimum.habana.utils import check_optimum_habana_min_version
except ImportError:
def check_optimum_habana_min_version(*a, **b):
return ()
# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks.
check_optimum_habana_min_version("1.19.0.dev0")
def convert_outputs_to_pdb(outputs):
"""
Converts the model outputs to a PDB file.
This code comes from the original ESMFold repo, and uses some functions from openfold that have been ported to Transformers.
"""
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = outputs["atom37_atom_exists"]
pdbs = []
for i in range(outputs["aatype"].shape[0]):
aa = outputs["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = outputs["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=outputs["plddt"][i],
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
)
pdbs.append(to_pdb(pred))
return pdbs
adapt_transformers_to_gaudi()
steps = 4
device = torch.device("hpu")
# This is the sequence for human GNAT1.
# Feel free to substitute your own peptides of interest
# Depending on memory constraints you may wish to use shorter sequences.
test_protein = "MGAGASAEEKHSRELEKKLKEDAEKDARTVKLLLLGAGESGKSTIVKQMKIIHQDGYSLEECLEFIAIIYGNTLQSILAIVRAMTTLNIQYGDSARQDDARKLMHMADTIEEGTMPKEMSDIIQRLWKDSGIQACFERASEYQLNDSAGYYLSDLERLVTPGYVPTEQDVLRSRVKTTGIIETQFSFKDLNFRMFDVGGQRSERKKWIHCFEGVTCIIFIAALSAYDMVLVEDDEVNRMHESLHLFNSICNHRYFATTSIVLFLNKKDVFFEKIKKAHLSICFPDYDGPNTYEDAGNYIKVQFLELNMRRDVKEIYSHMTCATDTQNVKFVFDAVTDIIIKENLKDCGLF" # len = 350
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
# Set _supports_param_buffer_assignment to False since facebook/esmfold_v1's encoder weights are float16.
# Without this fix, we will have the weights loaded with float16 on gaudi2,gaudi3 and runtime error on gaudi1
EsmForProteinFolding._supports_param_buffer_assignment = False
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=False)
model = model.to(device)
# Uncomment this line if you're folding longer (over 600 or so) sequences
model.trunk.set_chunk_size(64)
with torch.no_grad():
tk = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)
tokenized_input = tk["input_ids"]
print(f"ESMFOLD: input shape = {tokenized_input.shape}")
tokenized_input = tokenized_input.to(device)
for batch in range(steps):
print(f"ESMFOLD: step {batch} start ...")
with HabanaGenerationTime() as timer:
output = model(tokenized_input)
htcore.mark_step()
print(f"ESMFOLD: step {batch} duration: {timer.last_duration:.03f} seconds")
pdb = convert_outputs_to_pdb(output)
pdb_file = "save-hpu.pdb"
with open(pdb_file, "w") as fout:
fout.write(pdb[0])
print(f"pdb file saved in {pdb_file}")