"""
This script can be used to start training from the last saved checkpoint
"""

import torch
import argparse
import os
from torch.utils.data import DataLoader, TensorDataset
import yaml

from VAE_model import VAE
from trainer_VAE_dynamics import Trainer
from utils import (
    initialize,
    process_data,
)

############### Arguments ###############

parser = argparse.ArgumentParser(description='VAE_dynamics_resume_training')

parser.add_argument('--config_file_path', type=str, help='Path to config yaml file')
parser.add_argument('--extra_epochs', type=int, default=1000, help='Number of additional epochs')
parser.add_argument('--wandb_user', type=str, default='', help='Wandb user name')
parser.add_argument('--wandb_project', type=str, default='VAE_dynamics', help='Wandb project name')
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')

args = parser.parse_args()

with open(args.config_file_path, "r") as f:
    config=yaml.safe_load(f)
for (key, value) in config.items():
    if not "wandb" in key:
        setattr(args, key, value['value'])

############### Inintializing ###############

device, use_gpu, save_folder = initialize(args)

############### Data ###############

kappa, kappa_train, kappa_val, coords, coords_train, coords_val, coords_pNeRF, bond_lengths_pNeRF, kappa_prior, \
    coords_ref, top, top_ref, ind_train = process_data(args)

############### Model ###############

# Model parameters 
latent_features = args.latent_features
encoder_sizes = args.encoder_list
decoder_sizes = args.decoder_list

# Training parameters
num_epochs = args.extra_epochs
lr = args.lr 
batch_size = args.batch_size 
num_warm_up_KL = args.num_warm_up_KL
num_mean_only = args.num_mean_only

# Model
fluctuation_aux_weight_start = args.fluctuation_aux_weight_start if args.fluctuation_aux != 'none' else 0.0
fix_fluctuation_aux_weight = args.fix_fluctuation_aux_weight if args.fluctuation_aux != 'none' else True
lambda_aux_weight_start = args.lambda_aux_weight_start if args.lambda_aux != 'none' else 0.0
fix_lambda_aux_weight = args.fix_lambda_aux_weight if args.lambda_aux != 'none' else True

aux_weight_start = [fluctuation_aux_weight_start, lambda_aux_weight_start]
fix_aux_weight = [fix_fluctuation_aux_weight, fix_lambda_aux_weight]
aux_loss = [args.fluctuation_aux, args.lambda_aux]

vae = VAE(latent_features, encoder_sizes, decoder_sizes, coords.shape[1], bond_lengths_pNeRF, kappa_prior, 
          args.predict_prior, a_start=args.a_start, fix_a=args.fix_a, scale_prior = args.scale_prior, ll=args.ll, 
          aux_loss=aux_loss, aux_weight_start = aux_weight_start, fix_aux_weight = fix_aux_weight,
          allow_negative_lambda=args.allow_negative_lambda, steps = args.fluctuation_steps, 
          ll_every_layer=args.ll_every_layer, superpose = False, coords_ref = coords_ref, top_ref = top_ref,
          sin_cos_in=args.sin_cos_input, sin_cos_out=args.sin_cos_output)
if args.fluctuation_aux == 'cm_prior':
    vae.fluct_prior = coords.var(dim=0).mean(dim=-1).to(device)
opt = torch.optim.Adam(vae.parameters(), lr = lr)
print(f"Model: {args.model_name}")
print(f"Auxiliary fluctuation loss: {args.fluctuation_aux}")
print(f"Auxiliary lambda loss: {args.lambda_aux}")
print(vae)


############### Training ###############

# Data
drop_last_train = True if len(kappa_train) % batch_size == 1 else False
data_train = TensorDataset(kappa_train, coords_train)
data_train = DataLoader(data_train, batch_size=batch_size, shuffle=True, drop_last=drop_last_train)

drop_last_val = True if len(kappa_val) % batch_size == 1 else False
data_val = TensorDataset(kappa_val, coords_val)
data_val = DataLoader(data_val, batch_size=batch_size, shuffle=False, drop_last=drop_last_val)

# Train from checkpoint
save_loc = os.path.join(save_folder, args.model_name + ".pt")
checkpoint = torch.load(save_loc, map_location=torch.device(device))
trainer = Trainer(vae, opt, data_train, data_val, num_epochs, num_warm_up_KL, num_mean_only, save_loc, start_epoch=checkpoint['epoch'])
trainer.vae.load_state_dict(checkpoint['model_state_dict'])
trainer.opt.load_state_dict(checkpoint['optimizer_state_dict'])

trainer.train()

# Evaluate
with torch.no_grad():
    trainer.opt=None
    trainer.vae = VAE(trainer.vae.latent_features, trainer.vae.encoder_sizes, trainer.vae.decoder_sizes, 
                      trainer.vae.length, trainer.vae.bond_lengths, trainer.vae.prior, trainer.vae.predict_prior, trainer.vae.a_start, 
                      trainer.vae.fix_a, trainer.vae.scale_prior, trainer.vae.ll, trainer.vae.aux_loss, trainer.vae.aux_weight_start, 
                      trainer.vae.fix_aux_weight, trainer.vae.allow_negative_lambda, trainer.vae.steps, trainer.vae.ll_every_layer,
                      trainer.vae.superpose, trainer.vae.coords_ref, trainer.vae.top_ref, trainer.vae.sin_cos_in, trainer.vae.sin_cos_out)
    if args.fluctuation_aux == 'cm_prior':
        trainer.vae.fluct_prior = coords[ind_train,:,:].var(dim=0).mean(dim=-1).to(device)
    trainer.vae.mean_only = False
    save_loc = os.path.join(save_folder, args.model_name + ".pt")
    checkpoint = torch.load(save_loc, map_location=torch.device('cpu'))
    trainer.vae.load_state_dict(checkpoint['model_state_dict'])
    trainer.vae = trainer.vae.to(device)

eval = False if args.no_eval else True
print(f"Eval mode set to {eval}")
trainer.get_plots(kappa, coords, coords_pNeRF, num_samples_z=args.num_samples_z, num_samples_k=args.num_samples_k, batch_size=batch_size, 
                  topology=top, save=os.path.join(save_folder, args.model_name + "_samples.png"), 
                  model_name=args.model_name, eval=eval)