#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import copy

import argparse
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
import ot
import torch
import torch.nn as nn
import torch.optim as optim
from pytorch_pretrained_bert import BertAdam
import torch.fft
from src.data.helpers import get_data_loaders
from src.models import get_model
from src.utils.logger import create_logger
from src.utils.utils import *
import geomloss
from scipy.stats import wasserstein_distance 

def get_args(parser):
    parser.add_argument("--batch_sz", type=int, default=128)
    parser.add_argument("--bert_model", type=str, default="../bert-base-uncased")#, choices=["bert-base-uncased", "bert-large-uncased"])
    parser.add_argument("--data_path", type=str, default="../../../../dataset")
    parser.add_argument("--drop_img_percent", type=float, default=0.0)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--embed_sz", type=int, default=300)
    parser.add_argument("--freeze_img", type=int, default=0)
    parser.add_argument("--freeze_txt", type=int, default=0)
    parser.add_argument("--glove_path", type=str, default="./datasets/glove_embeds/glove.840B.300d.txt")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=24)
    parser.add_argument("--hidden", nargs="*", type=int, default=[])
    parser.add_argument("--hidden_sz", type=int, default=768)
    parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
    parser.add_argument("--img_hidden_sz", type=int, default=2048)
    parser.add_argument("--include_bn", type=int, default=True)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--lr_factor", type=float, default=0.5)
    parser.add_argument("--lr_patience", type=int, default=2)
    parser.add_argument("--max_epochs", type=int, default=100)
    parser.add_argument("--max_seq_len", type=int, default=512)
    parser.add_argument("--model", type=str, default="bow", choices=["bow", "img", "bert", "concatbow", "concatbert", "mmbt", "latefusion", "intermediate"])
    parser.add_argument("--n_workers", type=int, default=12)
    parser.add_argument("--name", type=str, default="nameless")
    parser.add_argument("--num_image_embeds", type=int, default=1)
    parser.add_argument("--patience", type=int, default=10)
    parser.add_argument("--savedir", type=str, default="/path/to/save_dir/")
    parser.add_argument("--seed", type=int, default=123)
    parser.add_argument("--task", type=str, default="mmimdb", choices=["mmimdb", "vsnli", "food101","MVSA_Single", "HFM", "MVSA_Multiple"])
    parser.add_argument("--task_type", type=str, default="multilabel", choices=["multilabel", "classification"])
    parser.add_argument("--warmup", type=float, default=0.1)
    parser.add_argument("--weight_classes", type=int, default=1)
    parser.add_argument("--df", type=bool, default=False)
    parser.add_argument("--noise", type=float, default=0.0)
    parser.add_argument('--hyperparameter', type=float, nargs='+')


def cs(x1, target_shape=64):
    x2 = torch.fft.fft(x1)
    x2[torch.abs(x2) < 1] = 0
    x3 = torch.fft.ifft(x2)
    x4 = x3.to(dtype=x1.dtype)
    y = torch.rand(x1.shape[-1], target_shape, dtype=torch.float32).cuda()
    x5 = torch.matmul(x4, y)
    return x5

def rand_projections(dim, num_projections=1000):
    projections = torch.randn((num_projections, dim))
    projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
    return projections

def sliced_wasserstein_distance(first_samples,
                                second_samples,
                                num_projection=1000,
                                p=2,
                                device='cuda'):
    dim = second_samples.size(1)
    projections = rand_projections(dim, num_projection).to(device)
    first_projections = first_samples.matmul(projections.transpose(0, 1))
    second_projections = (second_samples.matmul(projections.transpose(0, 1)))
    wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
                            torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p), dim=1),1./p)
    return torch.pow(torch.pow(wasserstein_distance, p).mean(),1./p)

def get_criterion(args):
    if args.task_type == "multilabel":
        if args.weight_classes:
            freqs = [args.label_freqs[l] for l in args.labels]
            label_weights = (torch.FloatTensor(freqs) / args.train_data_len) ** -1
            criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights.cuda())
        else:
            criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    return criterion


def get_optimizer(model, args):
    if args.model in ["bert", "concatbert", "mmbt"]:
        total_steps = (
            args.train_data_len
            / args.batch_sz
            / args.gradient_accumulation_steps
            * args.max_epochs
        )
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
            {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0,},
        ]
        optimizer = BertAdam(
            optimizer_grouped_parameters,
            lr=args.lr,
            warmup=args.warmup,
            t_total=total_steps,
        )
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr)

    return optimizer


def get_scheduler(optimizer, args):
    return optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor
    )


