# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass
import logging
logging.getLogger().setLevel(logging.INFO)
import copy
import utils
import argparse
import wandb
from configs.datasets_config import get_dataset_info
from os.path import join
from qm9 import dataset
from qm9.models import get_optim, get_model
from bond_type_prediction.initialize_pp_model import get_pp_model
from equivariant_diffusion import en_diffusion
from equivariant_diffusion.utils import assert_correctly_masked
from equivariant_diffusion import utils as flow_utils
import torch
import time
import pickle
from qm9.utils import prepare_context, compute_mean_mad
from train_test import train_epoch, test, analyze_and_save
from guacamol_evaluation.evaluator import GuacamolEvaluator

parser = argparse.ArgumentParser(description='E3Diffusion')
parser.add_argument('--exp_name', type=str, default='debug_10')
parser.add_argument('--model', type=str, default='egnn_dynamics',
                    help='our_dynamics | schnet | simple_dynamics | '
                         'kernel_dynamics | egnn_dynamics |gnn_dynamics')
parser.add_argument('--probabilistic_model', type=str, default='diffusion',
                    help='diffusion')
parser.add_argument('--joint_training', type=eval, default=False,
                    help='whether to train the diffusion model and the edge model jointly')
parser.add_argument('--debug', type=eval, default=False,
                    help='If True, use only few batches and evaluate on training set')
parser.add_argument('--joint_space', type=str, default='z0_pred',
                    help='the space on which both the diffusion model and pp_model operate')
parser.add_argument('--guacamaol_eval', type=eval, default=True,
                    help='If True, will run Guacamol evaluaion every few epochs')
parser.add_argument('--patience', type=int, default=100000, metavar='N',
                    help='number of epochs to wait before stopping the training if val accuracy does not improve (default: 100)')

# Training complexity is O(1) (unaffected), but sampling complexity is O(steps).
parser.add_argument('--diffusion_steps', type=int, default=500)
parser.add_argument('--diffusion_noise_schedule', type=str, default='polynomial_2',
                    help='learned, cosine')
parser.add_argument('--diffusion_noise_precision', type=float, default=1e-5,
                    )
parser.add_argument('--diffusion_loss_type', type=str, default='l2',
                    help='vlb, l2')

parser.add_argument('--n_epochs', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--brute_force', type=eval, default=False,
                    help='True | False')
parser.add_argument('--actnorm', type=eval, default=True,
                    help='True | False')
parser.add_argument('--break_train_epoch', type=eval, default=False,
                    help='True | False')
parser.add_argument('--dp', type=eval, default=False,
                    help='True | False')
parser.add_argument('--condition_time', type=eval, default=True,
                    help='True | False')
parser.add_argument('--clip_grad', type=eval, default=True,
                    help='True | False')
parser.add_argument('--trace', type=str, default='hutch',
                    help='hutch | exact')
# EGNN args -->
parser.add_argument('--n_layers', type=int, default=6,
                    help='number of layers')
parser.add_argument('--inv_sublayers', type=int, default=1,
                    help='number of layers')
parser.add_argument('--nf', type=int, default=128,
                    help='number of layers')
parser.add_argument('--tanh', type=eval, default=True,
                    help='use tanh in the coord_mlp')
parser.add_argument('--attention', type=eval, default=True,
                    help='use attention in the EGNN')
parser.add_argument('--norm_constant', type=float, default=1,
                    help='diff/(|diff| + norm_constant)')
parser.add_argument('--sin_embedding', type=eval, default=False,
                    help='whether using or not the sin embedding')
# <-- EGNN args
parser.add_argument('--ode_regularization', type=float, default=1e-3)
parser.add_argument('--dataset', type=str, default='qm9',
                    help='qm9 | zinc250k | qm9_second_half (train only on the last 50K samples of the training dataset)')
parser.add_argument('--datadir', type=str, default='data/',
                    help='data directory')
parser.add_argument('--filter_n_atoms', type=int, default=None,
                    help='When set to an integer value, QM9 will only contain molecules of that amount of atoms')
parser.add_argument('--dequantization', type=str, default='argmax_variational',
                    help='uniform | variational | argmax_variational | deterministic')
parser.add_argument('--n_report_steps', type=int, default=1)
parser.add_argument('--wandb_usr', type=str)
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
parser.add_argument('--online', type=bool, default=True, help='True = wandb online -- False = wandb offline')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--save_model', type=eval, default=True,
                    help='save model')
parser.add_argument('--save_model_history', type=eval, default=False,
                    help='save model every new best val epoch')
