from __future__ import print_function
import argparse
import pickle
import copy
import torch
import torch.utils.data
from torch import optim
import utils
import wandb
from tqdm import tqdm
import numpy as np

from configs.datasets_config import get_dataset_info
from qm9 import dataset
from bond_type_prediction.egnn_edge_model import EGNNEdgeModel
from bond_type_prediction.utils import compute_class_weight
from bond_type_prediction.edge_model_train_test import train_epoch, test_epoch
from equivariant_diffusion import utils as flow_utils


parser = argparse.ArgumentParser(description='Edge model training')
parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N',
                    help='experiment_name')
parser.add_argument('--model', type=str, default='ae_egnn', metavar='N',
                    help='available models: ae | ae_rf | ae_egnn | baseline')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--wandb_usr', type=str)
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb')
parser.add_argument('--online', type=bool, default=True, help='True = wandb online -- False = wandb offline')
parser.add_argument('--save_model', type=eval, default=True,
                    help='save model')

# EGNN args -->
parser.add_argument('--n_layers', type=int, default=4,
                    help='number of Equivariant blocks')
parser.add_argument('--inv_sublayers', type=int, default=2,
                    help='number of GCL layers in each Equivariant block')
parser.add_argument('--hidden_nf', type=int, default=64,
                    help='number of layers')
parser.add_argument('--tanh', type=eval, default=False,
                    help='use tanh in the coord_mlp')
parser.add_argument('--attention', type=eval, default=True,
                    help='use attention in the EGNN')
parser.add_argument('--norm_constant', type=float, default=1,
                    help='diff/(|diff| + norm_constant)')
parser.add_argument('--sin_embedding', type=eval, default=False,
                    help='whether using or not the sin embedding')
parser.add_argument('--normalization_factor', type=float, default=1,
                    help="Normalize the sum aggregation of EGNN")
parser.add_argument('--aggregation_method', type=str, default='sum',
                    help='"sum" or "mean"')
parser.add_argument('--encoder', type=str, default='egnn',
                    help='egnn or None')
parser.add_argument('--edge_head', type=str, default='mlp',
                    help='mlp or linear')
parser.add_argument('--edge_head_hidden_dim', type=int, default=32,
                    help='hidden dimensions of the edge head MLP')
parser.add_argument('--modify_h', action='store_true', default=False,
                    help='If True, the edge model will also modify the atom types and charges')
# <-- EGNN args

# Dataset args -->
parser.add_argument('--dataset', type=str, default='qm9',
                    help='qm9 | zinc250k | qm9_second_half (train only on the last 50K samples of the training dataset)')
parser.add_argument('--datadir', type=str, default='data/',
                    help='data directory')
parser.add_argument('--filter_n_atoms', type=int, default=None,
                    help='When set to an integer value, QM9 will only contain molecules of that amount of atoms')
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker for the dataloader')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--remove_h', action='store_true')
parser.add_argument('--include_charges', type=eval, default=True,
                    help='include atom charge or not')
# <-- Dataset args

# Optimization args -->
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 100)')
parser.add_argument('--patience', type=int, default=100, metavar='N',
                    help='number of epochs to wait before stopping the training if val accuracy does not improve (default: 100)')
parser.add_argument('--lr', type=float, default=5e-4, metavar='N',
                    help='learning rate')
parser.add_argument('--reg', type=float, default=1e-3, metavar='N',
                    help='regularizer for the equivariant autoencoder')
parser.add_argument('--weight_decay', type=float, default=1e-16, metavar='N',
                    help='clamp the output of the coords function if get too large')
parser.add_argument('--ema_decay', type=float, default=0.999,
                    help='Amount of EMA decay, 0 means off. A reasonable value'
                         ' is 0.999.')
parser.add_argument('--train_on_noisy_graphs', action='store_true', default=False,
                    help='If True, use perturbed molecules for training')
parser.add_argument('--sigma', type=float, default=0.1, metavar='N',
                    help='standard deviation for Gaussian noise added to the molecules positions')
parser.add_argument('--p', type=float, default=0.2, metavar='N',
                    help='probability of flipping the atom type or formal charge')
