import copy
import datetime
from termcolor import colored
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import training.utils as utils
from torch.utils.data import DataLoader
from model_utils import get_model
from data_utils import EnvSampler


class GeneralizedCELoss(nn.Module):

    def __init__(self, q=0.7):
        super(GeneralizedCELoss, self).__init__()
        self.q = q

    def forward(self, logits, targets):
        p = F.softmax(logits, dim=1)
        if np.isnan(p.mean().item()):
            raise NameError('GCE_p')
        Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
        # modify gradient of cross entropy
        loss_weight = (Yg.squeeze().detach()**self.q)*self.q
        if np.isnan(Yg.mean().item()):
            raise NameError('GCE_Yg')

        loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight

        return loss


def train_loop(train_loaders, model, bias_model, opt, bias_opt, ep, args,
               train_ebd):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    step = 0
    bias_criteria = GeneralizedCELoss()
    for batch_0, batch_1 in zip(train_loaders[0], train_loaders[1]):
        model['ebd'].train()
        model['clf_all'].train()
        bias_model['ebd'].train()
        bias_model['clf_all'].train()

        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))

        bias_pred_0 = bias_model['clf_all'](bias_model['ebd'](batch_0['X']),
                                            return_logit=True)
        bias_pred_1 = bias_model['clf_all'](bias_model['ebd'](batch_1['X']),
                                            return_logit=True)
        pred_0 = model['clf_all'](model['ebd'](batch_0['X']), return_logit=True)
        pred_1 = model['clf_all'](model['ebd'](batch_1['X']), return_logit=True)

        y_0 = batch_0['Y']
        y_1 = batch_1['Y']

        # compute GCE for bias model
        loss_gce = torch.mean(bias_criteria(bias_pred_0, y_0) +
                              bias_criteria(bias_pred_1, y_1)) / 2.0

        # compute weighted loss for debias model
        # normalize the loss per class first
        bias_loss_ce_0 = F.cross_entropy(bias_pred_0, y_0, reduction='none')
        bias_loss_ce_1 = F.cross_entropy(bias_pred_1, y_1, reduction='none')

        loss_ce_0 = F.cross_entropy(pred_0, y_0, reduction='none')
        loss_ce_1 = F.cross_entropy(pred_1, y_1, reduction='none')

        loss_weight_0 = (bias_loss_ce_0 / (bias_loss_ce_0 + loss_ce_0 +
                                           1e-8)).detach()
        loss_weight_0 = loss_weight_0 / torch.sum(loss_weight_0)
        loss_weight_1 = (bias_loss_ce_1 / (bias_loss_ce_1 + loss_ce_1 +
                                           1e-8)).detach()
        loss_weight_1 = loss_weight_1 / torch.sum(loss_weight_1)

        loss_debias = torch.sum(loss_weight_0 * loss_ce_0 +
                                loss_weight_1 * loss_ce_1) / 2.0


        acc_0 = torch.mean((torch.argmax(pred_0, dim=1) == y_0).float()).item()
        acc_1 = torch.mean((torch.argmax(pred_1, dim=1) == y_1).float()).item()
        acc = (acc_0 + acc_1) / 2.0

        loss = loss_debias + loss_gce

        opt.zero_grad()
        bias_opt.zero_grad()

        loss.backward()
        opt.step()
        bias_opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss_debias.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def test_loop(test_loader, model, ep, args, att_idx_dict=None):
    loss_list = []
    true, pred = [], []

    if att_idx_dict is not None:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])

        y = batch['Y']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)

        if att_idx_dict is not None:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if att_idx_dict is not None:
        return utils.get_worst_acc(true, pred, idx, loss, att_idx_dict)

    return {
        'acc': acc,
        'loss': loss,
    }


def lff(train_data, test_data, model, opt, args, train_ebd=True):
    train_loaders = []

    # use balance sampling
    pos_idx = []
    neg_idx = []
    for i, y in zip(train_data.envs[0]['idx_list'],
                    train_data.get_all_y(0)):
        if y < 0.5:
            neg_idx.append(i)
        else:
            pos_idx.append(i)

    train_loaders.append(DataLoader(
        train_data,
        sampler=EnvSampler(args.num_batches, args.batch_size, 0, pos_idx),
    num_workers=10))

    train_loaders.append(DataLoader(
        train_data,
        sampler=EnvSampler(args.num_batches, args.batch_size, 0, neg_idx),
    num_workers=10))

    val_loader = DataLoader(
        test_data, sampler=EnvSampler(-1, args.batch_size, 2,
                                      test_data.envs[2]['idx_list']),
        num_workers=10)

    test_loader = DataLoader(
        test_data, sampler=EnvSampler(-1, args.batch_size, 3,
                                      test_data.envs[3]['idx_list']),
        num_workers=10)

    # define a bias model
    if hasattr(train_data, 'vocab'):
        bias_model, bias_opt = get_model(args, train_data.vocab)
    else:
        bias_model, bias_opt = get_model(args)

    # start training
    best_acc = -1
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        train_res = train_loop(train_loaders, model, bias_model, opt, bias_opt, ep, args, train_ebd)

        with torch.no_grad():
            val_res = test_loop(val_loader, model, ep, args, None)

        utils.print_res(train_res, val_res, ep)

        if min(train_res['acc'], val_res['acc']) > best_acc:
            best_acc = min(train_res['acc'], val_res['acc'])
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

    # load best model
    for k in 'ebd', 'clf_all':
        model[k].load_state_dict(best_model[k])

    # get the results
    test_res = test_loop(test_loader, model, ep, args,
                         test_data.test_att_idx_dict)
    print('Best train')
    print(train_res)
    print('Best val')
    val_res = best_val_res
    print(val_res)
    print('Test')
    print(test_res)

    return {
        'train': train_res,
        'val': val_res,
        'test': test_res,
    }