parser.add_argument('--generate_epochs', type=int, default=1,
                    help='save model')
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker for the dataloader')
parser.add_argument('--test_epochs', type=int, default=10)
parser.add_argument('--data_augmentation', type=eval, default=False, help='use attention in the EGNN')
parser.add_argument("--conditioning", nargs='+', default=[],
                    help='arguments : homo | lumo | alpha | gap | mu | Cv' )
parser.add_argument('--resume', type=str, default=None,
                    help='')
parser.add_argument('--start_epoch', type=int, default=0,
                    help='')
parser.add_argument('--ema_decay', type=float, default=0.999,
                    help='Amount of EMA decay, 0 means off. A reasonable value'
                         ' is 0.999.')
parser.add_argument('--augment_noise', type=float, default=0)
parser.add_argument('--n_stability_samples', type=int, default=500,
                    help='Number of samples to compute the stability')
parser.add_argument('--normalize_factors', type=eval, default=[1, 4, 1],
                    help='normalize factors for [x, categorical, integer]')
parser.add_argument('--remove_h', action='store_true')
parser.add_argument('--include_charges', type=eval, default=True,
                    help='include atom charge or not')
parser.add_argument('--visualize_every_batch', type=int, default=1e8,
                    help="Can be used to visualize multiple times per epoch")
parser.add_argument('--normalization_factor', type=float, default=1,
                    help="Normalize the sum aggregation of EGNN")
parser.add_argument('--aggregation_method', type=str, default='sum',
                    help='"sum" or "mean"')

# pp_model args -->
# ALL args related to the pp_model end with _pp
parser.add_argument('--n_layers_pp', type=int, default=4,
                    help='number of Equivariant blocks')
parser.add_argument('--inv_sublayers_pp', type=int, default=2,
                    help='number of GCL layers in each Equivariant block')
parser.add_argument('--hidden_nf_pp', type=int, default=64,
                    help='number of layers')
parser.add_argument('--tanh_pp', type=eval, default=False,
                    help='use tanh in the coord_mlp')
parser.add_argument('--attention_pp', type=eval, default=True,
                    help='use attention in the EGNN')
parser.add_argument('--norm_constant_pp', type=float, default=1,
                    help='diff/(|diff| + norm_constant)')
parser.add_argument('--sin_embedding_pp', type=eval, default=False,
                    help='whether using or not the sin embedding')
parser.add_argument('--normalization_factor_pp', type=float, default=1,
                    help="Normalize the sum aggregation of EGNN")
parser.add_argument('--aggregation_method_pp', type=str, default='sum',
                    help='"sum" or "mean"')
parser.add_argument('--encoder_pp', type=str, default='egnn',
                    help='egnn or None')
parser.add_argument('--edge_head_pp', type=str, default='mlp',
                    help='mlp or linear')
parser.add_argument('--edge_head_hidden_dim_pp', type=int, default=32,
                    help='hidden dimensions of the edge head MLP')
parser.add_argument('--modify_h_pp', type=eval, default=True,
                    help='If True, the edge model will also modify the atom types and charges')
parser.add_argument('--lambda_pp_loss', type=float, default=1.0,
                    help='coefficient for the pp_model loss in the joint training')
parser.add_argument('--condition_time_pp', type=eval, default=False,
                    help='True | False')
parser.add_argument('--lr_pp', type=float, default=5e-4, metavar='N',
                    help='learning rate')
parser.add_argument('--use_eps_correction', type=eval, default=False,
                    help='True | False')
parser.add_argument('--use_pp_model', type=eval, default=True,
                    help='whether to use the edge model')
# <-- pp_model args

args = parser.parse_args()

if args.debug:
    args.n_stability_samples = 10

dataset_info = get_dataset_info(args.dataset, args.remove_h)
print(dataset_info)

atom_encoder = dataset_info['atom_encoder']

# args, unparsed_args = parser.parse_known_args()
args.wandb_usr = utils.get_wandb_username(args.wandb_usr)

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
dtype = torch.float32

if args.resume is not None:
    exp_name = args.exp_name + '_resume'
    start_epoch = args.start_epoch
    resume = args.resume
    wandb_usr = args.wandb_usr
    normalization_factor = args.normalization_factor
    aggregation_method = args.aggregation_method

    with open(join(args.resume, f'args_{args.start_epoch}.pickle'), 'rb') as f:
        args = pickle.load(f)

    args.resume = resume
    args.break_train_epoch = False

    args.exp_name = exp_name
    args.start_epoch = start_epoch
    args.wandb_usr = wandb_usr

    # Careful with this -->
    if not hasattr(args, 'normalization_factor'):
        args.normalization_factor = normalization_factor
    if not hasattr(args, 'aggregation_method'):
        args.aggregation_method = aggregation_method

    print(args)

