# Modified from DPLM series repo:
#     DPLM-2: https://github.com/bytedance/dplm

import os

from omegaconf import OmegaConf, DictConfig

import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from pytorch_lightning import seed_everything

from transformers import AutoConfig, AutoTokenizer

from peft import PeftModel, LoraConfig, get_peft_model
from byprot.models.dplm import DiffusionProteinLanguageModel
import esm

from models.construct import ConstructModel
from utils.dataset import ProteinDataset, ConstructCollater
from utils.args import parse_training_args
from utils.training import get_scheduler, eval_struct_embeds
from utils.generation import get_initial

def train_one_epoch(args, model, loader, optimizer, scheduler, scaler):
    
    model.train()
    loss_total = 0
    
    for i, batch in enumerate(loader):

        input_seq_ids = batch.pop('input_ids').to(args.device)
        input_struct_embeds = batch.pop('struct_latent').to(args.device)
        position_ids = batch.pop('residue_index').to(args.device)
        
        try:
            with autocast():
                model_out = model(input_seq_ids=input_seq_ids, input_struct_embeds=input_struct_embeds, 
                                  complementary_masking=args.complementary_masking, self_mixup=args.self_mixup, cfg_training=args.cfg_training, 
                                  seq_reweighting=args.seq_reweighting, struct_reweighting=args.struct_reweighting,
                                  position_ids=position_ids, add_orth_term=args.add_orth_term)
                seq_ce_loss = model_out['seq_ce_loss']
                struct_diff_loss = model_out['struct_diff_loss']
                if args.self_mixup:
                    mixup_seq_ce_loss = model_out['mixup_seq_ce_loss']
                    mixup_struct_diff_loss = model_out['mixup_struct_diff_loss']
                    loss = args.seq_struct_ratio * (seq_ce_loss + mixup_seq_ce_loss) + (struct_diff_loss + mixup_struct_diff_loss)
                else:
                    loss = args.seq_struct_ratio * seq_ce_loss + struct_diff_loss
                if args.add_orth_term:
                    orth_loss = model_out['orth_loss']
                    loss += args.orth_term_scale * orth_loss
                
                if args.parallel:
                    seq_ce_loss = seq_ce_loss.mean()
                    struct_diff_loss = struct_diff_loss.mean()
                    if args.self_mixup:
                        mixup_seq_ce_loss = mixup_seq_ce_loss.mean()
                        mixup_struct_diff_loss = mixup_struct_diff_loss.mean()
                    if args.add_orth_term:
                        orth_loss = orth_loss.mean()
                    loss = loss.mean()

                seq_len = torch.mean((batch['input_mask']).sum(axis=1).float())

                loss_total += loss.item()
            
            optimizer.zero_grad()
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if (i + 1) % 10 == 0:
                if (args.self_mixup) & (args.add_orth_term):
                    print(f"step [{i+1}/{len(loader)}], seq_len: {seq_len:.3f}, loss: {loss.item():.3f} (ce_loss: {seq_ce_loss.item():.3f}, mixup_ce_loss: {mixup_seq_ce_loss.item():.3f}, diff_loss: {struct_diff_loss.item():.3f}, mixup_diff_loss: {mixup_struct_diff_loss.item():.3f})")
                elif (args.self_mixup) & (~args.add_orth_term):
                    print(f"step [{i+1}/{len(loader)}], seq_len: {seq_len:.3f}, loss: {loss.item():.3f} (ce_loss: {seq_ce_loss.item():.3f}, mixup_ce_loss: {mixup_seq_ce_loss.item():.3f}, diff_loss: {struct_diff_loss.item():.3f}")
                elif (~args.self_mixup) & (args.add_orth_term):
                    print(f"step [{i+1}/{len(loader)}], seq_len: {seq_len:.3f}, loss: {loss.item():.3f} (ce_loss: {seq_ce_loss.item():.3f}, diff_loss: {struct_diff_loss.item():.3f}, orth_loss: {orth_loss.item():.3f})")
                else:
                    print(f"step [{i+1}/{len(loader)}], seq_len: {seq_len:.3f}, loss: {loss.item():.3f} (ce_loss: {seq_ce_loss.item():.3f}, diff_loss: {struct_diff_loss.item():.3f})")
        
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print(f'CUDA OOM, skipping batch [{i+1}/{len(loader)}]')
                for p in model.parameters():
                    if p.grad is not None: del p.grad
                torch.cuda.empty_cache()
                continue
            else:
                print("Uncaught error " + str(e))
    
    scheduler.step()
    
    return loss_total / len(loader)

