import os, sys
import shutil
import gradio as gr
import tempfile

sys.path.insert(0,'./atomflow/')
from omegaconf import DictConfig
from gm import (
    FlowMatching,
)
from model.denoise import Denoise
from data.data_pipeline import (
    AllAtomDataPipeline,
    make_complex_feature
)
from data.feature_pipeline import AllAtomFeaturePipeline
import torch
from atomflow_utils import get_generation_pdb

cfg=DictConfig({'model': {'c_s': 256, 'c_z': 128, 'c_t': 64, 'c_m': 256, 'dropout': 0.0, 'enable_sc': True, 'enable_ligand_hint': False, 'enable_sequence': False, 'core_stack_sel': 'original', 'distogram': {'input': {'splits': [1, 3.25, 50.75], 'steps': [46, 39]}, 'output': {'splits': [1.0125, 2.3125, 21.6875], 'steps': [27, 63]}}, 'feature_embedder': {'t_emb_dim': 64, 'seq_dim': 147, 'pair_dim': 5, 'c_z': '${model.c_z}', 'c_m': '${model.c_m}', 'relpos_k': 32, 'distogram_params': '${model.distogram.input}', 'enable_sc': '${model.enable_sc}', 'enable_sequence': '${model.enable_sequence}'}, 'input_embedder': {'c_t': '${model.c_t}', 'distogram_params': '${model.distogram.input}'}, 'input_pair_stack': {'c_t': '${model.c_t}', 'c_hidden_tri_att': 16, 'c_hidden_tri_mul': 64, 'no_blocks': 2, 'no_heads': 4, 'pair_transition_n': 2, 'dropout_rate': 0.25, 'blocks_per_ckpt': 1}, 'input_pointwise_attn': {'c_t': '${model.c_t}', 'c_z': '${model.c_z}', 'c_hidden': 16, 'no_heads': 4, 'inf': 100000}, 'z_pair_stack': {'c_t': '${model.c_z}', 'c_hidden_tri_att': 16, 'c_hidden_tri_mul': 64, 'no_blocks': 16, 'no_heads': 4, 'pair_transition_n': 2, 'dropout_rate': 0.25, 'blocks_per_ckpt': 1}, 'transformer_stack': {'c_m': '${model.c_s}', 'c_z': '${model.c_z}', 'c_hidden_msa_att': 32, 'c_hidden_opm': 32, 'c_hidden_mul': 128, 'c_hidden_pair_att': 32, 'c_s': '${model.c_s}', 'no_heads_msa': 8, 'no_heads_pair': 4, 'no_blocks': 14, 'transition_n': 4, 'msa_dropout': 0.15, 'pair_dropout': 0.25, 'blocks_per_ckpt': 1, 'no_column_attention': True, 'inf': 1000000000.0, 'eps': 1e-10}, 'ipa_structure': {'c_s': '${model.c_s}', 'c_z': '${model.c_z}', 'c_ipa': 16, 'c_resnet': 128, 'no_heads_ipa': 12, 'no_qk_points': 4, 'no_v_points': 8, 'dropout_rate': 0.1, 'no_blocks': 4, 'no_transition_layers': 1, 'no_resnet_blocks': 2, 'no_angles': 1, 'trans_scale_factor': 10, 'epsilon': 1e-13, 'inf': 100000, 'shared_weight': True}, 'distogram_head': {'c_z': '${model.c_z}', 'no_bins': 90}}, 'gm': {'a': 0.20775623268698062, 'loss': {'violation': {'violation_tolerance_factor': 0.0, 'clash_overlap_tolerance': 0.5, 'weight': 0.0}, 'fape': {'weight': 0.0, 'mol_weight': 0.0, 'prot_weight': 0.0, 'inter_weight': 0.0}, 'distogram': {'distogram_params': '${model.distogram.output}', 'weight': 0.0, 'mol_weight': 0.0, 'prot_weight': 0.0, 'inter_weight': 0.0}, 'dist': {'max_cutoff': 0, 'weight': 0.0}, 'mse': {'weight': 0.0}, 'mse_star': {'weight': 0.0}, 'smoothlddt': {'weight': 0.0}, 't_max': 0.0}, 'no_noise': False}, 'data': {'common': {'feat': {'token_type': ['NUM_TOKEN'], 'target_feat': ['NUM_TOKEN', None], 'entity_type': ['NUM_TOKEN', None], 'all_atom_mask': ['NUM_TOKEN', None], 'all_atom_positions': ['NUM_TOKEN', None, None], 'alt_chi_angles': ['NUM_TOKEN', None], 'atom14_alt_gt_exists': ['NUM_TOKEN', None], 'atom14_alt_gt_positions': ['NUM_TOKEN', None, None], 'atom14_atom_exists': ['NUM_TOKEN', None], 'atom14_atom_is_ambiguous': ['NUM_TOKEN', None], 'atom14_gt_exists': ['NUM_TOKEN', None], 'atom14_gt_positions': ['NUM_TOKEN', None, None], 'atomFull_atom_exists': ['NUM_TOKEN', None], 'backbone_rigid_mask': ['NUM_TOKEN'], 'backbone_rigid_tensor': ['NUM_TOKEN', None, None], 'is_distillation': [], 'no_recycling_iters': [], 'pseudo_beta': ['NUM_TOKEN', None], 'pseudo_beta_mask': ['NUM_TOKEN'], 'token_index': ['NUM_TOKEN'], 'residx_atom14_to_atomFull': ['NUM_TOKEN', None], 'residx_atomFull_to_atom14': ['NUM_TOKEN', None], 'resolution': [], 'rigidgroups_alt_gt_frames': ['NUM_TOKEN', None, None, None], 'rigidgroups_group_exists': ['NUM_TOKEN', None], 'rigidgroups_group_is_ambiguous': ['NUM_TOKEN', None], 'rigidgroups_gt_exists': ['NUM_TOKEN', None], 'rigidgroups_gt_frames': ['NUM_TOKEN', None, None, None], 'seq_length': [], 'seq_mask': ['NUM_TOKEN'], 'extra_feat': ['NUM_TOKEN', None], 'pair_feat': ['NUM_TOKEN', 'NUM_TOKEN', None], 'torsion_angles_sin_cos': ['NUM_TOKEN', 7, 2], 'alt_torsion_angles_sin_cos': ['NUM_TOKEN', 7, 2], 'torsion_angles_mask': ['NUM_TOKEN', 7], 'fape_frame_idx': ['NUM_TOKEN', 3], 'use_clamped_fape': [], 'edges': ['NUM_TOKEN', 'NUM_TOKEN', None]}, 'unsupervised_features': ['token_type', 'target_feat', 'token_index', 'entity_type', 'seq_length', 'fape_frame_idx', 'edges']}, 'supervised': {'supervised_features': ['all_atom_mask', 'all_atom_positions', 'extra_feat', 'pair_feat', 'torsion_angles_sin_cos', 'alt_torsion_angles_sin_cos', 'torsion_angles_mask', 'use_clamped_fape'], 'clamp_prob': 0.9}, 'train': {'supervised': True, 'crop_size': 512, 'distillation_prob': 0.0}, 'eval': {'supervised': True, 'crop_size': 512, 'distillation_prob': 0.0}, 'predict': {'supervised': True, 'crop_size': 512, 'distillation_prob': 0.0}}, 'data_module': {'data_loaders': {'batch_size': 1, 'num_workers': 0}, 'options': {'train_epoch_len': 10000, 'pickle_dir': 'xxx/cache_data/entity_feature_list/', 'train_data_index_list': [{'path': 'data_index/test_data_index/train_data_index.json', 'data_dir': 'xxx', 'prob': 0.0}], 'eval_data_index_path': 'xxx', 'eval_data_dir': 'xxx', 'predict_data_index_path': 'xxx', 'predict_data_dir': 'xxx', 'batch_seed': 0}}, 'diffuse_mol': True, 'learning_rate': 0.0, 'scheduler': None, 'project_name': 'fape_fm_binder_design', 'version': 'none', 'log_root': 'xxx', 'ckpt_file': 'none', 'lr': 'none', 'use_ema': False, 'noise_type': 'gaussian', 'fast_dev_run': False, 'pickle_version': 'p_00007', 'accumulate_grad_batches': 2, 'enable_scale': 'noise', 'sigma_data': 10, 'use_eval_as_predict': False, 'result_folder': 'xxx', 'use_deepspeed_evo_attention': False})