utils.create_folders(args)
# print(args)

# Wandb config
if args.no_wandb:
    mode = 'disabled'
else:
    mode = 'online' if args.online else 'offline'
kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': 'zinc_joint_training', 'config': args,
          'settings': wandb.Settings(_disable_stats=False), 'reinit': True, 'mode': mode}
wandb.init(**kwargs)
wandb.save('*.txt')

# Retrieve QM9 dataloaders
dataloaders, charge_scale = dataset.retrieve_dataloaders(args, args.debug)

data_dummy = next(iter(dataloaders['train']))

if len(args.conditioning) > 0:
    print(f'Conditioning on {args.conditioning}')
    property_norms = compute_mean_mad(dataloaders, args.conditioning, args.dataset)
    context_dummy = prepare_context(args.conditioning, data_dummy, property_norms)
    context_node_nf = context_dummy.size(2)
else:
    context_node_nf = 0
    property_norms = None

args.context_node_nf = context_node_nf

# Create EGNN flow
model, nodes_dist, prop_dist = get_model(args, device, dataset_info, dataloaders['train'])
if prop_dist is not None:
    prop_dist.set_normalizer(property_norms)
model = model.to(device)
# print(model)

if args.use_pp_model:
    #: initialize cloud2graph_model / pp_model using some get_pp_model(args) that uses default parameters for the architecture
    # move pp_model to device
    # set up optimizer for the pp_model
    pp_model = get_pp_model(args, dataset_info, device, dataloaders['train'], recompute_class_weight=False)
else:
    pp_model = None

# joint optimizer for both models if pp_model is not None
optim = get_optim(args, model, pp_model)

gradnorm_queue = utils.Queue()
gradnorm_queue.add(3000)  # Add large value that will be flushed.

gradnorm_queue_pp = utils.Queue()
gradnorm_queue_pp.add(3000)

if args.guacamaol_eval:
    guacamol_evaluator = GuacamolEvaluator()

def check_mask_correct(variables, node_mask):
    for variable in variables:
        if len(variable) > 0:
            assert_correctly_masked(variable, node_mask)


