# Modified from DPLM-2:
#     DPLM-2: https://github.com/bytedance/dplm/blob/main/run/scaffold_generate_dplm2.py

import argparse
from omegaconf import OmegaConf

import numpy as np
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer

import torch
import os
from byprot import utils
from byprot.models.dplm import DiffusionProteinLanguageModel

from peft import PeftModel, LoraConfig, get_peft_model

from models.construct import ConstructModel
from utils.scaffolding import motif_name_mapping, get_initial_prot

from biotite.sequence.io.fasta import FastaFile

def generate(args):

    config_path = '/'.join(args.checkpoint_path.split('/')[:-1]) + '/construct_150m.yaml'
    config = OmegaConf.load(config_path)
    OmegaConf.set_struct(config, False)

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)
    dplm_config = AutoConfig.from_pretrained(args.pretrained_model)

    torch.cuda.set_device(args.device)

    model = ConstructModel(config.model, dplm_config, tokenizer)

    model_state = torch.load(args.checkpoint_path + '.pth')
    model.load_state_dict(model_state)
    del model_state

    model = model.eval()
    model = model.cuda()
    device = next(model.parameters()).device

    print('Generating ...')
    print(f'Key Settings: seq_sampling_strategy: {args.sampling_strategy},', 
          f'unmasking_strategy: {args.unmasking_strategy}, seq_temp: {args.seq_temp},', 
          f'struct_temp: {args.struct_temp}, seq_cfg: {args.seq_cfg}, seq_cfg_schedule: {args.seq_cfg_schedule},',
          f'struct_cfg: {args.struct_cfg}, struct_cfg_schedule: {args.struct_cfg_schedule}, cover_ori_motif: {bool(args.cover_ori_motif)}')
    
    for pdb, ori_pdb in motif_name_mapping.items():
        # if '6E6R' not in pdb:
        #     continue
        print(f'Motif-Scaffolding for {pdb}...')
        
        prot_dict = dict(np.load(f'{args.scaffolding_pdb_dir}/latents/{ori_pdb}_reference.npz'))
        (
            baches, start_idxs_list, 
            end_idxs_list, scaffold_length_list
        ) = get_initial_prot(prot_dict, tokenizer, pdb, ori_pdb, args.prot_num, bool(args.cover_ori_motif), device)
        
        seq_fasta = FastaFile()
        for batch in baches:
            (
                output_seq_tokens, _,
                output_struct_embeds
            ) = model.generate(
                batch=batch, 
                seq_temp=float(args.seq_temp),
                struct_temp=float(args.struct_temp),
                sampling_strategy=args.sampling_strategy,
                unmasking_strategy=args.unmasking_strategy,
                seq_cfg=float(args.seq_cfg), seq_cfg_schedule=args.seq_cfg_schedule,
                struct_cfg=float(args.struct_cfg), struct_cfg_schedule=args.struct_cfg_schedule,
                cover_ori_motif=bool(args.cover_ori_motif)
            )
            for i in range(output_seq_tokens.shape[0]):
                idx = batch['idxs'][i]
                output_seq_results = tokenizer.decode(output_seq_tokens[i], skip_special_tokens=True)
                seq = ''.join(output_seq_results.split(' '))  # remove spaces
                seq_len = len(seq)

                seq_fasta[f'SEQUENCE_{idx}_L={seq_len}'] = seq
                struct_embed = output_struct_embeds[i] / 0.1875  # scale
                saveto_name = os.path.join(args.saveto, 'struct_latent', pdb)
                os.makedirs(saveto_name, exist_ok=True)
                np.savez(f"{saveto_name}/struct_{idx}.npz", latent=struct_embed[1:-1].cpu().detach().numpy(), seq=seq)
        seq_fasta.write(f"{args.saveto}/{pdb}.fasta")
        
        os.makedirs(f"{args.saveto}/start_end_scaffold", exist_ok=True)
        np.savez(
            f"{args.saveto}/start_end_scaffold/{pdb}.npz",
            start_idxs_list=start_idxs_list,
            end_idxs_list=end_idxs_list,
            scaffold_length_list=scaffold_length_list
        )
    print('Motif-Scaffolding completed!')
    
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--pretrained_model', type=str, default='')
    parser.add_argument('--checkpoint_path', type=str, default='')
    parser.add_argument('--scaffolding_pdb_dir', type=str, default='')
    parser.add_argument('--prot_num', type=int, default=100)
    parser.add_argument('--saveto', type=str, default='')
    parser.add_argument('--seq_temp', type=str, default='1.0')
    parser.add_argument('--struct_temp', type=str, default='0.35')
    parser.add_argument('--sampling_strategy', type=str, default='vanilla')
    parser.add_argument('--unmasking_strategy', type=str, default='deterministic')
    parser.add_argument('--seq_cfg', type=str, default='1.0')
    parser.add_argument('--seq_cfg_schedule', type=str, default='constant')
    parser.add_argument('--struct_cfg', type=str, default='1.0')
    parser.add_argument('--struct_cfg_schedule', type=str, default='constant')
    parser.add_argument('--cover_ori_motif', type=int, default=0)
    parser.add_argument('--device', type=int, default=0)
    
    args = parser.parse_args()
        
    generate(args)


if __name__ == '__main__':
    main()
