from __future__ import print_function
import argparse
import pickle
import copy
from os.path import join

import torch
import torch.utils.data
from torch import optim
import utils
import wandb
from tqdm import tqdm

from configs.datasets_config import get_dataset_info
from qm9 import dataset
from bond_type_prediction.egnn_edge_model import EGNNEdgeModel
from bond_type_prediction.utils import compute_class_weight
from bond_type_prediction.edge_model_train_test import train_epoch, test_epoch
from equivariant_diffusion import utils as flow_utils


parser = argparse.ArgumentParser(description='Edge model training')
parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N',
                    help='experiment_name')
parser.add_argument('--model', type=str, default='ae_egnn', metavar='N',
                    help='available models: ae | ae_rf | ae_egnn | baseline')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--wandb_usr', type=str)
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
parser.add_argument('--online', type=bool, default=True, help='True = wandb online -- False = wandb offline')
parser.add_argument('--save_model', type=eval, default=True,
                    help='save model')
parser.add_argument('--debug', action='store_true', default=False,
                    help='If True, use only few batches and evaluate on training set')
parser.add_argument('--resume', type=str, default=None,
                    help='set to the checkpoint path to resume training. If None, will start a new training')

# EGNN args -->
parser.add_argument('--n_layers', type=int, default=4,
                    help='number of Equivariant blocks')
parser.add_argument('--inv_sublayers', type=int, default=2,
                    help='number of GCL layers in each Equivariant block')
parser.add_argument('--hidden_nf', type=int, default=64,
                    help='number of layers')
parser.add_argument('--tanh', type=eval, default=False,
                    help='use tanh in the coord_mlp')
parser.add_argument('--attention', type=eval, default=True,
                    help='use attention in the EGNN')
parser.add_argument('--norm_constant', type=float, default=1,
                    help='diff/(|diff| + norm_constant)')
parser.add_argument('--sin_embedding', type=eval, default=False,
                    help='whether using or not the sin embedding')
parser.add_argument('--normalization_factor', type=float, default=1,
                    help="Normalize the sum aggregation of EGNN")
parser.add_argument('--aggregation_method', type=str, default='sum',
                    help='"sum" or "mean"')
parser.add_argument('--encoder', type=str, default='egnn',
                    help='egnn or None')
parser.add_argument('--edge_head', type=str, default='mlp',
                    help='mlp or linear')
parser.add_argument('--edge_head_hidden_dim', type=int, default=32,
                    help='hidden dimensions of the edge head MLP')
parser.add_argument('--modify_h', action='store_true', default=False,
                    help='If True, the edge model will also modify the atom types and charges')
# <-- EGNN args

# Dataset args -->
parser.add_argument('--dataset', type=str, default='qm9',
                    help='qm9 | zinc250k | qm9_second_half (train only on the last 50K samples of the training dataset)')
parser.add_argument('--datadir', type=str, default='data/',
                    help='data directory')
parser.add_argument('--filter_n_atoms', type=int, default=None,
                    help='When set to an integer value, QM9 will only contain molecules of that amount of atoms')
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker for the dataloader')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--remove_h', action='store_true')
parser.add_argument('--include_charges', type=eval, default=True,
                    help='include atom charge or not')
# <-- Dataset args

# Optimization args -->
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--patience', type=int, default=100, metavar='N',
                    help='number of epochs to wait before stopping the training if val accuracy does not improve (default: 100)')
parser.add_argument('--lr', type=float, default=5e-4, metavar='N',
                    help='learning rate')
parser.add_argument('--reg', type=float, default=1e-3, metavar='N',
                    help='regularizer for the equivariant autoencoder')
parser.add_argument('--weight_decay', type=float, default=1e-16, metavar='N',
                    help='clamp the output of the coords function if get too large')
parser.add_argument('--ema_decay', type=float, default=0.999,
                    help='Amount of EMA decay, 0 means off. A reasonable value'
                         ' is 0.999.')
parser.add_argument('--recompute_class_weight', action='store_true', default=False,
                    help='Forces recomuputation of class weights')
parser.add_argument('--train_on_noisy_graphs', action='store_true', default=False,
                    help='If True, use perturbed molecules for training')
parser.add_argument('--sigma', type=float, default=0.1, metavar='N',
                    help='standard deviation for Gaussian noise added to the molecules positions')
parser.add_argument('--p', type=float, default=0.2, metavar='N',
                    help='probability of flipping the atom type or formal charge')
# <-- Optimization args

# legacy args from AE code -->
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log_interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--test_interval', type=int, default=2, metavar='N',
                    help='how many epochs to wait before logging test')
parser.add_argument('--generate-interval', type=int, default=100, metavar='N',
                    help='how many epochs to wait before logging test')
parser.add_argument('--outf', type=str, default='outputs_ae', metavar='N',
                    help='folder to output vae')
parser.add_argument('--plots', type=int, default=0, metavar='N',
                    help='Plot images of the graphs & adjacency matrices')
parser.add_argument('--emb_nf', type=int, default=8, metavar='N',
                    help='learning rate')
parser.add_argument('--K', type=int, default=8, metavar='N',
                    help='learning rate')
parser.add_argument('--noise_dim', type=int, default=0, metavar='N',
                    help='break the symmetry applying noise at the input of the AE')
