# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import tqdm
from data import dataset as dset
import torchvision.models as tmodels
import tqdm
from archs import models
import os
import itertools
import glob
import pdb
import math
import time
import sys
import random

import tensorboardX as tbx
import torch.backends.cudnn as cudnn

from flags import parser
from test_modular import test as full_test
from utils import utils

args = parser.parse_args()

os.makedirs(args.cv_dir + '/' + args.name, exist_ok=True)
utils.save_args(args)
print(' '.join(sys.argv))

print("")
print("Arguments:")
for arg_name in vars(args):
    print("{} = {}".format(arg_name, repr(getattr(args, arg_name))))
print("", flush=True)

random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)


def adjust_learning_rate(args, optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    factor = 1.0
    if epoch == args.steps[0]:
        factor = 0.1
    param_groups = optimizer.param_groups
    for p in param_groups:
        p['lr'] *= factor


#----------------------------------------------------------------#
def train(epoch):

    model.train()
    lossmeter = utils.AverageMeter()
    lossauxmeter = utils.AverageMeter()
    accmeter = utils.AverageMeter()

    train_loss = 0.0
    n_batches_accumulated = 0
    for idx, data in enumerate(trainloader):
        data = [d.cuda() for d in data]
        loss, all_losses, acc, _ = model(data)
        if acc is not None:
            accmeter.update(acc, data[0].shape[0])

        mainloss = all_losses['main_loss']
        lossmeter.update(mainloss.item(), data[0].shape[0])
        if 'aux_loss' in all_losses.keys():
            loss_aux = all_losses['aux_loss']
            lossauxmeter.update(loss_aux.item(), data[0].shape[0])

        if n_batches_accumulated == 0:
            optimizer.zero_grad()

        loss.backward()
        n_batches_accumulated += 1

        if n_batches_accumulated >= args.batch_accumulation:
            optimizer.step()
            n_batches_accumulated = 0

        train_loss += loss.item()
        if idx % args.print_every == 0 or idx <= 2:
            print(
                'Epoch: {} Iter: {}/{} | Loss: {:.3f}, Loss Aux: {:.3f}, Acc: {:.2f}'.
                format(
                    epoch, idx, len(trainloader), lossmeter.avg, lossauxmeter.avg, accmeter.avg
                ),
            )
            print(
                ','.join(
                    ['{}: {:.02f}  '.format(k, v.item()) for k, v in all_losses.items()]
                ),
                flush=True,
            )

    train_loss = train_loss / len(trainloader)
    logger.add_scalar('train_loss', train_loss, epoch)
    for k, v in all_losses.items():
        logger.add_scalar('train_{}'.format(k), v.item(), epoch)
    print(
        'Epoch: {} | Loss: {} | Acc: {}'.format(epoch, lossmeter.avg, accmeter.avg),
        flush=True,
    )
    return train_loss, accmeter.avg


def test(epoch):

    model.eval()

    accuracies = []
    all_attr_lab = []
    all_obj_lab = []
    all_pred = []
    pairs = valloader.dataset.pairs
    objs = valloader.dataset.objs
    attrs = valloader.dataset.attrs
    if args.test_set == 'test':
        val_pairs = valloader.dataset.test_pairs
    else:
        val_pairs = valloader.dataset.val_pairs
    train_pairs = valloader.dataset.train_pairs
    for idx, data in enumerate(valloader):
        data = [d.cuda() for d in data]
        attr_truth, obj_truth = data[1], data[2]
        _, _, _, predictions = model(data)
        predictions, feats = predictions
        all_pred.append(predictions)
        all_attr_lab.append(attr_truth)
        all_obj_lab.append(obj_truth)

        if idx % 100 == 0:
            print('Tested {}/{}'.format(idx, len(valloader)), flush=True)
    all_attr_lab = torch.cat(all_attr_lab)
    all_obj_lab = torch.cat(all_obj_lab)
    all_pair_lab = [
        val_pairs.index((attrs[all_attr_lab[i]], objs[all_obj_lab[i]]))
        for i in range(len(all_attr_lab))
    ]

    all_pred_dict = {}
    for k in all_pred[0].keys():
        all_pred_dict[k] = torch.cat(
            [all_pred[i][k] for i in range(len(all_pred))])
    all_accuracies = []

    # Calculate best unseen acc
    # put everything on cpu
    attr_truth, obj_truth = all_attr_lab.cpu(), all_obj_lab.cpu()
    pairs = list(
        zip(list(attr_truth.cpu().numpy()), list(obj_truth.cpu().numpy())))
    seen_ind = torch.LongTensor([
        i for i in range(len(attr_truth))
        if pairs[i] in evaluator_val.train_pairs
    ])
    unseen_ind = torch.LongTensor([
        i for i in range(len(attr_truth))
        if pairs[i] not in evaluator_val.train_pairs
    ])

    accuracies = []
    bias = 1e3
    args.bias = bias
    results = evaluator_val.score_model(
        all_pred_dict, all_obj_lab, bias=args.bias)
    match_stats = evaluator_val.evaluate_predictions(results, all_attr_lab,
                                                     all_obj_lab)
    accuracies.append(match_stats)
    meanAP = 0
    _, _, _, _, _, _, open_unseen_match = match_stats
    accuracies = zip(*accuracies)
    open_unseen_match = open_unseen_match.byte()
    accuracies = list(map(torch.mean, map(torch.cat, accuracies)))
    attr_acc, obj_acc, closed_acc, open_acc, objoracle_acc, open_seen_acc, open_unseen_acc = accuracies
    max_seen_scores = results['scores'][
        unseen_ind][:, evaluator_val.seen_mask].max(1)[0]
    max_unseen_scores = results['scores'][
        unseen_ind][:, 1 - evaluator_val.seen_mask].max(1)[0]
    unseen_score_diff = max_seen_scores - max_unseen_scores
    correct_unseen_score_diff = unseen_score_diff[open_unseen_match] - 1e-4
    full_unseen_acc = [(
        epoch,
        attr_acc,
        obj_acc,
        closed_acc,
        open_acc,
        (open_seen_acc * open_unseen_acc)**0.5,
        0.5 * (open_seen_acc + open_unseen_acc),
        open_seen_acc,
        open_unseen_acc,
        objoracle_acc,
        meanAP,
        bias,
    )]
    print(
        '(%s) E: %d | A: %.3f | O: %.3f | Cl: %.3f | Op: %.4f | OpHM: %.4f | OpAvg: %.4f | OpSeen: %.4f | OpUnseen: %.4f  | OrO: %.4f | maP: %.4f | bias: %.3f'
        % (
            args.test_set,
            epoch,
            attr_acc,
            obj_acc,
            closed_acc,
            open_acc,
            (open_seen_acc * open_unseen_acc)**0.5,
            0.5 * (open_seen_acc + open_unseen_acc),
            open_seen_acc,
            open_unseen_acc,
            objoracle_acc,
            meanAP,
            bias,
        ),
        flush=True,
    )
    return all_accuracies


#----------------------------------------------------------------#
print("Building datasets...", flush=True)

trainset = dset.CompositionDatasetActivations(
    root=args.data_dir,
    phase='train',
    split=args.splitname,
    num_negs=args.num_negs,
    pair_dropout=args.pair_dropout,
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.workers)
valset = dset.CompositionDatasetActivations(
    root=args.data_dir,
    phase=args.test_set,
    split=args.splitname,
    subset=args.subset,
)
valloader = torch.utils.data.DataLoader(
    valset,
    batch_size=args.test_batch_size,
    shuffle=False,
    num_workers=args.workers)

print("Building model...", flush=True)

if args.model == 'modularpretrained':
    model = models.GatedGeneralNN(
        trainset,
        args,
        num_layers=args.nlayers,
        num_modules_per_layer=args.nmods,
        gater_type=args.gater_type,
    )
else:
    raise (NotImplementedError)

evaluator_train = models.Evaluator(trainset, model)
evaluator_val = models.Evaluator(valset, model)

if 'modular' in args.model:
    gating_params = [
        param for name, param in model.named_parameters()
        if 'gating_network' in name and param.requires_grad
    ]
    network_params = [
        param for name, param in model.named_parameters()
        if 'gating_network' not in name and param.requires_grad
    ]
    optim_params = [
        {
            'params': network_params,
        },
        {
            'params': gating_params,
            'lr': args.lrg
        },
    ]
    optimizer = optim.Adam(optim_params, lr=args.lr, weight_decay=args.wd)
else:
    params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)