def model_eval(i_epoch, data, model, args, criterion, store_preds=False):
    with torch.no_grad():
        losses, preds, tgts = [], [], []
        for batch in data:
            loss, out, tgt = model_forward(i_epoch, model, args, criterion, batch, mode='eval')
            losses.append(loss.item())

            if args.task_type == "multilabel":
                pred = torch.sigmoid(out).cpu().detach().numpy() > 0.5
            else:
                pred = torch.nn.functional.softmax(out, dim=1).argmax(dim=1).cpu().detach().numpy()

            preds.append(pred)
            tgt = tgt.cpu().detach().numpy()
            tgts.append(tgt)

    metrics = {"loss": np.mean(losses)}
    if args.task_type == "multilabel":
        tgts = np.vstack(tgts)
        preds = np.vstack(preds)
        metrics["macro_f1"] = f1_score(tgts, preds, average="macro")
        metrics["micro_f1"] = f1_score(tgts, preds, average="micro")
    else:
        tgts = [l for sl in tgts for l in sl]
        preds = [l for sl in preds for l in sl]
        metrics["acc"] = accuracy_score(tgts, preds)

    if store_preds:
        store_preds_to_disk(tgts, preds, args)

    return metrics


def rank_loss(confidence, idx, history):
    # make input pair
    rank_input1 = confidence
    rank_input2 = torch.roll(confidence, -1)
    idx2 = torch.roll(idx, -1)

    # calc target, margin
    rank_target, rank_margin = history.get_target_margin(idx, idx2)
    rank_target_nonzero = rank_target.clone()
    rank_target_nonzero[rank_target_nonzero == 0] = 1
    rank_input2 = rank_input2 + (rank_margin / rank_target_nonzero).reshape((-1,1))

    # ranking loss
    ranking_loss = nn.MarginRankingLoss(margin=0.0)(rank_input1,
                                        rank_input2,
                                        -rank_target.reshape(-1,1))

    return ranking_loss

def loss_function(ce_loss, y_pred, y, mu, std):
   
    CE = ce_loss(y_pred, y)
    KL = 0.5 * torch.mean(mu.pow(2) + std.pow(2) - 2*std.log() - 1)

    return (1e-3 * KL + CE) 

def model_forward(i_epoch, model, args, OT, batch,txt_history=None, img_history=None, mode='eval'):
    txt, segment, mask, img, tgt, idx = batch
    freeze_img = i_epoch < args.freeze_img
    freeze_txt = i_epoch < args.freeze_txt

    if args.model == "bow":
        txt = txt.cuda()
        out = model(txt)
    elif args.model == "img":
        img = img.cuda()
        out = model(img)
    elif args.model == "concatbow":
        txt, img = txt.cuda(), img.cuda()
        out = model(txt, img)
    elif args.model == "bert":
        txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda()
        out = model(txt, mask, segment)
    elif args.model == "concatbert":
        txt, img = txt.cuda(), img.cuda()
        mask, segment = mask.cuda(), segment.cuda()
        out = model(txt, mask, segment, img)
    # QMF use the same backbone with late fusion, the only difference are dynamic fusion weights
    elif args.model == "latefusion":
        txt, img = txt.cuda(), img.cuda()
        mask, segment = mask.cuda(), segment.cuda()
        txt_img_logits, txt_logits, img_logits, txt_conf, img_conf = model(txt, mask, segment, img)
    
    elif args.model == "intermediate":
        txt, img = txt.cuda(), img.cuda()
        mask, segment = mask.cuda(), segment.cuda()
        txt_img_logits, txt_logits, img_logits, txt_conf, img_conf, txt_feature_hd, img_feature_hd, txt_distribution, img_distribution, multimodal = model(txt, mask, segment, img)

    else:
        assert args.model == "mmbt"
        for param in model.enc.img_encoder.parameters():
            param.requires_grad = not freeze_img
        for param in model.enc.encoder.parameters():
            param.requires_grad = not freeze_txt

        txt, img = txt.cuda(), img.cuda()
        mask, segment = mask.cuda(), segment.cuda()
        out = model(txt, mask, segment, img)

    tgt = tgt.cuda()

    txt_clf_loss = nn.CrossEntropyLoss()(txt_logits, tgt)
    img_clf_loss = nn.CrossEntropyLoss()(img_logits, tgt)

    clf_loss = txt_clf_loss + img_clf_loss + nn.CrossEntropyLoss()(txt_img_logits, tgt)

    txt_loss = nn.CrossEntropyLoss(reduction='none')(txt_logits, tgt).detach()
    img_loss = nn.CrossEntropyLoss(reduction='none')(img_logits, tgt).detach()

    if mode == 'train':

        txt_history.correctness_update(idx, txt_loss, txt_conf.squeeze())
        img_history.correctness_update(idx, img_loss, img_conf.squeeze())
        
        mutual_information_loss = loss_function(nn.CrossEntropyLoss(), multimodal[0], tgt, multimodal[1], multimodal[2]) + \
            loss_function(nn.CrossEntropyLoss(), txt_distribution[0], tgt, txt_distribution[1], txt_distribution[2]) +\
            loss_function(nn.CrossEntropyLoss(), img_distribution[0], tgt, img_distribution[1], img_distribution[2])    

        txt_feature_ld = cs(txt_feature_hd)
        img_feature_ld = cs(img_feature_hd)

        wasserstein_distance_ot = OT(txt_feature_ld, img_feature_ld)

        # transport_dist = ot.sinkhorn2(flat_data_t, flat_data_i, M, reg=0.01)
        
        txt_rank_loss = rank_loss(txt_conf, idx, txt_history)
        img_rank_loss = rank_loss(img_conf, idx, img_history)

        crl_loss = txt_rank_loss + img_rank_loss
        loss = torch.mean(clf_loss + crl_loss) + args.hyperparameter[0] * wasserstein_distance_ot + args.hyperparameter[1] * mutual_information_loss

        return loss, txt_img_logits, tgt
    else:
        loss = clf_loss

        return loss, txt_img_logits, tgt


