import os

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import random
import datetime
import pickle
from torch.utils.data import DataLoader, ConcatDataset, TensorDataset
from tqdm import tqdm

from datasets import load_dataset
from data_loader import DynaSent2
from model import load_backbone, Classifier_pref, Classifier_pref_ensemble
from common import parse_args
from sampling import disagreement_sampling, inconsistency_sampling
from utils import Logger, set_seed, set_model_path, save_model, AverageMeter, cut_input, ECE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CKPT_PATH = './checkpoint'

def main():
    args = parse_args()

    # Set seed
    set_seed(args)
    log_name = f"{args.dataset}_{args.train_type}_{args.n_layers}{args.activation}_B{args.batch_size}_S{args.seed}"

    logger = Logger(log_name)
    log_dir = logger.logdir

    logger.log(args)
    logger.log(log_name)

    logger.log('Loading pre-trained backbone network... ({})'.format(args.backbone))
    backbone, tokenizer = load_backbone(args.backbone)

    logger.log('Initializing model and optimizer...')
    if 'dynasent' in args.dataset:
        args.n_class = 3
    else:
        args.n_class = 2

    if args.ensemble:
        model = Classifier_pref_ensemble(args, args.backbone, backbone, args.n_class, args.train_type).to(device)
    else:
        model = Classifier_pref(args, args.backbone, backbone, args.n_class, args.train_type).to(device)

    if args.pre_ckpt is not None:
        logger.log('Loading from pre-trained model')
        model.load_state_dict(torch.load(args.pre_ckpt))

    # Set optimizer (1) fixed learning rate and (2) no weight decay
    optimizer = optim.Adam(model.parameters(), lr=args.model_lr, weight_decay=0)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    logger.log('Initializing dataset...')
    dataset = DynaSent2(args.dataset, tokenizer)
        
    # Added for preference
    orig_loader = DataLoader(dataset.train_dataset, shuffle=True, drop_last=True, batch_size=args.batch_size, num_workers=4)
    val_loader = DataLoader(dataset.val_dataset, shuffle=False, batch_size=args.batch_size, num_workers=4)
    test_loader = DataLoader(dataset.test_dataset, shuffle=False, batch_size=args.batch_size, num_workers=4)

    logger.log('==========> Start training ({})'.format(args.train_type))
    best_acc, final_acc, final_ece = 0, 0, 0

    # Initialization
    train_labels = dataset.train_dataset[:][1][:, 0]
    pref_train, prob_train = None, None

    for epoch in range(1, 1+args.epochs):
        # Set Dataloader
        train_loader, pair_idx = set_loader(args, dataset, orig_loader, epoch, pref_train, prob_train, train_labels)
        # Training function
        pref_train, prob_train = train_preference(args, train_loader, pair_idx, model, optimizer, epoch, logger)
        best_acc, final_acc = eval_func(args, model, val_loader, test_loader, logger, log_dir, epoch,
                                                  best_acc, final_acc)

    logger.log('===========>>>>> Final Test Accuracy: {}'.format(final_acc))

def set_loader(args, dataset, orig_loader, epoch, prefs, probs, labels):
    if args.pre_pref is not None:
        pair_idx_pref = torch.LongTensor(np.load('./pre_gen/{}_idx_pref_{}.npy'.format(args.dataset, args.pre_pref)))
        args.n_samples = pair_idx_pref.size()[1]

        if epoch == 1:
            pair_idx, preference = pair_idx_pref[0, :, 0], pair_idx_pref[0, :, 1]
        elif args.sampling is not None:
            if args.sampling == 'disagreement':
                pair_idx, preference = disagreement_sampling(args, prefs)
            else:
                pair_idx, preference = inconsistency_sampling(args, prefs, probs, labels)
        else:
            pair_idx, preference = pair_idx_pref[epoch-1, :, 0], pair_idx_pref[epoch-1, :, 1]

        pref_train_dataset = TensorDataset(dataset.train_dataset[:][0], dataset.train_dataset[:][0][pair_idx],
                                           dataset.train_dataset[:][1], preference, dataset.train_dataset[:][2])
        train_loader = DataLoader(pref_train_dataset, shuffle=True, drop_last=False, batch_size=args.batch_size, num_workers=4)
    else:
        train_loader = orig_loader
        pair_idx = None

    return train_loader, pair_idx

