import os
import argparse
import torch
import torch.utils.tensorboard
from torch.nn.utils import clip_grad_norm_
from torch_geometric.data import Batch, DataLoader
from tqdm.auto import tqdm

from models.edgecnf import *
from models.ebm import *
from datasets import *
from utils.transforms import *
from utils.misc import *
from models.common import GradualWarmupScheduler


# Arguments
parser = argparse.ArgumentParser()
# BEGIN Arguments
# EBM arguments
parser.add_argument('--ebm_activation', type=str, default='ssp')
parser.add_argument('--ebm_hidden_dim', type=int, default=128)
parser.add_argument('--ebm_num_layers', type=int, default=6)
parser.add_argument('--ebm_cutoff', type=float, default=10.0)
parser.add_argument('--ebm_num_gaussians', type=int, default=50)
parser.add_argument('--ebm_alpha', type=float, default=.01,
                    help='Recommended: 0.01, Coefficient of the L2-regularization on the magnitude of energy.')
parser.add_argument('--ebm_lambda', type=float, default=0,
                    help='Coefficient of the L2-regularization on the magnitude of score function.')
# Pre-trained EdgeCNF
parser.add_argument('--edgecnf_ckpt', type=str, default='logs/ECNF_2020_08_21__13_31_32_B128N0.1')
# Dataset
parser.add_argument('--train_dataset', type=str, default='./data/qm9/QM9_train.pkl')
# parser.add_argument('--test_dataset', type=str, default='./data/qm9/QM9_test.pkl')
# parser.add_argument('--val_dataset', type=str, default='./data/qm9/QM9_val.pkl')
parser.add_argument('--aux_edge_order', type=int, default=3)
parser.add_argument('--train_batch_size', type=int, default=384)
# parser.add_argument('--val_batch_size', type=int, default=384)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pos_sample_noise', type=float, default=0.5)  # Currently 0.5 seems good, less is worse
# Optimizing
parser.add_argument('--loss', type=str, default='nce', choices=['nce', 'cd', 'mle'])
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--warmup_iter', type=int, default=0)
parser.add_argument('--sched_step_size', type=int, default=1000)
parser.add_argument('--sched_factor', type=float, default=0.6)
parser.add_argument('--sched_min_lr', type=int, default=5e-6)
parser.add_argument('--max_grad_norm', type=float, default=50)
parser.add_argument('--beta1', type=float, default=0.99)
parser.add_argument('--beta2', type=float, default=0.999)
# Training
parser.add_argument('--seed', type=int, default=2020)
parser.add_argument('--logging', type=eval, default=True, choices=[True, False])
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--max_iters', type=float, default=1000*1000)
parser.add_argument('--val_freq', type=float, default=100)
# Misc
parser.add_argument('--log_root', type=str, default='./logs_EBM')
parser.add_argument('--tag', type=str, default='')
# END Arguments
args = parser.parse_args()
seed_all(args.seed)


# Logging
if args.logging:
    log_dir = get_new_log_dir(root=args.log_root, prefix='EbmEf', tag=args.tag)
    logger = get_logger('train', log_dir)
    writer = torch.utils.tensorboard.SummaryWriter(log_dir)
    ckpt_mgr = CheckpointManager(log_dir)
    log_hyperparams(writer, args)
else:
    logger = get_logger('train', None)
    writer = BlackHole()
    ckpt_mgr = BlackHole()
logger.info(args)


# Datasets and loaders
logger.info('Loading dataset...')
tf = get_standard_transforms(order=args.aux_edge_order)
train_dset = MoleculeDataset(args.train_dataset, transform=tf)
# val_dset = MoleculeDataset(args.val_dataset, transform=tf)
train_iterator = get_data_iterator(DataLoader(train_dset, batch_size=args.train_batch_size, shuffle=True, drop_last=True))
# val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, shuffle=False, drop_last=True)
logger.info('TrainSet %d' % (len(train_dset), ))


