# Modified from DPLM:
#     DPLM: https://github.com/bytedance/dplm/blob/main/generate_dplm.py

import argparse
from omegaconf import OmegaConf, DictConfig

import numpy as np

from transformers import AutoConfig, AutoTokenizer

from pprint import pprint
import torch
import os
from byprot import utils
from byprot.models.dplm import DiffusionProteinLanguageModel

from peft import PeftModel, LoraConfig, get_peft_model

from models.construct import ConstructModel
from utils.generation import get_initial

import esm


def generate(args):

    config_path = '/'.join(args.checkpoint_path.split('/')[:-1]) + '/construct_150m.yaml'
    config = OmegaConf.load(config_path)
    OmegaConf.set_struct(config, False)

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)
    dplm_config = AutoConfig.from_pretrained(args.pretrained_model)

    torch.cuda.set_device(args.device)

    model = ConstructModel(config.model, dplm_config, tokenizer)

    model_state = torch.load(args.checkpoint_path + '.pth')
    model.load_state_dict(model_state)
    del model_state

    model = model.eval()
    model = model.cuda()
    device = next(model.parameters()).device

    esm_model, _ = esm.pretrained.load_model_and_alphabet(args.esm_path)
    esm_model = esm_model.to(args.device)
    esm_model.eval()

    print('Generating ...')
    print(f'Key Settings: seq_sampling_strategy: {args.sampling_strategy},', 
          f'unmasking_strategy: {args.unmasking_strategy}, seq_temp: {args.seq_temp},', 
          f'struct_temp: {args.struct_temp}, seq_cfg: {args.seq_cfg}, seq_cfg_schedule: {args.seq_cfg_schedule},',
          f'struct_cfg: {args.struct_cfg}, struct_cfg_schedule: {args.struct_cfg_schedule}')
    for idx, seq_len in enumerate(args.seq_lens):
        max_iter = args.max_iter[idx]
        batch = get_initial(args, seq_len, tokenizer, device)
        
        (
            output_seq_tokens, _, 
            output_struct_embeds
        ) = model.generate(
                batch=batch, 
                max_iter=max_iter,
                seq_temp=args.seq_temp,
                struct_temp=float(args.struct_temp),
                sampling_strategy=args.sampling_strategy,
                unmasking_strategy=args.unmasking_strategy,
                seq_cfg=float(args.seq_cfg), seq_cfg_schedule=args.seq_cfg_schedule,
                struct_cfg=float(args.struct_cfg), struct_cfg_schedule=args.struct_cfg_schedule
            )

        with torch.no_grad():
            esm_scores = esm_model(output_seq_tokens, repr_layers=[])['logits']
        esm_scores = torch.nn.functional.softmax(esm_scores, dim=-1)
        esm_scores = torch.gather(esm_scores, dim=-1, index=output_seq_tokens.unsqueeze(-1))
        esm_scores = esm_scores[:, 1:-1].mean(dim=1).squeeze()
        esm_scores = esm_scores.cpu().numpy()

        print(f'seq_len: {seq_len}, max_iter: {max_iter}, esm score: {esm_scores.mean().item():.3f}')
        output_seq_results = [''.join(seq.split(' ')) for seq in tokenizer.batch_decode(output_seq_tokens, skip_special_tokens=True)]
        for idx, seq in enumerate(output_seq_results):
            print(f"{esm_scores[idx]}: {seq}")
        
        os.makedirs(args.saveto, exist_ok=True)
        saveto_name = os.path.join(args.saveto, f"iter_{max_iter}_L_{seq_len}")
        fp_save = open(f"{saveto_name}.fasta", 'w')
        for idx, seq in enumerate( 
            output_seq_results
        ):
            fp_save.write(f">SEQUENCE_{idx}_L={seq_len}_S={esm_scores[idx]:.3f}\n")
            fp_save.write(f"{seq}\n")
        fp_save.close()
        
        saveto_name = os.path.join(args.saveto, 'struct_latent', f"iter_{max_iter}_L_{seq_len}")
        os.makedirs(saveto_name, exist_ok=True)
        for idx, struct_embed in enumerate(output_struct_embeds):
            struct_embed = struct_embed / 0.1875
            np.savez(f"{saveto_name}/struct_{idx}.npz", latent=struct_embed[1:-1].cpu().detach().numpy(), seq=output_seq_results[idx])
        
        torch.cuda.empty_cache()
    
    
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--pretrained_model', type=str, default='')
    parser.add_argument('--esm_path', type=str, default='')
    parser.add_argument('--checkpoint_path', type=str, default='')
    parser.add_argument('--num_seqs', type=int, default=100)
    parser.add_argument('--seq_lens', default=[100], nargs='*', type=int)
    parser.add_argument('--saveto', type=str, default='')
    parser.add_argument('--seq_temp', type=float, default=1.0)
    parser.add_argument('--struct_temp', type=str, default='0.35')
    parser.add_argument('--sampling_strategy', type=str, default='vanilla')
    parser.add_argument('--unmasking_strategy', type=str, default='deterministic')
    parser.add_argument('--seq_cfg', type=str, default='1.0')
    parser.add_argument('--seq_cfg_schedule', type=str, default='constant')
    parser.add_argument('--struct_cfg', type=str, default='1.0')
    parser.add_argument('--struct_cfg_schedule', type=str, default='constant')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--max_iter', default=[100], nargs='*', type=int)
    args = parser.parse_args()
        
    generate(args)


if __name__ == '__main__':
    main()