def validate_by_generation(args, tokenizer, model, epoch):
    
    torch.cuda.empty_cache()

    model.eval()
    esm_model, _ = esm.pretrained.load_model_and_alphabet(f'{args.ckpt_root}/{args.esm_path}')
    esm_model = esm_model.to(args.device)
    esm_model.eval()

    val_seq_scores = []
    val_esm_scores = []
    val_fids = []
    val_pos_wise_fids = []
    for idx, seq_len in enumerate(args.eval_seq_lens):
        max_iter = args.max_iter[idx]
        batch = get_initial(args, seq_len, tokenizer, args.device)

        with torch.cuda.amp.autocast():
            (
                output_seq_tokens, 
                output_seq_scores, 
                output_struct_embeds
            ) = model.generate(
                batch=batch, 
                max_iter=max_iter,
                seq_temp=args.seq_temp,
                struct_temp=args.struct_temp,
                sampling_strategy=args.sampling_strategy,
                unmasking_strategy=args.unmasking_strategy,
                seq_cfg=args.seq_cfg, seq_cfg_schedule=args.seq_cfg_schedule,
                struct_cfg=args.struct_cfg, struct_cfg_schedule=args.struct_cfg_schedule
                # cfg=args.cfg, cfg_schedule=args.cfg_schedule
            )
        
        output_seq_scores = output_seq_scores[:, 1:-1]
        output_struct_embeds = output_struct_embeds[:, 1:-1, :]

        with torch.no_grad():
            esm_scores = esm_model(output_seq_tokens, repr_layers=[])['logits']
        esm_scores = torch.nn.functional.softmax(esm_scores, dim=-1)
        esm_scores = torch.gather(esm_scores, dim=-1, index=output_seq_tokens.unsqueeze(-1))
        esm_scores = esm_scores[:, 1:-1].mean(dim=1).squeeze()
        esm_scores = esm_scores.cpu().numpy()
        
        fid, pos_wise_fid = eval_struct_embeds(args, output_struct_embeds)

        print(f'generated {args.num_seqs} sequences of length {seq_len} with max_iter {max_iter}')
        print(f'mean seq_scores: {output_seq_scores.mean().item():.3f}; mean esm_scores: {esm_scores.mean().item():.3f}')
        val_seq_scores.append(output_seq_scores.mean().item())
        val_esm_scores.append(esm_scores.mean().item())
        print(f'mean fid: {fid:.3f}, mean position-wise fid: {pos_wise_fid:.3f}')
        val_fids.append(fid)
        val_pos_wise_fids.append(pos_wise_fid)
        
        output_seq_results = [''.join(seq.split(' ')) for seq in tokenizer.batch_decode(output_seq_tokens, skip_special_tokens=True)]
        
        dir_name = f'{args.date}_grid_search/{args.date}-epoch={epoch}-{args.sampling_strategy}-seq_temp={args.seq_temp}-seq_cfg={args.seq_cfg}-{args.unmasking_strategy}-struct_temp={args.struct_temp}-struct_cfg={args.struct_cfg}'
        save_path = os.path.join(args.saveto, dir_name)
        os.makedirs(save_path, exist_ok=True)
        saveto_name = os.path.join(save_path, f"iter_{max_iter}_L_{seq_len}")
        fp_save = open(f"{saveto_name}.fasta", 'w')
        for idx, seq in enumerate( 
            output_seq_results
        ):
            fp_save.write(f">SEQUENCE_{idx}_L={seq_len}_S={esm_scores[idx]:.3f}\n")
            fp_save.write(f"{seq}\n")
        fp_save.close()
        
        saveto_name = os.path.join(save_path, 'struct_latent', f"iter_{max_iter}_L_{seq_len}")
        os.makedirs(saveto_name, exist_ok=True)
        for idx, struct_embed in enumerate(output_struct_embeds):
            struct_embed = struct_embed / 0.1875
            np.savez(f"{saveto_name}/struct_{idx}.npz", latent=struct_embed.cpu().detach().numpy(), seq=output_seq_results[idx])
    
    del esm_model
    del output_seq_scores
    del output_seq_tokens
    del output_struct_embeds
    torch.cuda.empty_cache()
    
    return val_seq_scores, val_esm_scores, val_fids, val_pos_wise_fids