model.cuda()

start_epoch = 0
best_epoch = -1
best_auc_val = -1
if args.load is None:
    for epoch in range(1000, -1, -1):
        ckpt_files = glob.glob(args.cv_dir +
                               '/{}/ckpt_E_{}*.t7'.format(args.name, epoch))
        if len(ckpt_files):
            args.load = ckpt_files[-1]
            break
if args.load and os.path.isfile(args.load):
    print("loading checkpoint {}".format(args.load), flush=True)
    checkpoint = torch.load(args.load)
    model.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch']
    best_auc_val = checkpoint['best_auc_val']
    # Use start_epoch as the default so early stopping will still work if
    # best_epoch was not defined
    best_epoch = checkpoint.get("best_epoch", start_epoch)
    print(
        "loaded checkpoint from {} (epoch {})".format(args.load, checkpoint["epoch"]),
        flush=True,
    )
elif args.load:
    print("no checkpoint found at '{}'".format(args.resume))

print("Model architecture:")
print(model)
print()
print(
    "Number of model parameters (total):     {}".format(
        utils.count_parameters(model, only_trainable=False)
    )
)
print(
    "Number of model parameters (trainable): {}".format(
        utils.count_parameters(model, only_trainable=True)
    )
)
for submodel_name in ("attr_embedder", "obj_embedder", "comp_network", "gating_network"):
    if not hasattr(model, submodel_name):
        continue
    print(
        "model.{} has {} parameters, of which {} are trainable".format(
            submodel_name,
            utils.count_parameters(getattr(model, submodel_name), only_trainable=False),
            utils.count_parameters(getattr(model, submodel_name), only_trainable=True),
        )
    )