def load_data(smiles, length, cfg, fm):
    data_pipeline = AllAtomDataPipeline()
    feature_pipeline = AllAtomFeaturePipeline(
        config=cfg.get("data")
    )
    entity_feat_list = data_pipeline.process_data(
        name="EXP",
        current_data={
            "entities": [
                { "is_target": True, "length": length, "type": "protein" },
                { "path": [f"{smiles}.smiles"], "type": "molecule" }
            ], 
            "is_distillation": False,
            "resolution": 1.0
        },
        data_dir="",
        mode="predict"
    )
    data = make_complex_feature(
        entity_feat_list=entity_feat_list,
    )
    feats = feature_pipeline.process_features(
        data, "predict"
    )
    feats = fm.data_postprocess(feats, 0, "predict")
    return feats

def predict(smiles, length):
    torch.set_printoptions(sci_mode=False)

    model = Denoise(cfg)
    fm = FlowMatching(model=model, cfg=cfg)

    state_dict = torch.load("checkpoint/atomflow.pt")
    fm.load_state_dict(state_dict, strict=True)
    
    batch = load_data(smiles, length, cfg, fm)
    for item in batch:
        batch[item] = batch[item][None, ...]
        batch[item] = batch[item].to('cuda:0')
    with torch.no_grad():
        model = model.to('cuda:0')
        model.eval()
        result = fm.sample(batch)

    return get_generation_pdb(batch, result)