def main():

    args = parse_training_args()
    config = OmegaConf.load(args.yaml_config)
    OmegaConf.set_struct(config, False)
    seed_everything(42, workers=True)

    # create dataset and dataloader
    dataset = ProteinDataset(args)

    tokenizer = AutoTokenizer.from_pretrained(f'{args.ckpt_root}/{args.pretrained_model}')
    construct_collate_fn = ConstructCollater(tokenizer)
    train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, 
                                  num_workers=args.num_workers, collate_fn=construct_collate_fn)
    
    # load pre-trained model states and create our model
    dplm_config = AutoConfig.from_pretrained(f'{args.ckpt_root}/{args.pretrained_model}')

    model = ConstructModel(config.model, dplm_config, tokenizer)
    if args.lora:
        print('Loading pre-trained DPLM model ...')
        dplm_state_dict = DiffusionProteinLanguageModel.from_pretrained(f'{args.ckpt_root}/{args.pretrained_model}').state_dict()
        model.load_state_dict({k.replace('net.', ''): v for k, v in dplm_state_dict.items()}, strict=False)
        del dplm_state_dict
        
        if not args.continuous_training:
            # initialize LoRA and set specific parameters trainable
            print('Initializing LoRA ...')
            lora_config = LoraConfig(
                r=config.model.lora.lora_rank,
                target_modules=config.model.lora.lora_target_module.split(','),
                modules_to_save=config.model.lora.modules_to_save.split(','),
                inference_mode=False,
                lora_alpha=config.model.lora.lora_alpha,
                lora_dropout=config.model.lora.lora_dropout,
            )
            model = get_peft_model(model, lora_config)
        else:
            print(f'Loading pre-trained LoRA modules from {args.ckpt_root}/{args.prev_ckpt_path} ...')
            custom_params_dict = torch.load(f'{args.ckpt_root}/{args.prev_ckpt_path}/custom_params.pth')
            model.struct_mask_token.data = custom_params_dict['struct_mask_token']
            model.esm.z_proj.load_state_dict(custom_params_dict['z_proj'], strict=True)
            model.diffloss.load_state_dict(custom_params_dict['diffloss'], strict=True)
            del custom_params_dict
            model = PeftModel.from_pretrained(model, f'{args.ckpt_root}/{args.prev_ckpt_path}')
        
        print('Setting LoRA modules and our modules trainable ...')
        # set trainable parameters again
        for name, param in model.named_parameters():
            if ('lora' in name) or ('modules_to_save' in name):
                param.requires_grad = True
            elif ('struct_mask_token' in name) or ('z_proj' in name) or ('diffloss' in name):
                param.requires_grad = True
            else:
                param.requires_grad = False
    elif args.from_scratch:
        print('Training from scratch, no pre-trained model loaded ...')
    else:
        if not args.continuous_training:
            print(f'Loading pre-trained DPLM model from {args.ckpt_root}/{args.pretrained_model} ...')
            dplm_state_dict = DiffusionProteinLanguageModel.from_pretrained(f'{args.ckpt_root}/{args.pretrained_model}').state_dict()
            model.load_state_dict({k.replace('net.', ''): v for k, v in dplm_state_dict.items()}, strict=False)
            del dplm_state_dict
        else:
            print(f'Loading fine-tuned DPLM model from {args.ckpt_root}/{args.prev_ckpt_path}.pth ...')
            model_state = torch.load(f'{args.ckpt_root}/{args.prev_ckpt_path}.pth', map_location='cpu')
            model.load_state_dict(model_state, strict=True)
            del model_state
        print('Preserving the full model trainable ...')

    model = model.to(args.device)

    model_without_dp = model
    if args.parallel:
        model = torch.nn.DataParallel(model, device_ids=[0, 1])
        model_without_dp = model.module

    total_params = sum(p.numel() for p in model.parameters()) / 1e6
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
    trainable_prop = trainable_params / total_params
    print(f"total params: {total_params:.1f} M; trainable params: {trainable_params:.1f} M; proportion of trainable params: {trainable_prop*100:.1f}%")

    # train the model
    optimizer = torch.optim.AdamW(model.parameters(filter(lambda p: p.requires_grad, model.parameters())), lr=args.lr_init, betas=(0.9, 0.95), weight_decay=1e-2)
    scheduler = get_scheduler(args, optimizer)
    scaler = GradScaler(enabled=True)

    if args.continuous_training_all:
        print('Loading optimizer and scheduler states ...')
        prev_ckpt_dir = '/'.join(args.prev_ckpt_path.split('/')[:-1])
        other_state = torch.load(f'{args.ckpt_root}/{prev_ckpt_dir}/latest_other_state.pth', map_location='cpu')
        optimizer.load_state_dict(other_state['optimizer_state_dict'])
        scheduler.load_state_dict(other_state['scheduler_state_dict'])
        current_epoch = other_state['epoch'] + 1
        del other_state
    else:
        print('Initializing optimizer and scheduler states ...')
        current_epoch = 0
    
    print(f"Key settings: complementary_masking: {args.complementary_masking}, self_mixup: {args.self_mixup}, cfg_training: {args.cfg_training}, crop_longer_prot: {args.crop_longer_prot}, seq_reweighting: {args.seq_reweighting}, struct_reweighting: {args.struct_reweighting}, add_orth_term: {args.add_orth_term}")

    for epoch in range(current_epoch, args.total_epochs):
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f'epoch {epoch} training start, current learning rate is {current_lr}')
        train_loss = train_one_epoch(args, model, train_dataloader, optimizer, scheduler, scaler)
        print(f"epoch {epoch}, train_loss: {train_loss:.3f}")
        
        try:
            save_path = os.path.join(args.ckpt_root, 'construct', args.date)
            os.makedirs(save_path, exist_ok=True)
                
            if args.lora:
                os.makedirs(f'{save_path}/latest', exist_ok=True)
                model_without_dp.save_pretrained(f'{save_path}/latest')
                torch.save({
                    'struct_mask_token': model_without_dp.base_model.model.struct_mask_token.data,
                    'z_proj': model_without_dp.base_model.model.esm.z_proj.state_dict(),
                    'diffloss': model_without_dp.base_model.model.diffloss.state_dict(),
                }, save_path + f'/latest/custom_params.pth')
                torch.save({
                    "epoch": epoch,
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                }, save_path + '/latest_other_state.pth')
            else:
                torch.save(model_without_dp.state_dict(), save_path + f'/latest.pth')
                torch.save({
                    "epoch": epoch,
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict()
                }, save_path + f'/latest_other_state.pth')
            print(f'Latest Model states saved!')
                    
        except RuntimeError as e:
            print("Uncaught error " + str(e))
        
        if (epoch + 1) % args.ckpt_period == 0:
            
            try:
                
                if args.lora:
                    os.makedirs(f'{save_path}/epoch_{epoch}', exist_ok=True)
                    model_without_dp.save_pretrained(f'{save_path}/epoch_{epoch}')
                    torch.save({
                        'struct_mask_token': model_without_dp.base_model.model.struct_mask_token.data,
                        'z_proj': model_without_dp.base_model.model.esm.z_proj.state_dict(),
                        'diffloss': model_without_dp.base_model.model.diffloss.state_dict(),
                    }, save_path + f'/epoch_{epoch}/custom_params.pth')
                else:
                    torch.save(model_without_dp.state_dict(), save_path + f'/epoch_{epoch}.pth')
                print(f'Model states saved at epoch {epoch}!')
                    
            except RuntimeError as e:
                print("Uncaught error " + str(e))
            
            print(f'epoch {epoch} validation start')
            try:
                val_metrics = validate_by_generation(args, tokenizer, model_without_dp, epoch)
            except RuntimeError as e:
                print("Uncaught error " + str(e))

    if args.lora:
        model_without_dp = model_without_dp.merge_and_unload()
        torch.save(model_without_dp.state_dict(), save_path + f'/epoch_{epoch}.pth')

if __name__ == "__main__":
    main()
