"""
Main training script
"""

import torch
import argparse
import os
from torch.utils.data import DataLoader, TensorDataset
import warnings

from VAE_model import VAE
from trainer_VAE_dynamics import Trainer
from utils import (
    initialize,
    process_data, 
)

############### Arguments ###############

parser = argparse.ArgumentParser(description='VAE_dynamics_main')

parser.add_argument('--save_folder', type=str, default='./models', help='Folder to save results')
parser.add_argument('--model_name', type=str, default='test', help='Supply name for current model')
parser.add_argument('--wandb_user', type=str, default='', help='Wandb user name')
parser.add_argument('--wandb_project', type=str, default='VAE_dynamics', help='Wandb project name')
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
parser.add_argument('--data_file_path', type=str, default = "", help='Path to data file (pandas dataframe, .pickle or DE Shaw data), not needed for nmr data')
parser.add_argument('--protein', type=str, default = '1unc', help='Protein name')
parser.add_argument('--pdb_file_path', type=str, default = './data/1unc.pdb', help='Path to pdb file for this protein')
parser.add_argument('--crop', type=int, default=None, help='Optionally crop protein to first specified number of amino acids')
parser.add_argument('--subsample_data', type=int, default=1, help='Interval at which to subsample trajectories')
parser.add_argument('--sin_cos_input', action='store_true', help='Use sines and cosines of kappa as input to the encoder')
parser.add_argument('--sin_cos_output', action='store_true', help='Use sines and cosines of kappa as output of the decoder')
parser.add_argument('--latent_features', type=int, default=8, help='Number of latent features')
parser.add_argument('--encoder_list', nargs="+", type=int, default=[60, 30, 15], help='Encoder sizes (large to small')
parser.add_argument('--decoder_list', nargs="+", type=int, default=[15, 30, 60], help='Decoder sizes (small to large')
parser.add_argument('--epochs', type=int, default=1, help='Number of epochs')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--num_warm_up_KL', type=int, default=1000, help='Number of epochs to warm up KL-divergence in ELBO')
parser.add_argument('--num_mean_only', type=int, default=200, help='Number of epochs to train only mean kappa')
parser.add_argument('--num_samples_z', type=int, default=5, help='Number of latent space samples')
parser.add_argument('--num_samples_k', type=int, default=5, help='Number of kappa distribution samples (at each step!)')
parser.add_argument('--a_start', type=float, default=100., help='Starting value for precomputed prior scaling factor a')
parser.add_argument('--fix_a', action='store_true', help='Fix a at a_start during training')
parser.add_argument('--scale_prior', type=float, default=1.0, help='Value to scale the prior with')
parser.add_argument('--predict_prior', action='store_true', help='Predict prior instead of precomputing and scaling with a')
parser.add_argument('--no_eval', action='store_true', help='Disable eval mode for sampling')
parser.add_argument('--ll', choices=['kappa', 'x'], type=str, default=['x'], help="Choose one or multiple loglikelihoods from ['kappa', 'x']", nargs='+')
parser.add_argument('--ll_every_layer', action='store_true', help='NLL (weighted) on every layer')
parser.add_argument('--superpose', action='store_true', help='Superpose structures before training (experimental feature)')
parser.add_argument('--lambda_aux', choices=['mae', 'mse', 'none'], type=str, default='none', help="Choose auxiliary loss from ['mae', 'mse', 'none']")
parser.add_argument('--lambda_aux_weight_start', type=float, default=0.01, help='Starting value weight on auxiliary loss')
parser.add_argument('--fix_lambda_aux_weight', action='store_true', help='Fix auxiliary loss weight during training')
parser.add_argument('--fluctuation_aux', choices=['cm_prior', 'cm_low', 'none'], type=str, default='none', help="Choose auxiliary loss from ['cm_prior', 'cm_low', 'none']")
parser.add_argument('--fluctuation_aux_weight_start', type=float, default=0.01, help='Starting value weight on auxiliary loss')
parser.add_argument('--fix_fluctuation_aux_weight', action='store_true', help='Fix auxiliary loss weight during training')
parser.add_argument('--allow_negative_lambda', action='store_true', help='Allow lagrange multipliers to become negative')
parser.add_argument('--fluctuation_steps', type=int, default=1, help='Number of steps in hierarchical VAE')

args = parser.parse_args()

if args.superpose:
    warnings.warn("Superpostition only done for input structures, not during training (backprop challenges).")

############### Initializing ###############

device, use_gpu, save_folder = initialize(args)

############### Data ###############

kappa, kappa_train, kappa_val, coords, coords_train, coords_val, coords_pNeRF, bond_lengths_pNeRF, kappa_prior, \
    coords_ref, top, top_ref, ind_train = process_data(args)

############### Model ###############

# Model parameters 
latent_features = args.latent_features
encoder_sizes = args.encoder_list
decoder_sizes = args.decoder_list

