import os
os.environ["HF_HOME"] = '/workspace/data/transformers_cache'

import torch
from torch.nn import functional as F
from deli import load, save_json, save
from tqdm import tqdm
import numpy as np

# from biotite.database import rcsb
from biotite.structure import annotate_sse

from esm.utils.structure.protein_chain import ProteinChain
# from esm.sdk.api import ESMProtein
# from esm.utils.types import FunctionAnnotation
from esm.pretrained import ESM3_sm_open_v0, ESM3_structure_encoder_v0
# from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization import get_model_tokenizers

from huggingface_hub import login


def get_approximate_ss(protein_chain: ProteinChain):
    # get biotite's ss3 representation
    ss3_arr = annotate_sse(protein_chain.atom_array)
    biotite_ss3_str = ''.join(ss3_arr)

    # translate into ESM3's representation
    translation_table = str.maketrans({
        'a': 'H', # alpha helix
        'b': 'E', # beta sheet
        'c': 'C', # coil
    })
    esm_ss3 = biotite_ss3_str.translate(translation_table)
    return esm_ss3


if __name__ == '__main__':
    access_token = TOKEN # Change to your РА profile
    login(access_token)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    tokenizers = get_model_tokenizers()
    encoder = ESM3_structure_encoder_v0(device)
    model = ESM3_sm_open_v0(device)


    TASK_NAME = 'secondary_structure_pdb'
    is_residue = True

    # TASK_NAME = 'humanppi'
    # is_residue = False

    # id2seq = load(f'/workspace/data/downstream_tasks/{TASK_NAME}/id2seq_real.json')
    # save_emb_path = f'/workspace/data/downstream_tasks/downstream_datasets/{TASK_NAME}/esm3/'

    # id2seq = load(f'/workspace/data/downstream_tasks/{TASK_NAME}/id2seq_real.json')
    id2seq = load(f'/workspace/data/downstream_tasks/secondary_structure/id2seq_real_pdb.json')
    # id2seq = load(f'/workspace/data/downstream_tasks/{TASK_NAME}/id2seq_real_nopairs.json')
    save_emb_path = f'/workspace/data/docking/downstream_tasks/downstream_datasets/{TASK_NAME}/esm3/'

    os.makedirs(save_emb_path, exist_ok=True)

    for split in id2seq.keys():
        if split in ['casp12', 'cb513']:
            continue
        embeddings = []
        all_names = list(id2seq[split].keys())
        for uid in tqdm(all_names):
            # uid = uid.split('-')[1]

            # pdb_path = f"/workspace/data/downstream_tasks/downstream_structures/AF-{uid}-F1-model_v4.pdb"
            # pdb_path = f"/workspace/data/docking/downstream_tasks/outputs/{uid}.pdb"
            pdb_path = f"/workspace/data/downstream_tasks/secondary_structure/processed_structures/ss_structures_processed/{uid}.pdb"
            protein_chain = ProteinChain.from_pdb(pdb_path)

            coords, plddt, residue_index = protein_chain.to_structure_encoder_inputs()
            coords = coords.to(device)
            plddt = plddt.to(device)
            residue_index = residue_index.to(device)
            _, structure_tokens = encoder.encode(coords, residue_index=residue_index)

            sequence = protein_chain.sequence
            tokens = tokenizers.sequence.encode(sequence)
            sequence_tokens = torch.tensor(tokens, dtype=torch.int64, device=device)

            ssp = get_approximate_ss(protein_chain)
            ssp_tokens = tokenizers.secondary_structure.encode(ssp)
            ssp_tokens = torch.tensor(ssp_tokens, dtype=torch.int64, device=device)

            # Add BOS/EOS padding
            coords = F.pad(coords, (0, 0, 0, 0, 1, 1), value=torch.inf)
            plddt = F.pad(plddt, (1, 1), value=0)
            structure_tokens = F.pad(structure_tokens, (1, 1), value=0)
            structure_tokens[:, 0] = 4098
            structure_tokens[:, -1] = 4097

            output = model.forward(
                sequence_tokens=sequence_tokens[None, ],
                ss8_tokens=ssp_tokens[None, ],
                structure_coords=coords, 
                per_res_plddt=plddt, 
                structure_tokens=structure_tokens
            )

            if is_residue:
                protein_embedding = output.embeddings[0, 1:-1].detach().cpu().numpy()
            else:
                protein_embedding = output.embeddings[0, 1:-1].detach().cpu().numpy().mean(axis=0)
            embeddings.append(protein_embedding)

        if is_residue:
            names = [f'{name}_{i}' for name, seq in id2seq[split].items()
                     for i in range(len(seq))]
            embeddings = np.concatenate(embeddings, axis=0)
        else:
            embeddings = np.stack(embeddings)

        print(split, type(embeddings), embeddings.shape)
        save_json(all_names, os.path.join(save_emb_path, f'{split}_names.json'))
        save(embeddings, os.path.join(save_emb_path, f'{split}_avg_embeddings.npy.gz'), compression=1)