def main():
    print(f"Using device: {device}")
    print(f'Training using {torch.cuda.device_count()} GPUs')

    if args.resume is not None:
        flow_state_dict = torch.load(join(args.resume, f'generative_model_{args.start_epoch}.npy'))
        optim_state_dict = torch.load(join(args.resume, f'optim_{args.start_epoch}.npy'))
        model.load_state_dict(flow_state_dict)
        optim.load_state_dict(optim_state_dict)
        # TODO: load ema_model
        if args.use_pp_model:
            # TODO: load pp_model
            pass

    # Initialize dataparallel if enabled and possible.
    if args.dp and torch.cuda.device_count() > 1:
        print(f'Training using {torch.cuda.device_count()} GPUs')
        model_dp = torch.nn.DataParallel(model.cpu())
        model_dp = model_dp.cuda()
    else:
        model_dp = model
        pp_model_dp = pp_model

    # Initialize model copy for exponential moving average of params.
    if args.ema_decay > 0:
        model_ema = copy.deepcopy(model)
        pp_model_ema = copy.deepcopy(pp_model)
        # same ema can work for both
        ema = flow_utils.EMA(args.ema_decay)

        if args.dp and torch.cuda.device_count() > 1:
            model_ema_dp = torch.nn.DataParallel(model_ema)
        else:
            model_ema_dp = model_ema
            pp_model_ema_dp = pp_model_ema
    else:
        ema = None
        model_ema = model
        model_ema_dp = model_dp

        pp_model_ema = pp_model
        pp_model_ema_dp = pp_model_dp

    #best_nll_val = 1e8
    #best_nll_test = 1e8
    best_combined_score = 0
    best_epoch = 0
    for epoch in range(args.start_epoch, args.n_epochs):
        start_epoch = time.time()
        # TODO: pass pp_model as an argument. If we're not doing joint_training, it will be None
        train_epoch(args=args, loader=dataloaders['train'], epoch=epoch, model=model, model_dp=model_dp,
                    model_ema=model_ema, ema=ema, device=device, dtype=dtype, property_norms=property_norms,
                    nodes_dist=nodes_dist, dataset_info=dataset_info,
                    gradnorm_queue=gradnorm_queue, gradnorm_queue_pp=gradnorm_queue_pp, optim=optim, prop_dist=prop_dist,
                    pp_model=pp_model, pp_model_dp=pp_model_dp, pp_model_ema=pp_model_ema)
        print(f"Epoch took {time.time() - start_epoch:.1f} seconds.")

        if epoch % args.test_epochs == 0:
            if isinstance(model, en_diffusion.EnVariationalDiffusion):
                wandb.log(model.log_info(), commit=True)

            #if not args.break_train_epoch:
            rdkit_tuple, unique_valid_smiles = analyze_and_save(args=args, epoch=epoch, model_sample=model_ema, 
                                pp_model=pp_model_ema, nodes_dist=nodes_dist,
                                dataset_info=dataset_info, device=device,
                                prop_dist=prop_dist, n_samples=args.n_stability_samples)

            validity, uniqueness, novelty = rdkit_tuple
            # to avoid setting the best metric at the first epoch where validity is very high but uniqueness is low
            combined_score = validity * uniqueness * novelty

            if args.guacamaol_eval:
                if unique_valid_smiles is not None:
                    guacamol_evaluator.add_smiles(unique_valid_smiles)
                if guacamol_evaluator.get_smiles_count() > 10000:
                    print(f'Accumulated {guacamol_evaluator.get_smiles_count()} valid & unique smiles over the last epochs. Now running Guacamol Evaluation.')
                    # run eval
                    guacamol_results = guacamol_evaluator.evaluate(training_smiles_path=f'data/{args.dataset}/smiles/train.txt',
                                                                   number_samples=10000)
                    # log
                    for result in guacamol_results:
                        wandb.log({f"Guacamol {result.benchmark_name}": result.score}, commit=False)
                    # clear
                    guacamol_evaluator.clear_smiles()

            nll_val = test(args=args, loader=dataloaders['valid'], epoch=epoch, eval_model=model_ema_dp,
                           pp_model=pp_model_ema_dp, partition='Val', device=device, dtype=dtype, nodes_dist=nodes_dist,
                           property_norms=property_norms)
            # No use of test set for us
            # nll_test = test(args=args, loader=dataloaders['test'], epoch=epoch, eval_model=model_ema_dp,
            #                 pp_model=pp_model_ema_dp, partition='Test', device=device, dtype=dtype,
            #                 nodes_dist=nodes_dist, property_norms=property_norms)

            # TODO: monitor best model in terms of validity
            #if nll_val < best_nll_val:
            if combined_score >= best_combined_score:
                best_combined_score = combined_score
                best_epoch = epoch

                #best_nll_val = nll_val
                #best_nll_test = nll_test
                if args.save_model:
                    args.current_epoch = epoch + 1
                    utils.save_model(optim, 'outputs/%s/optim.npy' % args.exp_name)
                    utils.save_model(model, 'outputs/%s/generative_model.npy' % args.exp_name)
                    utils.save_model(pp_model, 'outputs/%s/pp_model.npy' % args.exp_name)
                    if args.ema_decay > 0:
                        utils.save_model(model_ema, 'outputs/%s/generative_model_ema.npy' % args.exp_name)
                        utils.save_model(pp_model_ema, 'outputs/%s/pp_model_ema.npy' % args.exp_name)
                    with open('outputs/%s/args.pickle' % args.exp_name, 'wb') as f:
                        pickle.dump(args, f)

                if args.save_model_history:
                    utils.save_model(optim, 'outputs/%s/optim_%d.npy' % (args.exp_name, epoch))
                    utils.save_model(model, 'outputs/%s/generative_model_%d.npy' % (args.exp_name, epoch))
                    utils.save_model(pp_model, 'outputs/%s/pp_model_%d.npy' % (args.exp_name, epoch))
                    if args.ema_decay > 0:
                        utils.save_model(model_ema, 'outputs/%s/generative_model_ema_%d.npy' % (args.exp_name, epoch))
                        utils.save_model(pp_model_ema, 'outputs/%s/pp_model_ema_%d.npy' % (args.exp_name, epoch))
                    with open('outputs/%s/args_%d.pickle' % (args.exp_name, epoch), 'wb') as f:
                        pickle.dump(args, f)
            print('Val loss: %.4f' % nll_val)
            #print('Best val loss: %.4f' % best_nll_val)
            wandb.log({"Val loss ": nll_val}, commit=True)
            # wandb.log({"Test loss ": nll_test}, commit=True)
            # wandb.log({"Best cross-validated test loss ": best_nll_test}, commit=True)

        # patience
        if epoch - best_epoch > args.patience:
            print(f'Stopping the training after waiting {args.patience} epochs and the combined score did not improve')
            break


if __name__ == "__main__":
    main()