# Training parameters
num_epochs = args.epochs
lr = args.lr 
batch_size = args.batch_size 
num_warm_up_KL = args.num_warm_up_KL
num_mean_only = args.num_mean_only

# Model
assert args.fluctuation_aux in ['cm_prior', 'cm_low', 'none'], "Invalid fluctuation auxiliary loss"
assert args.lambda_aux in ['mae', 'mse', 'none'], "Invalid lambda auxiliary loss"

fluctuation_aux_weight_start = args.fluctuation_aux_weight_start if args.fluctuation_aux != 'none' else 0.0
fix_fluctuation_aux_weight = args.fix_fluctuation_aux_weight if args.fluctuation_aux != 'none' else True
lambda_aux_weight_start = args.lambda_aux_weight_start if args.lambda_aux != 'none' else 0.0
fix_lambda_aux_weight = args.fix_lambda_aux_weight if args.lambda_aux != 'none' else True

aux_weight_start = [fluctuation_aux_weight_start, lambda_aux_weight_start]
fix_aux_weight = [fix_fluctuation_aux_weight, fix_lambda_aux_weight]
aux_loss = [args.fluctuation_aux, args.lambda_aux]

# NB: superpose set to False manually
vae = VAE(latent_features, encoder_sizes, decoder_sizes, coords.shape[1], bond_lengths_pNeRF, kappa_prior, 
          args.predict_prior, a_start=args.a_start, fix_a=args.fix_a, scale_prior = args.scale_prior, 
          ll=args.ll, aux_loss= aux_loss, aux_weight_start = aux_weight_start, fix_aux_weight = fix_aux_weight, 
          allow_negative_lambda = args.allow_negative_lambda, steps = args.fluctuation_steps, 
          ll_every_layer = args.ll_every_layer, superpose = False, coords_ref = coords_ref, top_ref = top_ref, 
          sin_cos_in=args.sin_cos_input, sin_cos_out=args.sin_cos_input)
if args.fluctuation_aux == 'cm_prior':
    vae.fluct_prior = coords[ind_train,:,:].var(dim=0).mean(dim=-1).to(device)
opt = torch.optim.Adam(vae.parameters(), lr = lr)
print(f"Model: {args.model_name}")
print(f"Auxiliary fluctuation loss: {args.fluctuation_aux}")
print(f"Auxiliary lambda loss: {args.lambda_aux}")
print(vae)


############### Training ###############

# Data
drop_last_train = True if len(kappa_train) % batch_size == 1 else False
data_train = TensorDataset(kappa_train, coords_train)
data_train = DataLoader(data_train, batch_size=batch_size, shuffle=True, drop_last=drop_last_train)

drop_last_val = True if len(kappa_val) % batch_size == 1 else False
data_val = TensorDataset(kappa_val, coords_val)
data_val = DataLoader(data_val, batch_size=batch_size, shuffle=False, drop_last=drop_last_val)

# Train
save_loc = os.path.join(save_folder, args.model_name + ".pt")
trainer = Trainer(vae, opt, data_train, data_val, num_epochs, num_warm_up_KL, num_mean_only, save_loc)
trainer.train()

# Evaluate
with torch.no_grad():
    trainer.opt=None
    trainer.vae = VAE(trainer.vae.latent_features, trainer.vae.encoder_sizes, trainer.vae.decoder_sizes, 
                      trainer.vae.length, trainer.vae.bond_lengths, trainer.vae.prior, trainer.vae.predict_prior, trainer.vae.a_start, 
                      trainer.vae.fix_a, trainer.vae.scale_prior, trainer.vae.ll, trainer.vae.aux_loss, trainer.vae.aux_weight_start, 
                      trainer.vae.fix_aux_weight, trainer.vae.allow_negative_lambda, trainer.vae.steps, trainer.vae.ll_every_layer, 
                      trainer.vae.superpose, trainer.vae.coords_ref, trainer.vae.top_ref, trainer.vae.sin_cos_in, trainer.vae.sin_cos_out)
    if args.fluctuation_aux == 'cm_prior':
        trainer.vae.fluct_prior = coords[ind_train,:,:].var(dim=0).mean(dim=-1).to(device)
    trainer.vae.mean_only = False
    save_loc = os.path.join(save_folder, args.model_name + ".pt")
    checkpoint = torch.load(save_loc, map_location=torch.device('cpu'))
    trainer.vae.load_state_dict(checkpoint['model_state_dict'])
    trainer.vae = trainer.vae.to(device)

eval = False if args.no_eval else True
print(f"Eval mode set to {eval}")
trainer.get_plots(kappa, coords, coords_pNeRF, num_samples_z=args.num_samples_z, num_samples_k=args.num_samples_k, batch_size=batch_size, 
                  topology=top, save=os.path.join(save_folder, args.model_name + "_samples.png"), 
                  model_name=args.model_name, eval=eval)