parser.add_argument('--clamp', type=int, default=1, metavar='N',
                    help='clamp the output of the coords function if get too large (safe mechanism, it is not activated in practice)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
dtype = torch.float32

if args.resume is not None:
    resume = args.resume
    debug = args.debug
    recompute_class_weight = args.recompute_class_weight
    epochs = args.epochs
    patience = args.patience

    with open(join(args.resume, 'args.pickle'), 'rb') as f:
        args = pickle.load(f)

    args.resume = resume
    args.debug = debug
    args.recompute_class_weight = recompute_class_weight
    args.epochs = epochs
    args.patience = patience
    args.exp_name += '_resume'

print(args)

# create folder outputs/exp_name where we will dump everything
utils.create_folders(args)

# Wandb config
if args.no_wandb:
    mode = 'disabled'
else:
    mode = 'online' if args.online else 'offline'
kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': 'bond_type_prediction', 'config': args,
          'settings': wandb.Settings(_disable_stats=False), 'reinit': True, 'mode': mode}
wandb.init(**kwargs)
wandb.save('*.txt')

# Dataset
dataset_info = get_dataset_info(args.dataset, args.remove_h)
print(dataset_info)
# if debug, use limited train data, and test overfitting
dataloaders, charge_scale = dataset.retrieve_dataloaders(args, args.debug)
if args.debug:
    dataloaders['valid'] = dataloaders['train']

# Model
in_node_nf = len(dataset_info['atom_decoder']) + int(args.include_charges)*3 # 3 for one-hot encoding of charges -1, 0, 1
model = EGNNEdgeModel(in_node_nf=in_node_nf, 
                    in_edge_nf=1, 
                    hidden_nf=args.hidden_nf, 
                    device=device,
                    act_fn=torch.nn.SiLU(), 
                    n_layers=args.n_layers, 
                    attention=args.attention,
                    norm_diff=True, 
                    out_node_nf=None, 
                    tanh=args.tanh, 
                    coords_range=15, 
                    norm_constant=args.norm_constant,
                    inv_sublayers=args.inv_sublayers, 
                    sin_embedding=args.sin_embedding, 
                    normalization_factor=args.normalization_factor, 
                    aggregation_method=args.aggregation_method,
                    include_charges=args.include_charges,
                    encoder=args.encoder,
                    edge_head=args.edge_head,
                    edge_head_hidden_dim=args.edge_head_hidden_dim,
                    n_classes=5, # include aromatic types
                    modify_h=args.modify_h,
                    )
print(model)

# Initialize model copy for exponential moving average of params.
if args.ema_decay > 0:
    model_ema = copy.deepcopy(model)
    ema = flow_utils.EMA(args.ema_decay)
else:
    ema = None
    model_ema = model

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

if args.resume is not None:
    model_state_dict = torch.load(join(args.resume, 'bond_prediction_model.npy'))
    optimizer_state_dict = torch.load(join(args.resume, 'optim.npy'))

    model.load_state_dict(model_state_dict)
    optimizer.load_state_dict(optimizer_state_dict)

    if args.ema_decay > 0:
        model_ema_state_dict = torch.load(join(args.resume, 'bond_prediction_model_ema.npy'))
        model_ema.load_state_dict(model_ema_state_dict)


def main():
    best_molecule_accuracy = 0
    best_epoch = 0

    print('Computing class weights...')
    class_weight_dict = compute_class_weight(dataloaders['train'], dataset_info, recompute_class_weight=args.recompute_class_weight)
    class_weight_dict = {key: value.to(device) for key, value in class_weight_dict.items()}
    print(f'Computed class_weight: {class_weight_dict}')

    # pass over val data before starting training
    if args.train_on_noisy_graphs:
        res_val = test_epoch(args, model_ema, -1, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=True)
    res_val_clean = test_epoch(args, model, -1, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=False)
    for epoch in tqdm(range(0, args.epochs)):
        res_train = train_epoch(args, model, model_ema, ema, epoch, dataloaders['train'], optimizer, lr_scheduler, device, class_weight_dict=class_weight_dict)
        if args.train_on_noisy_graphs:
            res_val = test_epoch(args, model_ema, epoch, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=True)
        res_val_clean = test_epoch(args, model_ema, epoch, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=False)

        # update best and save new best model
        if res_val['molecule_accuracy'] > best_molecule_accuracy:
            best_molecule_accuracy = res_val['molecule_accuracy']
            best_epoch = epoch

            if args.save_model:
                args.current_epoch = epoch + 1
                utils.save_model(optimizer, 'outputs/%s/optim.npy' % args.exp_name)
                utils.save_model(model, 'outputs/%s/bond_prediction_model.npy' % args.exp_name)
                if args.ema_decay > 0:
                    utils.save_model(model_ema, 'outputs/%s/bond_prediction_model_ema.npy' % args.exp_name)
                with open('outputs/%s/args.pickle' % args.exp_name, 'wb') as f:
                    pickle.dump(args, f)

        # patience
        if epoch - best_epoch > args.patience:
            print(f'Stopping the training after waiting {args.patience} epochs and the val accuracy did not improve')
            break


if __name__ == "__main__":
    main()
