import os
import pickle
import random
import time
import pandas as pd
from tqdm import tqdm
from torch.cuda import is_available
from model.utils.utils import load_config, set_seed, to_numpy
from model.MultiFlow import FlowModel
from model.utils.writers import save_traj
from model.utils.chemical import BBHeavyAtom, resindex_to_ressymb
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio import SeqIO
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

"""
无条件设计抗体的重链和轻链的可变区
"""

device = torch.device("cuda:0" if is_available() else "cpu")  # 获取可以使用的硬件资源
print(device)

def remove_module_prefix(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace("module.", "") if k.startswith("module.") else k
        new_state_dict[new_key] = v
    return new_state_dict

if __name__ == '__main__':
    config, config_name = load_config('configs/base_pretrain_ddp.yaml')
    # 输入尺寸会变，因此设置为False
    torch.backends.cudnn.benchmark = False
    # 固定cuda的随机数种子
    torch.backends.cudnn.deterministic = True

    save_path = config.sample.root_path + '/pdb'

    length_combines = pickle.load(open('./length_combines.pkl', 'rb'))

    if config.sample.save_sequence:
        sequences_dict = {'id': [], 'H_seq': [], 'L_seq': []}
        seqs_records = []
    else:
        sequences_dict = None
        seqs_records = None

    # 加载训练好的流生成模型
    model = FlowModel(config).to(device)
    # model.load_state_dict(remove_module_prefix(torch.load('./save_model/model.pt')))
    checkpoint = torch.load('./save_model/checkpoint.ckpt')
    model.load_state_dict(checkpoint['last_model_state_dict'])
    model.eval()

    pbar = tqdm(range(100))
    used_length_list = []

    for i in pbar:
        heavy_L, light_L = length_combines[i]
        set_seed(i)

        for j in range(8):
            batch = {}
            batch['generate_mask'] = torch.ones((1, heavy_L + light_L), device=device).bool()
            batch['res_mask'] = batch['generate_mask']
            batch['chain_id'] = torch.concat([torch.tensor([[0. for _ in range(heavy_L)]], device=device),
                                              torch.tensor([[1. for _ in range(light_L)]], device=device)], dim=-1).to(torch.long)
            batch['chain_name'] = [['H' for _ in range(heavy_L)] + ['L' for _ in range(light_L)]]
            batch['res_nb'] = torch.concat([torch.tensor([list(range(1, heavy_L + 1))], device=device),
                                            torch.tensor([list(range(heavy_L + 50, heavy_L + 50 + light_L))], device=device)], dim=-1)
            with torch.no_grad():
                prot_traj, clean_traj = model.sample(config, batch, num_steps=100)

            traj_paths = save_traj(
                bb_prot_traj=to_numpy(prot_traj[-1]['bb_pos']),
                x0_traj=to_numpy(prot_traj[-1]['bb_pos'][:, BBHeavyAtom.CA]),
                diffuse_mask=to_numpy(batch['generate_mask'][0]),
                chain_idx=to_numpy(batch['chain_id'][0]),
                plm_embed=to_numpy(prot_traj[-1]['seqs_emb']),
                output_dir='%s/length_%s_%s_sample_%s' % (save_path, heavy_L, light_L, (j + 1))
            )

            if config.sample.save_sequence:
                sequences_dict['id'].append('length_%s_%s_sample_%s' % (heavy_L, light_L, (j + 1)))
                heavy_seq = ''.join([resindex_to_ressymb[token] for token in prot_traj[-1]['seqs_aa'][~batch['chain_id'].cpu().bool()].tolist()])
                light_seq = ''.join([resindex_to_ressymb[token] for token in prot_traj[-1]['seqs_aa'][batch['chain_id'].cpu().bool()].tolist()])
                sequences_dict['H_seq'].append(heavy_seq)
                sequences_dict['L_seq'].append(light_seq)
                seqs_records.append(SeqRecord(Seq(heavy_seq), id='length_%s_%s_heavy_sample_%s' % (heavy_L, light_L, (j + 1)), description=""))
                seqs_records.append(SeqRecord(Seq(light_seq), id='length_%s_%s_light_sample_%s' % (heavy_L, light_L, (j + 1)), description=""))
    if config.sample.save_sequence:
        df = pd.DataFrame(sequences_dict)
        df.to_csv(config.sample.root_path + '/seqs.csv', index=False)
        SeqIO.write(seqs_records, config.sample.root_path + '/seqs.fasta', "fasta")