print()

logger = tbx.SummaryWriter('logs/{}'.format(args.name))
if args.test_only:
    out = test(start_epoch)
else:
    train_log_fname = os.path.join(args.cv_dir, args.name, "train.csv")
    if start_epoch == 0 or not os.path.isfile(train_log_fname):
        with open(train_log_fname, "w") as f:
            f.write("data_dir,splitname,model,nparams,compose_type,module_type,nlayers,nmods,module_emb,module_hidden,module_actfun,gater_type,gater_emb,gater_actfun,lr,lrg,wd,nnegs,pdrop,pdropep,epoch,seed,k,loss_train,acc_train,auc_{}\n".format(args.test_set))
    for epoch in range(start_epoch + 1, args.max_epochs + 1):
        adjust_learning_rate(args, optimizer, epoch)
        if epoch % args.pair_dropout_epoch == 0:
            trainloader.dataset.reset_dropout()
        loss_train, acc_train = train(epoch)
        with torch.no_grad():
            _, auc_val = full_test(epoch, model, args, valloader, evaluator_val)
        is_best = auc_val >= best_auc_val
        if is_best:
            best_auc_val = auc_val
            best_epoch = epoch
            print("  New best validation AUC: {} (seen on epoch {})".format(best_auc_val, epoch))
        else:
            print("  ({}) Val AUC {} < {} (ep {})".format(epoch, auc_val, best_auc_val, best_epoch))
        if epoch % args.save_every == 0 or is_best or epoch >= args.max_epochs:
            state = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'auc_val': auc_val,
                'best_auc_val': best_auc_val,
                'best_epoch': best_epoch,
            }
            ckpt_path = os.path.join(args.cv_dir, args.name, "ckpt_E_{}.t7".format(epoch))
            print("  Writing checkpoint to {}".format(ckpt_path))
            torch.save(state, ckpt_path)
            if is_best:
                tmp_name = ckpt_path + ".{}.tmp".format(time.time())
                torch.save(state, tmp_name)
                os.rename(tmp_name, os.path.join(args.cv_dir, args.name, "ckpt_best.t7"))
                print("  Overwriting best checkpoint with current model")
            tmp_link = ckpt_path + ".{}.tmp".format(time.time())
            os.symlink(os.path.basename(ckpt_path), tmp_link)
            os.rename(tmp_link, os.path.join(args.cv_dir, args.name, "ckpt_latest.t7"))
            if args.num_checkpoints is not None and args.num_checkpoints > 0 and epoch > args.num_checkpoints:
                ckpt_path_old = os.path.join(
                    args.cv_dir,
                    args.name,
                    "ckpt_E_{}.t7".format(epoch - args.num_checkpoints),
                )
                if os.path.exists(ckpt_path_old):
                    os.remove(ckpt_path_old)
        with open(train_log_fname, "a") as f:
            f.write(
                "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                    args.data_dir,
                    args.splitname,
                    args.model,
                    utils.count_parameters(model, only_trainable=False),
                    args.compose_type,
                    args.module_type,
                    args.nlayers,
                    args.nmods,
                    args.emb_dim,
                    args.module_hidden,
                    args.module_actfun,
                    args.gater_type,
                    args.embed_rank,
                    args.gater_actfun,
                    args.lr,
                    args.lrg,
                    args.wd,
                    args.num_negs,
                    args.pair_dropout,
                    args.pair_dropout_epoch,
                    epoch,
                    args.seed,
                    args.topk,
                    loss_train,
                    acc_train,
                    auc_val,
                )
            )
        if (
            not is_best
            and best_epoch >= 0
            and args.early_stopping >= 0
            and epoch - best_epoch > args.early_stopping
        ):
            print(
                "Stopping early at epoch {} because validation performance ({})"
                " has not increased since epoch {}, where it was {}.".format(
                    epoch, auc_val, best_epoch, best_auc_val
                )
            )
            break