def train_preference(args, loader, pair_idx, model, optimizer, epoch=0, logger=None):
    model.train()

    losses = dict()
    losses['cls'] = AverageMeter()
    losses['cls_acc'] = AverageMeter()
    losses['pref'] = AverageMeter()
    losses['pref_acc'] = AverageMeter()
    losses['div'] = AverageMeter()
    losses['cons'] = AverageMeter()
    losses['consistency'] = AverageMeter()

    criterion = nn.CrossEntropyLoss(reduction='none')
    soft_labels = torch.Tensor(np.load('./pre_gen/{}_soft_label.npy'.format(args.dataset))).to(device)

    if args.ensemble:
        prefs_all = torch.zeros(3, args.n_samples)
    else:
        prefs_all = torch.zeros(args.n_samples)
    probs_all = torch.zeros(args.n_samples, args.n_class)

    for batch_idx, (tokens1, tokens2, labels, pref, indices) in enumerate(tqdm(loader)):
        batch_size = tokens1.size(0)
        tokens1, tokens2, labels, pref = tokens1.to(device), tokens2.to(device), labels[:,0].to(device), pref.to(device)
        tokens1 = cut_input(tokens1)
        tokens2 = cut_input(tokens2)

        pref1, pref2 = pref.clone(), pref.clone()
        pref1[pref1 == 2] = 0
        pref2[pref2 == 2] = 1
        pref_label = torch.zeros(batch_size, 2).cuda()
        pref_label[torch.arange(batch_size), pref1] += 0.5
        pref_label[torch.arange(batch_size), pref2] += 0.5

        out_cls, out_pref1 = model(tokens1, y=labels, pref=True)
        out_cls2, out_pref2 = model(tokens2, y=labels, pref=True)

        # Cross-entropy loss
        loss_cls = criterion(out_cls, labels).mean()

        probs1, probs2 = out_cls.softmax(dim=-1), out_cls2.softmax(dim=-1)
        zeros = torch.zeros(batch_size).float().to(device)

        # Consistency regularization
        if args.cons_margin:
            soft_labels_batch, soft_labels_batch2 = soft_labels[indices], soft_labels[pair_idx[indices]]

            soft_labels_delta, prob_delta = (soft_labels_batch - soft_labels_batch2), (probs1 - probs2)
            mask1, mask2 = (soft_labels_delta >= 0).float(), (soft_labels_delta < 0).float()
            loss_cons = (mask1 * torch.max(zeros.unsqueeze(1), soft_labels_delta - prob_delta)).sum(dim=-1)
            loss_cons += (mask2 * torch.max(zeros.unsqueeze(1), prob_delta - soft_labels_delta)).sum(dim=-1)
            loss_cons = loss_cons.mean()
        elif args.cons_order:
            mask1, mask2, mask3 = (pref == 1).float(), (pref == 0).float(), (pref == 2).float()
            prob_delta = (probs1 - probs2)[torch.arange(batch_size), labels]

            loss_cons = (mask1 * torch.max(zeros, -1 * prob_delta))
            loss_cons += (mask2 * torch.max(zeros, prob_delta))
            loss_cons += (mask3 * torch.max(zeros, prob_delta.abs()))
            loss_cons = loss_cons.mean()
        else:
            loss_cons = torch.Tensor([0]).cuda().mean()

        # Multi-head preference learning
        if args.ensemble:
            loss_pref, pref_probs = 0, 0
            pref_probs_all = []
            for i in range(len(out_pref1)):
                pref_probs_i = torch.cat([torch.exp(out_pref2[i]), torch.exp(out_pref1[i])], dim=-1)  # pref: 1 if x1 > x2, 0 else
                pref_probs_i = pref_probs_i / (torch.exp(out_pref2[i]) + torch.exp(out_pref1[i])).sum(dim=-1, keepdim=True)

                loss_pref += (-1 * pref_label * torch.log(pref_probs_i + 1e-8)).sum(dim=-1).mean()
                pref_probs += pref_probs_i
                pref_probs_all.append(pref_probs_i)
            loss_pref /= len(out_pref1)
            pref_probs /= len(out_pref1)

            loss_pref = loss_pref

            # Diverisification loss
            loss_div = diversity_loss(out_pref1, out_pref2)
        else:
            pref_probs = torch.cat([torch.exp(out_pref2), torch.exp(out_pref1)], dim=-1) # pref: 1 if x1 > x2, 0 else
            pref_probs = pref_probs / (torch.exp(out_pref2) + torch.exp(out_pref1)).sum(dim=-1, keepdim=True)

            loss_pref = (-1 * pref_label * torch.log(pref_probs + 1e-8)).sum(dim=-1).mean()
            loss_div = torch.Tensor([0]).cuda().mean()

        loss = loss_cls + loss_pref - loss_div + args.lambda_cons * loss_cons

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # cls_acc
        _, pred_cls = out_cls.max(dim=1)
        corrects = (pred_cls == labels).float()
        acc_cls = corrects.sum()  / batch_size

        # pref_acc
        _, pred_pref = pref_probs.max(dim=1)
        n_pref = (pref != 2).float().sum()
        corrects_pref = (pred_pref == pref)[pref != 2].float()
        acc_pref = corrects_pref.sum() / (1e-8 + n_pref)

        losses['cls'].update(loss_cls.item(), batch_size)
        losses['cls_acc'].update(acc_cls.item(), batch_size)
        losses['pref'].update(loss_pref.item(), batch_size)
        losses['pref_acc'].update(acc_pref.item(), batch_size)
        losses['div'].update(loss_div.item(), 3)
        losses['cons'].update(loss_cons.item(), batch_size)

        p_delta = ((probs1 - probs2)[torch.arange(batch_size), labels] >= 0).data
        consistency = (pref == p_delta)[pref != 2].float().mean()
        losses['consistency'].update(consistency.item(), n_pref)

        # save pref
        if args.ensemble:
            prefs_all[0, indices] = out_pref1[0].data.cpu()[:, 0]
            prefs_all[1, indices] = out_pref1[1].data.cpu()[:, 0]
            prefs_all[2, indices] = out_pref1[2].data.cpu()[:, 0]
        else:
            prefs_all[indices] = out_pref1.data.cpu()[:, 0]
        probs_all[indices] = probs1.data.cpu()

    msg = '[Epoch %2d] [AccC %.3f] [LossC %.3f] [AccP %.3f] [LossP %.3f] [LossCons %.3f] [Consist %.3f]' \
          % (epoch, losses['cls_acc'].average, losses['cls'].average, losses['pref_acc'].average,
             losses['pref'].average, losses['cons'].average, losses['consistency'].average)

    if logger:
        logger.log(msg)
    else:
        print(msg)

    return prefs_all, probs_all

