# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass
import logging
logging.getLogger().setLevel(logging.INFO)
import copy
import random

import utils
import wandb
from configs.datasets_config import get_dataset_info
from os.path import join
from qm9 import dataset
from qm9.models import get_optim, get_model
from bond_type_prediction.initialize_pp_model import get_pp_model
from equivariant_diffusion import en_diffusion
from equivariant_diffusion.utils import assert_correctly_masked
from equivariant_diffusion import utils as flow_utils
import torch
import numpy as np
import time
import pickle
from qm9.utils import prepare_context, compute_mean_mad, compute_properties_upper_bounds
from train_test import train_epoch, test, analyze_and_save, eval_vae_reconstruction, eval_regression_model
from guacamol_evaluation.evaluator import GuacamolEvaluator
from geo_ldm.init_vae import get_vae
from geo_ldm.init_latent_diffuser import get_latent_diffusion
from geo_ldm.init_diffusion_guidance import get_diffusion_guidance
from geo_ldm.args import get_args, setup_args
from conditional_generation.prop_encoding import PropertyEncoding
from qm9.prodigy import Prodigy


def main_training_loop(args, first_stage, dataloaders, dataset_info, device, dtype, 
                        property_norms, property_norms_regression, prop_encoder, n_epochs):
    print(f"Using device: {device}")
    print(f'Training using {torch.cuda.device_count()} GPUs')

    # Create Latent Diffusion Model or Audoencoder
    if first_stage:
        # Create VAE
        model, nodes_dist, prop_dist = get_vae(args, device, dataset_info, dataloaders['train'])
        model_name = 'vae'
    elif args.train_diffusion:
        # Create LDM
        model, nodes_dist, prop_dist = get_latent_diffusion(args, device, dataset_info, dataloaders['train'])
        model_name = 'diffusion_model'
    elif args.train_regressor:
        # Create diffusion_guidance
        model, nodes_dist, prop_dist = get_diffusion_guidance(args, device, dataset_info, dataloaders['train'])
        model_name = 'regression_model'
        for prop in args.regression_target:
            model_name += '_' + prop

    if prop_dist is not None:
        prop_dist.set_normalizer(property_norms)
    model = model.to(device)
    print('model:', model)
    num_params = sum([p.numel() for p in model.parameters()])
    print('Num of parameters of model:', num_params)

    if args.prodigy_setting == 0:
        # Default
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, weight_decay=0.01, d_coef=args.d_coef)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)
    
    elif args.prodigy_setting == 1:
        print('Prodigy Setting 1')
        optim = Prodigy(model.parameters(), lr=1., d_coef=args.d_coef)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)
    elif args.prodigy_setting == 2:
        print('Prodigy Setting 2')
        optim = Prodigy(model.parameters(), lr=1., d_coef=args.d_coef)
        scheduler = None

    elif args.prodigy_setting == 3:
        print('Prodigy Setting 3')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, d_coef=args.d_coef)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)
    elif args.prodigy_setting == 4:
        print('Prodigy Setting 4')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, d_coef=0.05)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)

    elif args.prodigy_setting == 5:
        print('Prodigy Setting 5')
        optim = torch.optim.AdamW(model.parameters(), lr=2e-4, amsgrad=True, weight_decay=1e-12)
        scheduler = None

    elif args.prodigy_setting == 6:
        # Directly Diverged!
        print('Prodigy Setting 6')
        optim = torch.optim.AdamW(model.parameters(), lr=4e-4, amsgrad=True, weight_decay=1e-12)
        scheduler = None

    elif args.prodigy_setting == 7:
        # also diverged
        print('Prodigy Setting 7')
        optim = torch.optim.AdamW(model.parameters(), lr=3e-4, amsgrad=True, weight_decay=1e-12)
        scheduler = None

    if args.prodigy_setting == 8:
        # 2 changes in code
        # skipping training batches with loss >= 1000
        # initializing the gradient_queue with 10 instead of 3000
        # Also, added a try except block around loss.backward to catch the case of nan in model
        # Also: removed the queue and always clipping values above 1.0
        print('Prodigy Setting 8')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, weight_decay=0.01, d_coef=0.2)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)

    if args.prodigy_setting == 9:
        print('Prodigy Setting 9')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, weight_decay=0.01, d_coef=0.1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)
    if args.prodigy_setting == 10:
        print('Prodigy Setting 10')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, weight_decay=1e-12, d_coef=0.1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)
    if args.prodigy_setting == 11:
        print('Prodigy Setting 11')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, weight_decay=0.01, d_coef=0.05)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)

    if args.prodigy_setting == 12:
        # did not work
        print('Prodigy Setting 12')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, weight_decay=0.01, d_coef=0.05)
        scheduler = None

    elif args.prodigy_setting == 13:
        print('Prodigy Setting 13')
        optim = torch.optim.AdamW(model.parameters(), lr=2e-4, amsgrad=True, weight_decay=1e-12)
        scheduler = None

    elif args.prodigy_setting == 14:
        print('Prodigy Setting 14')
        optim = torch.optim.AdamW(model.parameters(), lr=2e-4, amsgrad=True, weight_decay=1e-12)
        warmup_epochs = 10
        scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs)
        scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=1e-4)
        scheduler = torch.optim.lr_scheduler.ChainedScheduler([scheduler1, scheduler2])
    elif args.prodigy_setting == 15:
        print('Prodigy Setting 15')
        optim = torch.optim.AdamW(model.parameters(), lr=2.5e-4, amsgrad=True, weight_decay=1e-12)
        warmup_epochs = 10
        scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs)
        scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=1e-4)
        scheduler = torch.optim.lr_scheduler.ChainedScheduler([scheduler1, scheduler2])

    elif args.prodigy_setting == 16:
        # Default without scheduler
        print('Prodigy Setting 16')
        optim = Prodigy(model.parameters(), lr=1., safeguard_warmup=True, use_bias_correction=True, weight_decay=0.01, d_coef=args.d_coef)
        scheduler = None

    else:
        optim = get_optim(args, model, pp_model=None)
        if args.prodigyopt:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs)
        else:
            scheduler = None

    gradnorm_queue = utils.Queue()
    #gradnorm_queue.add(3000)  # Add large value that will be flushed.
    gradnorm_queue.add(10) #  3000 was too high in my opinion

    if args.train_diffusion and args.guacamaol_eval:
        guacamol_evaluator = GuacamolEvaluator()
        best_fcd_score = 0.0

    # Initialize model copy for exponential moving average of params.
    assert args.ema_decay > 0, "Got ema_decay <= 0"
    model_ema = copy.deepcopy(model)
    ema = flow_utils.EMA(args.ema_decay)

    if args.resume is not None:
        if args.dp and torch.cuda.device_count() > 1:
            # necessary to match keys.
            model = torch.nn.DataParallel(model.cpu())
            model = utils.load_model(model, folder=args.resume, filename=f'last_{model_name}.npy')
            model = model.module.to(device)
        else:
            model = utils.load_model(model, folder=args.resume, filename=f'last_{model_name}.npy')

        model_ema = utils.load_model(model_ema, folder=args.resume, filename=f'last_{model_name}_ema.npy')
        if args.train_diffusion:
            # when resuming vae training during robust training, start over
            optim = utils.load_model(optim, folder=args.resume, filename=f'last_optim_{model_name}.npy')
        print(f'Loaded all models checkpoints from {args.resume}')

        # start from where we left off
        args.start_epoch = args.current_epoch

        if scheduler is not None:
            # create a fresh scheduler with the remaining number of epochs
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs-args.start_epoch)

    # Initialize dataparallel if enabled and possible.                                                              
    if args.dp and torch.cuda.device_count() > 1:
        print('Doing parallel training')
        print(f'Training using {torch.cuda.device_count()} GPUs')
        model_dp = torch.nn.DataParallel(model.cpu())
        model_dp = model_dp.cuda()
        model_ema_dp = torch.nn.DataParallel(model_ema)

        model = model_dp
        #model_ema = model_ema_dp
    else:
        model_dp = model
        model_ema_dp = model_ema


    #best_nll_val = 1e8
    #best_nll_test = 1e8
    best_combined_score = float('-inf')
    best_epoch = args.start_epoch
    for epoch in range(args.start_epoch, n_epochs):
        start_epoch = time.time()
        train_epoch(args=args, loader=dataloaders['train'], epoch=epoch, model=model, model_dp=model_dp,
                    model_ema=model_ema, ema=ema, device=device, dtype=dtype, property_norms=property_norms,
                    property_norms_regression=property_norms_regression, nodes_dist=nodes_dist, dataset_info=dataset_info,
                    gradnorm_queue=gradnorm_queue, optim=optim, prop_dist=prop_dist, prop_encoder=prop_encoder)
        print(f"Epoch took {time.time() - start_epoch:.1f} seconds.")
        if scheduler is not None:
            scheduler.step()
            last_lr = scheduler.get_last_lr()
            wandb.log({f'{model_name}/lr': last_lr[0]}, commit=False)

        # Always save last model
        args.current_epoch = epoch + 1
        utils.save_model(optim, f'outputs/{args.exp_name}/last_optim_{model_name}.npy')
        utils.save_model(model, f'outputs/{args.exp_name}/last_{model_name}.npy')
        if args.ema_decay > 0:
            utils.save_model(model_ema, f'outputs/{args.exp_name}/last_{model_name}_ema.npy')
        with open(f'outputs/{args.exp_name}/last_args_{model_name}.pickle', 'wb') as f:
            pickle.dump(args, f)

        if epoch % args.test_epochs == 0:
            if isinstance(model, en_diffusion.EnVariationalDiffusion):
                wandb.log(model.log_info(), commit=True)

            if args.train_diffusion:
                rdkit_tuple, unique_valid_smiles = analyze_and_save(args=args, epoch=epoch, model_sample=model_ema, 
                                    nodes_dist=nodes_dist,
                                    dataset_info=dataset_info, device=device,
                                    prop_dist=prop_dist, n_samples=args.n_stability_samples,
                                    prop_encoder=prop_encoder)

                validity, uniqueness, novelty = rdkit_tuple
                # to avoid setting the best metric at the first epoch where validity is very high but uniqueness is low
                combined_score = validity * uniqueness * novelty

                if unique_valid_smiles is not None:
                    print("Some generated valid SMILES:")
                    print(random.sample(unique_valid_smiles, k=min(len(unique_valid_smiles), args.n_stability_samples // 10)))

                if args.guacamaol_eval:
                    if unique_valid_smiles is not None:
                        guacamol_evaluator.add_smiles(unique_valid_smiles)
                    if guacamol_evaluator.get_smiles_count() > 10000:
                        print(f'Accumulated {guacamol_evaluator.get_smiles_count()} valid & unique smiles over the last epochs. Now running Guacamol Evaluation.')
                        # run eval
                        guacamol_results = guacamol_evaluator.evaluate(training_smiles_path=f'data/{args.dataset}/smiles/train.txt',
                                                                    number_samples=10000)
                        # save_fcd_score
                        new_fcd_score = guacamol_results[-1].score
                        if new_fcd_score >= best_fcd_score:
                            best_fcd_score = new_fcd_score
                            if args.save_model:
                                utils.save_model(optim, f'outputs/{args.exp_name}/best_fcd_optim_{model_name}.npy')
                                utils.save_model(model, f'outputs/{args.exp_name}/best_fcd_{model_name}.npy')
                                if args.ema_decay > 0:
                                    utils.save_model(model_ema, f'outputs/{args.exp_name}/best_fcd_{model_name}_ema.npy')
                                with open(f'outputs/{args.exp_name}/best_fcd_args_{model_name}.pickle', 'wb') as f:
                                    pickle.dump(args, f)

                        # log
                        for result in guacamol_results:
                            wandb.log({f"diffusion_model/Guacamol {result.benchmark_name}": result.score}, commit=False)
                        # clear
                        guacamol_evaluator.clear_smiles()

            nll_val = test(args=args, loader=dataloaders['valid'], epoch=epoch, eval_model=model_ema_dp,
                           partition='Val', device=device, dtype=dtype, nodes_dist=nodes_dist,
                           property_norms=property_norms, property_norms_regression=property_norms_regression,
                           prop_encoder=prop_encoder)

            if first_stage:
                results = eval_vae_reconstruction(args=args, property_norms=property_norms, vae_model=model_ema_dp, 
                        loader=dataloaders['valid'], device=device, dtype=dtype, mode='Val', prop_encoder=prop_encoder)
                # results_train = eval_vae_reconstruction(args=args, property_norms=property_norms, vae_model=model_ema_dp, 
                #         loader=dataloaders['train'], device=device, dtype=dtype, mode='Train', prop_encoder=prop_encoder)
                if model.noise_sigma is not None:
                    # TODO: might change the combined score to the corrupted mol_accuracy!
                    results_corrupted = eval_vae_reconstruction(args=args, property_norms=property_norms, vae_model=model_ema_dp, 
                                    loader=dataloaders['valid'], device=device, dtype=dtype, mode='Corrupted_val', 
                                    prop_encoder=prop_encoder, inject_noise=True)
                if args.use_vocab_data:
                    results_vocab = eval_vae_reconstruction(args=args, property_norms=property_norms, vae_model=model_ema_dp, 
                                loader=dataloaders['vocab'], device=device, dtype=dtype, mode='Vocab', prop_encoder=prop_encoder)

                combined_score = results['molecule_accuracy']
                if args.encoder_early_stopping:
                    if combined_score > 99.0 and not model.is_encoder_frozen:
                        print('Freezing Encoder')
                        model.freeze_encoder()
                        model_dp.freeze_encoder()
                        model_ema.freeze_encoder()
                        
                        # And start optimization from beginning
                        optim = get_optim(args, model, pp_model=None)
                        if args.prodigyopt:
                            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs-(epoch+1))
                        else:
                            scheduler = None
            
            if args.train_regressor:
                avg_mae = eval_regression_model(args=args, property_norms=property_norms, 
                        property_norms_regression=property_norms_regression, regression_model=model_ema_dp, 
                        loader=dataloaders['valid'], device=device, dtype=dtype, mode='Val', prop_encoder=prop_encoder)
                if args.regression_target in ['morgan_fingerprint']:
                    combined_score = avg_mae
                else:
                    combined_score = -avg_mae
                if epoch % 10 == 0: # and epoch != 0
                    # evaluate model on a fixed time step across whole dataset
                    for t in np.arange(0, args.max_step_regressor + 1, 100):
                        _ = eval_regression_model(args=args, property_norms=property_norms, 
                                property_norms_regression=property_norms_regression, regression_model=model_ema_dp, 
                                loader=dataloaders['valid'], device=device, dtype=dtype, mode='Val', prop_encoder=prop_encoder,
                                t_lower=t, t_upper=t)


            # No use of test set for us
            # nll_test = test(args=args, loader=dataloaders['test'], epoch=epoch, eval_model=model_ema_dp,
            #                 pp_model=pp_model_ema_dp, partition='Test', device=device, dtype=dtype,
            #                 nodes_dist=nodes_dist, property_norms=property_norms)

            # TODO: monitor best model in terms of validity
            #if nll_val < best_nll_val:
            if epoch > 0 and combined_score >= best_combined_score:
                best_combined_score = combined_score
                best_epoch = epoch

                #best_nll_val = nll_val
                #best_nll_test = nll_test
                if args.save_model:
                    utils.save_model(optim, f'outputs/{args.exp_name}/optim_{model_name}.npy')
                    utils.save_model(model, f'outputs/{args.exp_name}/{model_name}.npy')
                    if args.ema_decay > 0:
                        utils.save_model(model_ema, f'outputs/{args.exp_name}/{model_name}_ema.npy')
                    with open(f'outputs/{args.exp_name}/args_{model_name}.pickle', 'wb') as f:
                        pickle.dump(args, f)

                if args.save_model_history:
                    utils.save_model(optim, f'outputs/{args.exp_name}/optim_{model_name}_{epoch}.npy')
                    utils.save_model(model, f'outputs/{args.exp_name}/{model_name}_{epoch}.npy')
                    if args.ema_decay > 0:
                        utils.save_model(model_ema, f'outputs/{args.exp_name}/{model_name}_{epoch}_ema.npy')
                    with open(f'outputs/{args.exp_name}/args_{model_name}_{epoch}.pickle', 'wb') as f:
                        pickle.dump(args, f)
            print('Val loss: %.4f' % nll_val)
            #print('Best val loss: %.4f' % best_nll_val)
            wandb.log({f"{model_name}/Val loss ": nll_val}, commit=True)
            # wandb.log({"Test loss ": nll_test}, commit=True)
            # wandb.log({"Best cross-validated test loss ": best_nll_test}, commit=True)

        # patience
        if epoch - best_epoch > args.patience:
            print(f'Stopping the training after waiting {args.patience} epochs and the combined score did not improve')
            break

        # if first_stage and combined_score > 99.0 and args.encoder_early_stopping:
        #     print('VAE early stopping')
        #     break

        # turning on robust decoder training once clean training has converged
        # encoder will be fixed from now on
        # if first_stage and epoch - best_epoch > 10 and noise_sigma_vae is not None:
        #     print('Starting robust decoder training.')
        #     model.noise_sigma = noise_sigma_vae
        #     if args.ema_decay > 0:
        #         model_ema.noise_sigma = noise_sigma_vae