def train(args):

    set_seed(args.seed)

    if args.hyperparameter:
        args.savedir = os.path.join(os.path.join(args.savedir, str(args.hyperparameter)), args.name)
    else:
        args.savedir = os.path.join(args.savedir, args.name)

    os.makedirs(args.savedir, exist_ok=True)

    train_loader, val_loader, test_loaders = get_data_loaders(args)

    model = get_model(args)

    criterion = get_criterion(args)
    optimizer = get_optimizer(model, args)
    scheduler = get_scheduler(optimizer, args)

    logger = create_logger("%s/logfile.log" % args.savedir, args)
    logger.info(model)
    model.cuda()


    torch.save(args, os.path.join(args.savedir, "args.pt"))

    start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf
    p, entreg = 1, 0.1 
    
    if os.path.exists(os.path.join(args.savedir, "best_checkpoint.pt")):
        best_checkpoint = torch.load(os.path.join(args.savedir, "best_checkpoint.pt"))
        start_epoch = best_checkpoint["epoch"]
        n_no_improve = best_checkpoint["n_no_improve"]
        best_metric = best_checkpoint["best_metric"]
        model.load_state_dict(best_checkpoint["state_dict"])
        optimizer.load_state_dict(best_checkpoint["optimizer"])
        scheduler.load_state_dict(best_checkpoint["scheduler"])

    logger.info("Training..")
    txt_history = History(len(train_loader.dataset))
    img_history = History(len(train_loader.dataset))

    for i_epoch in range(start_epoch, args.max_epochs):
        logger.info("This is the {}-th epoch:".format(i_epoch + 1))
        train_losses = []
        model.train()
        optimizer.zero_grad()

        
        OTLoss = geomloss.SamplesLoss(
            loss='sinkhorn', p=p,
            cost=geomloss.utils.distances,
            blur=entreg**(1/p), backend='tensorized')


        for batch in tqdm(train_loader, total=len(train_loader)):
            
            loss, _, _ = model_forward(i_epoch, model, args, OTLoss, batch,txt_history, img_history, mode='train')
            
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            train_losses.append(loss.item())
            loss.backward()
            global_step += 1
            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        model.eval()
        metrics = model_eval(i_epoch, val_loader, model, args, criterion)
        logger.info("Train Loss: {:.4f}".format(np.mean(train_losses)))
        log_metrics("Val", metrics, args, logger)

        tuning_metric = (
            metrics["micro_f1"] if args.task_type == "multilabel" else metrics["acc"]
        )
        scheduler.step(tuning_metric)
        is_improvement = tuning_metric > best_metric

        if is_improvement:
            best_metric = tuning_metric
            n_no_improve = 0
        else:
            n_no_improve += 1

        current_epoch = {
            "epoch": i_epoch + 1,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "n_no_improve": n_no_improve,
            "best_metric": best_metric,
        }

        logger.info("is_improvement: {}".format(is_improvement))

        if is_improvement:
            best_checkpoint = copy.deepcopy(current_epoch)
        
        if i_epoch % 15 == 0:  # we only save the checkpoint
            save_checkpoint(best_checkpoint, args.savedir)

    logger.info("Save the best point.") 
    
    save_checkpoint(best_checkpoint, args.savedir)
    
    logger.info("The best point has been saved.")    
    logger.info("This is the best result: {}".format(best_checkpoint['best_metric']))
    logger.info("The best epoch is : {}".format(best_checkpoint["epoch"]))

    model.load_state_dict(best_checkpoint["state_dict"])
    model.eval()

    for test_name, test_loader in test_loaders.items():
        test_metrics = model_eval(
            np.inf, test_loader, model, args, criterion, store_preds=True
        )
        log_metrics(f"Test - {test_name}", test_metrics, args, logger)


def cli_main():
    parser = argparse.ArgumentParser(description="Train Models")
    get_args(parser)
    args, remaining_args = parser.parse_known_args()
    assert remaining_args == [], remaining_args
    train(args)


if __name__ == "__main__":
    import warnings

    warnings.filterwarnings("ignore")

    cli_main()