def diversity_loss(out_pref1, out_pref2):
    pref_probs_all = []
    n_ensemble = len(out_pref1)
    for i in range(n_ensemble):
        pref_probs_i = torch.cat([torch.exp(out_pref2[i]), torch.exp(out_pref1[i])], dim=-1)  # pref: 1 if x1 > x2, 0 else
        pref_probs_i = pref_probs_i / (torch.exp(out_pref2[i]) + torch.exp(out_pref1[i])).sum(dim=-1, keepdim=True)

        pref_probs_all.append(pref_probs_i)

    pref_sim = 0
    for i in range(n_ensemble):
        for j in range(n_ensemble):
            if i != j:
                pref_sim += (-1 * pref_probs_all[i].data * torch.log(pref_probs_all[j] + 1e-8)).sum(dim=-1).mean()

    loss_div = pref_sim / (n_ensemble * (n_ensemble - 1))
    return loss_div

def eval_func(args, model, val_loader, test_loader, logger, log_dir, epoch, best_acc, final_acc):
    acc = test_acc(args, val_loader, model, logger)

    if acc > best_acc:
        # As val_data == test_data in GLUE, do not inference it again.
        t_acc = test_acc(args, test_loader, model, logger)

        # Update test accuracy based on validation performance
        best_acc = acc
        final_acc = t_acc

        logger.log('========== Val Acc ==========')
        logger.log('Val acc: {:.3f}'.format(best_acc))
        logger.log('========== Test Acc ==========')
        logger.log('Test acc: {:.3f}'.format(final_acc))

        # Save model
        if args.save_ckpt:
            logger.log('Save model...')
            save_model(args, model, log_dir, epoch)

    return best_acc, final_acc

def test_acc(args, loader, model, logger=None):
    if logger is not None:
        logger.log('Compute test accuracy...')
    model.eval()

    all_preds = []
    all_labels = []

    for i, (tokens, labels, indices) in enumerate(loader):
        tokens = tokens.long().to(device)
        labels = labels.to(device)[:, 0]

        with torch.no_grad():
            outputs = model(tokens)

        all_preds.append(outputs)
        all_labels.append(labels)

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    if args.dataset != 'stsb':
        all_preds = all_preds.cpu().max(1)[1]
    else:
        all_preds = all_preds.cpu()
    all_labels = all_labels.cpu()

    acc = 100.0 * (all_preds == all_labels).float().sum() / len(all_preds)

    return acc

if __name__ == "__main__":
    main()