# <-- Optimization args

# legacy args from AE code -->
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log_interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--test_interval', type=int, default=2, metavar='N',
                    help='how many epochs to wait before logging test')
parser.add_argument('--generate-interval', type=int, default=100, metavar='N',
                    help='how many epochs to wait before logging test')
parser.add_argument('--outf', type=str, default='outputs_ae', metavar='N',
                    help='folder to output vae')
parser.add_argument('--plots', type=int, default=0, metavar='N',
                    help='Plot images of the graphs & adjacency matrices')
parser.add_argument('--emb_nf', type=int, default=8, metavar='N',
                    help='learning rate')
parser.add_argument('--K', type=int, default=8, metavar='N',
                    help='learning rate')
parser.add_argument('--noise_dim', type=int, default=0, metavar='N',
                    help='break the symmetry applying noise at the input of the AE')
parser.add_argument('--clamp', type=int, default=1, metavar='N',
                    help='clamp the output of the coords function if get too large (safe mechanism, it is not activated in practice)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
dtype = torch.float32

# wandb project
project = 'noisy_edge_model_sweep'

# Define sweep config
# For hyperparam optimization
# sweep_configuration = {
#     'method': 'bayes', # 'random'
#     'name': 'edge_model_sweep_no_h',
#     'metric': {'goal': 'maximize', 'name': 'val_molecule_accuracy'},
#     'parameters': 
#     {
#         'lr': {'values': [5e-4, 1e-4]},
#         'n_layers': {'values': [0,2,4]},
#         'inv_sublayers': {'values': [1, 2, 3]},
#         'hidden_nf': {'values': [32, 64, 128]},
#         'attention': {'values': [True, False]},
#         'edge_head_hidden_dim': {'values': [32, 64, 128]},
#         'optimizer': {'values': ['adam', 'adamw']},
#         'use_lr_scheduler': {'values': [True, False]},
#     }
# }

# For noise magnitudes
sweep_configuration = {
    'method': 'grid', # 'random'
    'name': 'noisy_edge_model_sweep',
    'metric': {'goal': 'maximize', 'name': 'val_noisy_molecule_accuracy'},
    'parameters': 
    {
        'sigma': {'values': [0.1, 0.01, 0.05, 0.2, 0.3, 0.5]},
        'p': {'values': [0.2, 0.1, 0.3, 0.4, 0.5, 0.6]},
    }
}


# Initialize sweep by passing in config. 
# (Optional) Provide a name of the project.

# initialize new sweep
sweep_id = wandb.sweep(
 sweep=sweep_configuration, 
 project=project
 )

# link to existing sweep
#sweep_id = 'la7r4tsq'
print(f'sweep_id: {sweep_id}')


def main():
    wandb.init(project=project)
    args.sigma = wandb.config.sigma
    args.p = wandb.config.p
    args.exp_name = f'noisy_edge_model_sweep_sigma_{args.sigma}_p_{args.p}'

    args.optimizer = 'adamw'
    args.use_lr_scheduler = False

    # args.lr = wandb.config.lr
    # args.n_layers = wandb.config.n_layers
    # args.inv_sublayers = wandb.config.inv_sublayers
    # args.hidden_nf = wandb.config.hidden_nf
    # args.attention = wandb.config.attention
    # args.edge_head_hidden_dim = wandb.config.edge_head_hidden_dim
    # args.optimizer = wandb.config.optimizer
    # args.use_lr_scheduler = wandb.config.use_lr_scheduler

    # hardcoded args specific to sweep runs
    args.epochs = 20
    args.patience = 3

    # in sweeps, we don't save be default. if you want to save, set this flag to True
    args.save_model = True
    if args.n_layers == 0:
        args.encoder = None
    else:
        args.encoder = 'egnn'

    print(f'args: {args}')

    # create folder outputs/exp_name where we will dump everything
    utils.create_folders(args)

    # Wandb config
    # if args.no_wandb:
    #     mode = 'disabled'
    # else:
    #     mode = 'online' if args.online else 'offline'
    # kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': project, 'config': args,
    #         'settings': wandb.Settings(_disable_stats=False), 'reinit': True, 'mode': mode}
    # wandb.init(**kwargs)
    # wandb.save('*.txt')

    # Dataset
    dataset_info = get_dataset_info(args.dataset, args.remove_h)
    print(f'dataset_info: {dataset_info}')
    dataloaders, charge_scale = dataset.retrieve_dataloaders(args)

    # Model
    in_node_nf = len(dataset_info['atom_decoder']) + int(args.include_charges)*3 # 3 for one-hot encoding of charges -1, 0, 1
    model = EGNNEdgeModel(in_node_nf=in_node_nf, 
                        in_edge_nf=1, 
                        hidden_nf=args.hidden_nf, 
                        device=device,
                        act_fn=torch.nn.SiLU(), 
                        n_layers=args.n_layers, 
                        attention=args.attention,
                        norm_diff=True, 
                        out_node_nf=None, 
                        tanh=args.tanh, 
                        coords_range=15, 
                        norm_constant=args.norm_constant,
                        inv_sublayers=args.inv_sublayers, 
                        sin_embedding=args.sin_embedding, 
                        normalization_factor=args.normalization_factor, 
                        aggregation_method=args.aggregation_method,
                        include_charges=args.include_charges,
                        encoder=args.encoder,
                        edge_head=args.edge_head,
                        edge_head_hidden_dim=args.edge_head_hidden_dim,
                        n_classes=5, # include aromatic types
                        modify_h=args.modify_h,
                        )
    print(model)

    # Initialize model copy for exponential moving average of params.
    if args.ema_decay > 0:
        model_ema = copy.deepcopy(model)
        ema = flow_utils.EMA(args.ema_decay)
    else:
        ema = None
        model_ema = model

    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, amsgrad=True, weight_decay=args.weight_decay)

    if args.use_lr_scheduler:
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    else:
        lr_scheduler = None

    best_molecule_accuracy = 0
    best_epoch = 0

    print('Computing class weights...')
    class_weight_dict = compute_class_weight(dataloaders['train'], dataset_info, recompute_class_weight=False)
    class_weight_dict = {key: value.to(device) for key, value in class_weight_dict.items()}
    print(f'Computed class_weight: {class_weight_dict}')

    # pass over val data before starting training
    if args.train_on_noisy_graphs:
        res_val = test_epoch(args, model_ema, -1, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=True)
    res_val_clean = test_epoch(args, model_ema, -1, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=False)
    for epoch in tqdm(range(0, args.epochs)):
        res_train = train_epoch(args, model, model_ema, ema, epoch, dataloaders['train'], optimizer, lr_scheduler, device, class_weight_dict=class_weight_dict)
        if args.train_on_noisy_graphs:
            res_val = test_epoch(args, model_ema, epoch, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=True)
        res_val_clean = test_epoch(args, model_ema, epoch, dataloaders['valid'], device, class_weight_dict=class_weight_dict, test_on_noisy_graphs=False)

        # update best and save new best model
        if res_val['molecule_accuracy'] > best_molecule_accuracy:
            best_molecule_accuracy = res_val['molecule_accuracy']
            best_epoch = epoch

            if args.save_model:
                args.current_epoch = epoch + 1
                utils.save_model(optimizer, 'outputs/%s/optim.npy' % args.exp_name)
                utils.save_model(model, 'outputs/%s/bond_prediction_model.npy' % args.exp_name)
                if args.ema_decay > 0:
                    utils.save_model(model_ema, 'outputs/%s/bond_prediction_model_ema.npy' % args.exp_name)
                with open('outputs/%s/args.pickle' % args.exp_name, 'wb') as f:
                    pickle.dump(args, f)

        # patience
        if epoch - best_epoch > args.patience:
            print(f'Stopping the training after waiting {args.patience} epochs and the val accuracy did not improve')
            break

        # nan
        if np.isnan(res_train['avg_loss']):
            print('Detected nan in training loss. Stopping the training')
            break


if __name__ == "__main__":
    wandb.agent(sweep_id, function=main, count=100, project=project)