def molecule(mol):
    x = (
        """<!DOCTYPE html><head>
<meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0" />

    <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/npm/pdbe-molstar@3.2.0/build/pdbe-molstar.css" />
    <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/pdbe-molstar@3.2.0/build/pdbe-molstar-plugin.js" ></script>

    <style>
      #myViewer {
        float: left;
        width: 400px;
        height: 300px;
        position: relative;
        margin: 20px;
      }
    </style>
  </head>

  <body>
    <div id="myViewer"></div>

    <script>
      const viewerInstance = new PDBeMolstarPlugin();
      const blob = new Blob([`""" + mol.replace('\n', '\\n') + """`], { type: "text/plain" });
      const options = {
        customData: {
          url: URL.createObjectURL(blob),
          format: "pdb",
          binary: false,
        },
        bgColor: { r: 255, g: 255, b: 255 },
        hideControls: true,
      };

      const viewerContainer = document.getElementById("myViewer");

      viewerInstance.render(viewerContainer, options);
    </script>
  </body>
"""
    )

    return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""

def update(smiles, length):
    try:
        pdb = predict(smiles, length)
        return molecule(pdb), pdb
    except Exception as err:
        return err, ""
    
def get_down_link(pdb):
    temp = tempfile.NamedTemporaryFile("w", prefix="AtomFlow_", suffix=".pdb", delete=False)
    temp.write(pdb)
    temp.close()
    return temp.name

demo = gr.Blocks()

with open("example_smiles.txt", "r") as f:
    lines = f.readlines()

examples = {}
smiles_to_example = {}
example_names = []
for line in lines:
    name, _, smiles = line.split(',')
    examples[name] = smiles
    smiles_to_example[smiles] = name
    example_names.append(name)

with open("example_pdb.pdb", "r") as f:
    example_result = '\n'.join(f.readlines())

with demo:
    gr.Markdown("# AtomFlow: Design a ligand-binding protein from SMILES")
    with gr.Row():
        with gr.Column():
            radio = gr.Radio(label="Example Ligands", choices=example_names, value="OQO", interactive=True)
            inp = gr.Textbox(
                placeholder="SMILES", label="Input Ligand SMILES", value=examples[radio.value]
            )
            inp_l = gr.Slider(
                100, 225, value=100, label="Residue Count of the Designed Binder", info="Ligands with more atoms may need a longer protein to bind. Longer protein takes more time to design, and may fail due to the GPU memory limit of your device."
            )
            outp = gr.Textbox(label="Output", value=example_result, visible=False)
            with gr.Row():
                btn = gr.Button("Design!", variant="primary")
                down_btn = gr.DownloadButton("Download Result", value=get_down_link, inputs=[outp])
        with gr.Column():
            mol = gr.HTML(value=molecule(example_result), label="Generation Result", show_label=True)
    btn.click(fn=update, inputs=[inp, inp_l], outputs=[mol, outp])
    radio.input(fn=lambda x: examples[x], inputs=[radio], outputs=[inp])
    inp.input(fn=lambda x: None if x not in smiles_to_example else smiles_to_example[x], inputs=[inp], outputs=[radio])

demo.launch(debug=True)

