import pickle
import random
import time
import pandas as pd
from tqdm import tqdm
from torch.cuda import is_available
from model.utils.utils import load_config, set_seed, to_numpy, recursive_to
from model.MultiFlow import FlowModel
from model.utils.sabdab_onlyV import SAbDabDataset
from torch.utils.data import DataLoader
from model.utils.writers import save_traj
from model.utils.chemical import BBHeavyAtom, resindex_to_ressymb, CDR
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio import SeqIO
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

"""
无条件设计抗体的重链和轻链的可变区
"""

device = torch.device("cuda:0" if is_available() else "cpu")  # 获取可以使用的硬件资源
print(device)

def remove_module_prefix(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace("module.", "") if k.startswith("module.") else k
        new_state_dict[new_key] = v
    return new_state_dict

if __name__ == '__main__':
    config, config_name = load_config('configs/base_pretrain_onlyCDR_ddp.yaml')
    # 输入尺寸会变，因此设置为False
    torch.backends.cudnn.benchmark = False
    # 固定cuda的随机数种子
    torch.backends.cudnn.deterministic = True

    only_hcdr3 = True

    """加载测试集"""
    test_dataset = SAbDabDataset(summary_path=config.dataset.summary_path_dir, chothia_dir=config.dataset.structure_dir,
                                 processed_dir=config.dataset.processed_dir, embedding_h5_dir=config.dataset.embedding_dir,
                                 split='test', reset=config.dataset.reset)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    time_str = time.strftime("%Y-%m-%d %H-%M-%S", time.localtime())
    save_path = config.sample.root_path + '/only_HCDR3/pdb'

    if config.sample.save_sequence:
        sequences_dict = {'id': [], 'H_seq': [], 'L_seq': []}
        seqs_records = []
    else:
        sequences_dict = None
        seqs_records = None

    # 加载训练好的流生成模型
    model = FlowModel(config).to(device)
    # model.load_state_dict(remove_module_prefix(torch.load('./save_model/model.pt')))
    checkpoint = torch.load('./save_model/onlyCDR_checkpoint.ckpt')
    model.load_state_dict(checkpoint['last_model_state_dict'])
    model.eval()


    for seed, batch in enumerate(test_loader):
        batch = recursive_to(batch, device)
        set_seed(config.train.seed + seed)

        pbar = tqdm(range(8))

        batch['res_mask'] = (batch['mask_heavyatom'][:, :, :3].sum(dim=-1) == 3.0)
        if only_hcdr3:
            cdr_flag = (batch['cdr_flag'] == CDR.H3)
        else:
            cdr_flag = batch['cdr_flag']
        batch['generate_mask'] = torch.logical_and(cdr_flag, batch['res_mask'])

        for i in pbar:
            with torch.no_grad():
                prot_traj, clean_traj = model.sample(config, batch, num_steps=100, only_cdr=True)

            traj_paths = save_traj(
                bb_prot_traj=to_numpy(prot_traj[-1]['bb_pos']),
                x0_traj=to_numpy(prot_traj[-1]['bb_pos'][:, BBHeavyAtom.CA]),
                diffuse_mask=to_numpy(batch['generate_mask'][0]),
                chain_idx=to_numpy(batch['chain_id'][0]),
                plm_embed=to_numpy(prot_traj[-1]['seqs_emb']),
                output_dir='%s/%s_sample_%s' % (save_path, batch['id'][0], (i + 1))
            )

            if config.sample.save_sequence:
                sequences_dict['id'].append('sample_%s' % (i + 1))
                heavy_seq = ''.join([resindex_to_ressymb[token] for token in prot_traj[-1]['seqs_aa'][~batch['chain_id'].cpu().bool()].tolist()])
                light_seq = ''.join([resindex_to_ressymb[token] for token in prot_traj[-1]['seqs_aa'][batch['chain_id'].cpu().bool()].tolist()])
                sequences_dict['H_seq'].append(heavy_seq)
                sequences_dict['L_seq'].append(light_seq)
                seqs_records.append(SeqRecord(Seq(heavy_seq), id='%s_heavy_sample_%s' % (batch['id'][0], (i + 1)), description=""))
                seqs_records.append(SeqRecord(Seq(light_seq), id='%s_light_sample_%s' % (batch['id'][0], (i + 1)), description=""))
    if config.sample.save_sequence:
        df = pd.DataFrame(sequences_dict)
        df.to_csv(config.sample.root_path + '/only_HCDR3/%s_seqs.csv' % time_str)
        SeqIO.write(seqs_records, config.sample.root_path + '/only_HCDR3/%s_seqs.fasta' % time_str, "fasta")
