"""
This script can be used to evaluate a model
"""

import torch
import argparse
import os
import yaml

from VAE_model import VAE
from trainer_VAE_dynamics import Trainer
from utils import (
    initialize,
    process_data,
)

############### Arguments ###############

parser = argparse.ArgumentParser(description='VAE_dynamics_resume_training')

parser.add_argument('--config_file_path', type=str, help='Path to config yaml file')
parser.add_argument('--wandb_user', type=str, default='', help='Wandb user name')
parser.add_argument('--wandb_project', type=str, default='VAE_dynamics', help='Wandb project name')
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
parser.add_argument('--no_eval', action='store_true', help='Disable eval mode for sampling')
parser.add_argument('--num_samples_z_eval', type=int, default=None, help='Number of latent space samples')
parser.add_argument('--num_samples_k_eval', type=int, default=None, help='Number of kappa distribution samples')

args = parser.parse_args()

with open(args.config_file_path, "r") as f:
    config=yaml.safe_load(f)
for (key, value) in config.items():
    if not "wandb" in key and key not in ['batch_size', 'no_eval']:
        setattr(args, key, value['value'])

if args.num_samples_z_eval is None:
    args.num_samples_z_eval = args.num_samples_z
if args.num_samples_k_eval is None:
    args.num_samples_k_eval = args.num_samples_k

############### Inintializing ###############

device, use_gpu, save_folder = initialize(args)

############### Data ###############

kappa, kappa_train, kappa_val, coords, _, _, coords_pNeRF, bond_lengths_pNeRF, kappa_prior, \
    coords_ref, top, top_ref, ind_train = process_data(args)


############### Model ###############

# Model parameters 
latent_features = args.latent_features
encoder_sizes = args.encoder_list
decoder_sizes = args.decoder_list

# Training parameters
num_epochs = 0
lr = args.lr 
batch_size = args.batch_size 
num_warm_up_KL = args.num_warm_up_KL
num_mean_only = args.num_mean_only

# Model
fluctuation_aux_weight_start = args.fluctuation_aux_weight_start if args.fluctuation_aux != 'none' else 0.0
fix_fluctuation_aux_weight = args.fix_fluctuation_aux_weight if args.fluctuation_aux != 'none' else True
lambda_aux_weight_start = args.lambda_aux_weight_start if args.lambda_aux != 'none' else 0.0
fix_lambda_aux_weight = args.fix_lambda_aux_weight if args.lambda_aux != 'none' else True

aux_weight_start = [fluctuation_aux_weight_start, lambda_aux_weight_start]
fix_aux_weight = [fix_fluctuation_aux_weight, fix_lambda_aux_weight]
aux_loss = [args.fluctuation_aux, args.lambda_aux]

vae = VAE(latent_features, encoder_sizes, decoder_sizes, coords.shape[1], bond_lengths_pNeRF, kappa_prior, 
          args.predict_prior, a_start=args.a_start, fix_a=args.fix_a, scale_prior = args.scale_prior, ll = args.ll,
          aux_loss=aux_loss, aux_weight_start = aux_weight_start, fix_aux_weight = fix_aux_weight,
          allow_negative_lambda=args.allow_negative_lambda, steps = args.fluctuation_steps, 
          ll_every_layer=args.ll_every_layer, superpose = False, coords_ref = coords_ref, top_ref = top_ref,
          sin_cos_in=args.sin_cos_input, sin_cos_out=args.sin_cos_output)
vae.mean_only = False
if args.fluctuation_aux == 'cm_prior':
    vae.fluct_prior = coords[ind_train,:,:].var(dim=0).mean(dim=-1).to(device)
opt = None
print(f"Model: {args.model_name}")
print(f"Auxiliary fluctuation loss: {args.fluctuation_aux}")
print(f"Auxiliary lambda loss: {args.lambda_aux}")
print(vae)

# Load checkpoint
save_loc = os.path.join(save_folder, args.model_name + ".pt")
checkpoint = torch.load(save_loc, map_location=torch.device('cpu'))
vae.load_state_dict(checkpoint['model_state_dict'])


############### Evaluation ###############

data_train, data_val = None, None
trainer = Trainer(vae, opt, data_train, data_val, num_epochs, num_warm_up_KL, num_mean_only, save_loc)

eval = False if args.no_eval else True
print(f"Eval mode set to {eval}")
trainer.get_plots(kappa, coords, coords_pNeRF, num_samples_z=args.num_samples_z_eval, num_samples_k=args.num_samples_k_eval, batch_size=batch_size, 
                  topology=top, save=os.path.join(save_folder, args.model_name + "_samples.png"), 
                  model_name=args.model_name, eval=eval)