# Rdkit import should be first, do not move it
try:
    from rdkit import Chem
except ModuleNotFoundError:
    pass
import logging
logging.getLogger().setLevel(logging.INFO)
from sys import exit
import copy
import random

import utils
import wandb
from configs.datasets_config import get_dataset_info
from os.path import join
from qm9 import dataset
from qm9.models import get_optim, get_model
from bond_type_prediction.initialize_pp_model import get_pp_model
from equivariant_diffusion import en_diffusion
from equivariant_diffusion.utils import assert_correctly_masked
from equivariant_diffusion import utils as flow_utils
import torch
import numpy as np
import time
import pickle
from qm9.utils import prepare_context, compute_mean_mad, compute_properties_upper_bounds
from train_test import train_epoch, test, analyze_and_save, eval_vae_reconstruction
from guacamol_evaluation.evaluator import GuacamolEvaluator
from geo_ldm.init_vae import get_vae
from geo_ldm.init_latent_diffuser import get_latent_diffusion
from geo_ldm.args import get_args, setup_args
from conditional_generation.prop_encoding import PropertyEncoding
from prodigyopt import Prodigy
from train_loop import main_training_loop


# --> SET UP ARGS
args = get_args()
args = setup_args(args)
print('args:', args)

if hasattr(args, 'seed'):
    # seed everything
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

# --> SET UP WANDB
# Wandb config
if args.no_wandb:
    mode = 'disabled'
else:
    mode = 'online' if args.online else 'offline'

wandb_project = 'GeoLDM_exps_for_paper'
if len(args.conditioning) > 0:
    wandb_project = 'GeoLDM_conditional_exps_for_paper'
elif args.train_regressor:
    wandb_project = 'GeoLDM_regressor_exps_for_paper'

kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': wandb_project, 'config': args,
          'settings': wandb.Settings(_disable_stats=False), 'reinit': True, 'mode': mode}
wandb.init(**kwargs)
wandb.save('*.txt')

# --> SET UP FOLDER FOR CKPTS, DEVICE, AND DTYPE
# Create folder called exp_name, where checkpoints will be saved
utils.create_folders(args.exp_name)
# Set device and dtype
device = torch.device("cuda" if args.cuda else "cpu")
dtype = torch.float32

# --> SET UP DATASET_INFO
# Get relevant dataset info
dataset_info = get_dataset_info(args.dataset, args.remove_h)
if args.use_ghost_nodes:
    dataset_info['atom_decoder'].insert(0, 'Ghost')
    dataset_info['atom_encoder'] = {node: idx for idx, node in enumerate(dataset_info['atom_decoder'])}
atom_encoder = dataset_info['atom_encoder']    
print(dataset_info)

# --> SET UP DATALOADERS
# Retrieve dataloaders
dataloaders, charge_scale = dataset.retrieve_dataloaders(args, args.debug)
if args.use_vocab_data:
    for n in dataloaders['vocab'].dataset.data['num_atoms']:
        n = n.item()
        if n in dataset_info['n_nodes']:
            dataset_info['n_nodes'][n] += 1
        else:
            dataset_info['n_nodes'][n] = 1

data_dummy = next(iter(dataloaders['train']))
# set up extra features
if args.use_extra_atomic_features:
    args.n_extra_atomic_features = data_dummy['atomic_features'].size(2)
else:
    args.n_extra_atomic_features = 0

# Set up conditioning
prop_encoder = None
if len(args.conditioning) > 0:
    print(f'Conditioning on {args.conditioning}')
    if 'morgan_fingerprint' not in args.conditioning:
        # compute mean and average dev from mean (mad)
        property_norms = compute_mean_mad(dataloaders, args.conditioning, args.dataset)
        if args.encode_prop:
            print('Using target property encodings (sine and cosine)')
            prop_encoder = PropertyEncoding(nf=16)
        context_dummy = prepare_context(args.conditioning, data_dummy, property_norms, prop_encoder=prop_encoder, condition_dropout=args.condition_dropout)
        context_node_nf = context_dummy.size(2)
        property_upperbounds = compute_properties_upper_bounds(dataloaders['train'], args.conditioning)
    else:
        # conditioning on morgan_fingerprint
        context_dummy = prepare_context(args.conditioning, data_dummy, None, prop_encoder=prop_encoder)
        context_node_nf = context_dummy.size(2)
        property_norms = None
else:
    context_node_nf = 0
    property_norms = None
args.context_node_nf = context_node_nf

# Set up regression
if len (args.regression_target) > 0:
    print(f'Using regression target: {args.regression_target}')
    property_norms_regression = compute_mean_mad(dataloaders, args.regression_target, args.dataset)
else:
    property_norms_regression = None

# noise_sigma_vae = None
# if args.noise_sigma_vae is not None:
#     noise_sigma_vae = args.noise_sigma_vae
#     args.noise_sigma_vae = None

# first stage or second stage training
if args.ae_path is None:
    print('Starting first stage training...')
    first_stage = True
    train_diffusion = args.train_diffusion
    train_regressor = args.train_regressor
    test_epochs = args.test_epochs # will save it for second-stage training
    # run vae training
    args.train_diffusion = False
    args.train_regressor = False
    args.patience = 10
    args.test_epochs = 1
    # TODO: add arg for n_epochs_vae
    main_training_loop(args, first_stage, dataloaders, dataset_info, device, dtype, 
                        property_norms, property_norms_regression, prop_encoder, n_epochs=100 if not args.debug else 2)
    # in case we just want to train vae
    if not train_diffusion and not train_regressor:
        exit()
    # update args.ae_path for second-stage training
    args.ae_path = 'outputs/' + args.exp_name
    args.train_diffusion = train_diffusion
    args.train_regressor = train_regressor
    args.test_epochs = test_epochs
    first_stage = False
else:
    first_stage = False

# TODO: cache activations of encoder to speed up second-stage training
# ...

# if we arrived here, it means: either ae_path was provided and we just want to run second-stage training, 
# or firs-stage training just finished and we also want to run second-stage training
# second-stage training might be diffusion, conditional diffusion, property regression etc.
# run diffusion training
print('Starting second stage training...')
assert not first_stage
assert (args.train_diffusion and not args.train_regressor) or (args.train_regressor and not args.train_diffusion), \
    "Choose either diffusion or regressor to train"
# no early stopping for diffusion
args.patience = args.n_epochs
main_training_loop(args, first_stage, dataloaders, dataset_info, device, dtype, 
                    property_norms, property_norms_regression, prop_encoder, n_epochs=args.n_epochs if not args.debug else 2)



# def check_mask_correct(variables, node_mask):
#     for variable in variables:
#         if len(variable) > 0:
#             assert_correctly_masked(variable, node_mask)




# if __name__ == "__main__":
#     main(model, optim)
