
import copy
import logging
import math
import os
from functools import partial

import torch
from openmmtools.testsystems import (AlanineDipeptideExplicit,
                                     AlanineDipeptideImplicit,
                                     AlanineDipeptideVacuum)
from tqdm import tqdm

import wandb
from datasets.aldp import construct_loader
from datasets.aldpSequenceDataset import construct_sequence_loader
from utils.parsing import parse_training_args
from utils.training import loss_function, test_epoch, train_epoch
from utils.utils import (ExponentialMovingAverage, get_model, get_optimizer_and_scheduler, save_yaml_file , get_single_conformation_energy_fn , save_prediction_pdb)
from visualizations.make_latent_space_tica import make_tica_plot

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Training')


def main_function():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    logger.info(f'Using device: {device}')

    args = parse_training_args()

    # Initialize wandb
    if args.wandb:
        wandb.init(
            entity='entity',
            project=args.wandb_project,
            name=args.run_name,
            config=args
        )


    testsystems = {"vacuum":AlanineDipeptideVacuum,"implicit":AlanineDipeptideImplicit,"explicit":AlanineDipeptideExplicit}
    args.testsystem = testsystems[args.testsystem](constraints=None)

    if args.no_propagator:
        train_loader, val_loader , coordinate_transform = construct_loader(args)
        # get reverse transform , here train_loader.dataset is a subset -> dataset.dataset is the actual dataset
        reverse_transform = train_loader.dataset.dataset.inverse_transform
    else : 
        train_loader, val_loader , coordinate_transform ,reverse_transform  = construct_sequence_loader(args)

    args.feature_size = args.ns
    model = get_model(args, coordinate_transform= coordinate_transform)
    model.to(device)
    args.device = device


    numel = sum([p.numel() for p in model.parameters()])
    logger.info(f'Model with {numel} parameters')
    
    if args.wandb:
        wandb.log({'num_parameters': numel})


    ema_weights = ExponentialMovingAverage(model.parameters(),decay=args.ema_rate)
    
    optimizer, scheduler = get_optimizer_and_scheduler(args, model)
    


    loss_fn = partial(loss_function, bond_width_weight=args.bond_width_weight, bond_angles_weight=args.bond_angles_weight, torsion_angles_weight=args.torsion_angles_weight , kl_weight=args.kl_weight , propagator_weight=args.propagator_weight , rescale_transform=reverse_transform)
    
    # save parameters and model
    run_dir = os.path.join(args.log_dir, args.run_name)
    yaml_file_name = os.path.join(run_dir, 'model_parameters.yml')
    save_yaml_file(yaml_file_name, {k:v for k,v in args.__dict__.items() if k not in ['testsystem', 'device']})

    if args.log_energy:
        # expensive initialization of energy computation, only do it if we need it
        args._compute_single_energy = get_single_conformation_energy_fn(args.testsystem.system, args.data_temperature)

    # train 
    train(args,model,scheduler, train_loader, val_loader, optimizer,ema_weights, loss_fn , run_dir , coordinate_transform , reverse_transform)

def train(args,model,scheduler, train_loader, val_loader, optimizer,ema_weights, loss_fn , run_dir , coordinate_transform ,inverse_feature_scaling):
    best_val_loss = math.inf
    best_epoch = 0

    logger.info(f"Starting training for {args.num_epochs} epochs")
    for epoch in range(args.num_epochs):

        # Train Epoch
        train_losses = train_epoch(model, train_loader, optimizer, loss_fn, args.device, coordinate_transform,inverse_feature_scaling, args)
        logger.info(f"Epoch {epoch} | Train loss: {train_losses['loss']:.4f} | Bond width loss: {train_losses['bond_width_loss']:.4f} | Bond angles loss: {train_losses['bond_angles_loss']:.4f} | Torsion angles loss: {train_losses['torsion_angles_loss']:.4f} | KL divergence: {train_losses['kl_div']:.4f} | Propagator loss: {train_losses['propagator_loss']:.4f} | VAE rmsd: {train_losses['rmsd']:.4f} ")
        
        ema_weights.store(model.parameters())
        if args.use_ema: ema_weights.copy_to(model.parameters()) # load ema parameters into model for running validation 


        # Val Epoch
        val_losses = test_epoch(model, val_loader, loss_fn , args.device , coordinate_transform,inverse_feature_scaling, args)
        logger.info(f"Epoch {epoch} | Val loss: {val_losses['loss']:.4f} | Bond width loss: {val_losses['bond_width_loss']:.4f} | Bond angles loss: {val_losses['bond_angles_loss']:.4f} | Torsion angles loss: {val_losses['torsion_angles_loss']:.4f} | KL divergence: {val_losses['kl_div']:.4f} | Propagator loss: {val_losses['propagator_loss']:.4f} | VAE rmsd: {val_losses['rmsd']:.4f} ")


        # save ema weights
        if not args.use_ema: ema_weights.copy_to(model.parameters())
        ema_state_dict = copy.deepcopy(model.state_dict() if args.device.type == 'cuda' else model.state_dict())
        ema_weights.restore(model.parameters())

        # save model
        state_dict = model.state_dict() if args.device.type == 'cuda' else model.state_dict()
       
        # log to wandb
        if args.wandb:
            # log train losses
            wandb.log({"train_"+k:v for k,v in train_losses.items() if not math.isnan(v)}, step=epoch + 1)
            # log val losses
            wandb.log({"val_"+k:v for k,v in val_losses.items() if not math.isnan(v)}, step=epoch + 1)

            wandb.log({"current_lr": optimizer.param_groups[0]['lr']}, step=epoch + 1)

          
          

        # save best model
        if val_losses['loss'] <= best_val_loss:
            best_val_loss = val_losses['loss']
            best_epoch = epoch
            torch.save(state_dict, os.path.join(run_dir, 'best_model.pt'))
            torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_model.pt'))



        # make lr scheduler step
        if scheduler:
            scheduler.step(val_losses['loss'])


        torch.save({
            'epoch': epoch,
            'model': state_dict,
            'optimizer': optimizer.state_dict(),
            'ema_weights': ema_weights.state_dict(),
        }, os.path.join(run_dir, 'last_model.pt'))
            
    logger.info(f"Best val loss: {best_val_loss:.4f} at epoch {best_epoch}")

    if args.log_latent_tica:
        logger.info('Logging latent tica')
        model.load_state_dict(torch.load(os.path.join(run_dir, 'best_model.pt')))
        
        # push model to cpu 
        model = model.to('cpu')

        latent_space_batches = []
        with torch.no_grad():
            for data ,*_ in tqdm(val_loader): 
                latent_space = model.encoder(data)
                if args.no_vae:
                    # take mu output (if --no_vae is set this is the latent space - logvar is not trained)
                    latent_space = latent_space[0]
                else:
                    # reparameterize if vae
                    latent_space = model.reparameterize(latent_space[0],latent_space[1])

                latent_space_batches.append(latent_space)
            latent_space = torch.cat(latent_space_batches, dim=0)

        make_tica_plot(latent_space.numpy(), 1, 2, args)

    if args.save_pdb:
        logger.info('Saving pdb trajectories for 1 batch')
        
        save_prediction_pdb(model, next(iter(val_loader))[0], coordinate_transform, args, run_dir, inverse_feature_scaling)


if __name__ == '__main__':
    main_function()