# Model: Load the pretrained EdgeCNF
logger.info('Loading pre-trained EdgeCNF model...')
ckpt_mgr_ef = CheckpointManager(args.edgecnf_ckpt)
ckpt = ckpt_mgr_ef.load_best()
model_cnf = EdgeCNF(ckpt['args']).to(args.device)
if ckpt['args'].spectral_norm:
    add_spectral_norm(model_cnf, logger=logger)
model_cnf.load_state_dict(ckpt['state_dict'])


# Model: Build EBM
logger.info('Building EBM model...')
model_ebm = EBM(args).to(args.device)
# model_ebm.enable_spectral_norm(logger=logger)
logger.info(repr(model_ebm))


# Optimizer and Scheduler for EBM
optimizer = torch.optim.Adam(
    model_ebm.parameters(),
    lr=args.lr,
    weight_decay=args.weight_decay,
    betas=(args.beta1, args.beta2)
)
# scheduler = GradualWarmupScheduler(
#     optimizer,
#     multiplier=1,
#     total_epoch=args.warmup_iter,
#     after_scheduler=torch.optim.lr_scheduler.StepLR(
#         optimizer,
#         step_size=args.sched_step_size,
#         gamma=args.sched_factor
#     )
# )


# Sampling
def negative_sample(batch):
    model_cnf.eval()
    pos_neg, d_neg = simple_generate_batch(
        model_cnf, batch.to(args.device), 
        num_samples=1, 
        embedder=Embed3D(mu=0.25, step_size=.05, num_steps=200, logger=logger), 
        dg_init_pos=batch.pos   # Small step size
    )
    batch_neg = batch.clone()
    batch_neg.pos = pos_neg.view(batch.pos.size())
    batch_neg.edge_length = d_neg.view(batch.edge_length.size())
    return batch_neg

def positive_sample(batch):
    embedder = Embed3D(mu=0.25, step_size=0.05, num_steps=200, logger=logger)
    model_cnf.eval()
    with torch.no_grad():
        z = model_cnf.get_z(batch, batch.edge_length)
        d = model_cnf.get_d(batch, z + torch.randn_like(z) * args.pos_sample_noise)
    pos, _ = embedder(d, batch.edge_index, batch.pos, batch.edge_order)
    batch_pos = batch.clone()
    batch_pos.pos = pos.view(batch.pos.size())
    batch_pos.edge_length = d.view(batch.edge_length.size())
    return batch_pos


# Train and validation
def train_ebm(it):
    model_ebm.train()
    optimizer.zero_grad()

    batch = next(train_iterator).to(args.device)
    batch_pos = positive_sample(batch)
    batch_neg = negative_sample(batch)
    loss, others = model_ebm.get_loss_nce(batch_pos, batch_neg, alpha=args.ebm_alpha)

    loss.backward()
    orig_grad_norm = clip_grad_norm_(model_ebm.parameters(), args.max_grad_norm)
    optimizer.step()
    # scheduler.step()

    ener_pos_mean, ener_neg_mean = others['ener_pos'].mean().item(), others['ener_neg'].mean().item()
    ener_range = ener_neg_mean - ener_pos_mean
    logger.info('[Train] Iter %04d | Loss %.6f | Grad %.6f | Ener(+) %.6f | Ener(-) %.6f | Ener(D) %.6f' % (
        it, loss.item(), orig_grad_norm,
        ener_pos_mean, ener_neg_mean, ener_range
    ))
    writer.add_scalar('train_ebm/loss', loss, it)
    writer.add_scalar('train_ebm/grad_norm', orig_grad_norm, it)
    writer.add_scalar('train_ebm/lr', optimizer.param_groups[0]['lr'], it)
    for k, v in others.items():
        if isinstance(v, torch.Tensor) and v.numel() > 1:
            writer.add_scalar('train_ebm/%s' % k, v.mean(), it)
        else:
            writer.add_scalar('train_ebm/%s' % k, v, it)
    writer.flush()


# Main loop
logger.info('Start training...')
try:
    for it in range(1, args.max_iters + 1):
        train_ebm(it)
        if it % args.val_freq == 0:
            ckpt_mgr.save(model_ebm, args, 0, it)
except KeyboardInterrupt:
    logger.info('Terminating...')