import copy
import datetime
from termcolor import colored
from tqdm import tqdm
import torch
import numpy as np
from sklearn.cluster import k_means
from collections import Counter
import training.utils as utils
import torch.nn.functional as F
from data_utils import EnvSampler, is_textdata
from torch.utils.data import DataLoader
from model_utils import get_model
import pickle
import torch.nn as nn
from sklearn import metrics


def compute_l2(XS, XQ):
    '''
        Compute the pairwise l2 distance
        @param XS (support x): support_size x ebd_dim
        @param XQ (support x): query_size x ebd_dim

        @return dist: query_size x support_size

    '''
    diff = XS.unsqueeze(0) - XQ.unsqueeze(1)
    dist = torch.norm(diff, dim=2)

    return dist ** 2


def compute_mmd(x, y):
    '''
        https://github.com/facebookresearch/DomainBed/blob/main/domainbed/algorithms.py
    '''
    mean_x = x.mean(0, keepdim=True)
    mean_y = y.mean(0, keepdim=True)
    cent_x = x - mean_x
    cent_y = y - mean_y
    cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
    cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

    mean_diff = (mean_x - mean_y).pow(2).mean()
    cova_diff = (cova_x - cova_y).pow(2).mean()

    return mean_diff + cova_diff


def train_loop(train_loaders, model, opt, ep, args):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    model['ebd'].train()
    model['clf_all'].train()

    step = 0
    for batch_0, batch_1 in zip(train_loaders[0], train_loaders[1]):
        # work on each batch
        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))

        x = torch.cat([batch_0['X'], batch_1['X']], dim=0)
        y = torch.cat([batch_0['Y'], batch_1['Y']], dim=0)

        x_ebd = model['ebd'](x)

        acc, loss_ce = model['clf_all'](x_ebd, y, return_pred=False,
                                        grad_penalty=False)

        loss_mmd = compute_mmd(x_ebd[:len(batch_0['X'])],
                               x_ebd[len(batch_0['X']):])

        loss = loss_ce + loss_mmd

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())
        stats['regret'].append(loss_mmd.item())
        stats['loss_train'].append(loss_ce.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def test_loop(test_loader, model, ep, args, att_idx_dict=None):
    loss_list = []
    true, pred, cor = [], [], []
    if att_idx_dict is not None:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']
        c = batch['C']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)
        cor.append(c)

        if att_idx_dict is not None:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if att_idx_dict is not None:
        return utils.get_worst_acc(true, pred, idx, loss, att_idx_dict)

    return {
        'acc': acc,
        'loss': loss,
    }


def print_res(train_res, val_res, ep):
    print(("epoch {epoch}, train {acc} {train_acc:>7.4f} {train_worst_acc:>7.4f} "
           "{loss} {train_loss:>10.7f} {train_worst_loss:>10.7f} "
           "val {acc} {val_acc:>10.7f}, {loss} {val_loss:>10.7f}").format(
               epoch=ep,
               acc=colored("acc", "blue"),
               loss=colored("loss", "yellow"),
               regret=colored("regret", "red"),
               train_acc=train_res["avg_acc"],
               train_worst_acc=train_res["worst_acc"],
               train_loss=train_res["avg_loss"],
               train_worst_loss=train_res["worst_loss"],
               val_acc=val_res["acc"],
               val_loss=val_res["loss"]), flush=True)


def print_pretrain_res(train_res, test_res, ep, i):
    print(("petrain {i}, epoch {epoch}, train {acc} {train_acc:>7.4f} "
           "{loss} {train_loss:>7.4f}, "
           "val {acc} {test_acc:>7.4f}, {loss} {test_loss:>7.4f} ").format(
               epoch=ep,
               i = i,
               acc=colored("acc", "blue"),
               loss=colored("loss", "yellow"),
               ebd=colored("ebd", "red"),
               train_acc=train_res["acc"],
               train_loss=train_res["loss"],
               test_acc=test_res["acc"],
               test_loss=test_res["loss"]), flush=True)


def mmd(train_data, test_data, model, opt, args, partition_model=None,
         train_partition_loaders=None, val_partition_loaders=None,
         train_ebd=None):
    '''
        run domain adversarial network on the source task
        to identify different envs
    '''
    # load train envs
    train_loaders = []
    for i in range(2):
        train_loaders.append(DataLoader(
            train_data,
            sampler=EnvSampler(args.num_batches, args.batch_size, i,
                               train_data.envs[i]['idx_list']),
            num_workers=10))

    # load val / test env
    val_loader = DataLoader(
        train_data,
        sampler=EnvSampler(-1, args.batch_size, 2,
                           train_data.envs[2]['idx_list']),
        num_workers=10)

    test_loader = DataLoader(
        train_data,
        sampler=EnvSampler(-1, args.batch_size, 3,
                           train_data.envs[3]['idx_list']),
        num_workers=10)

    # start training
    best_acc = -1
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        # train
        train_res = train_loop(train_loaders, model, opt, ep, args)

        with torch.no_grad():
            val_res = test_loop(val_loader, model, ep, args, None)

        utils.print_res(train_res, val_res, ep)

        if min(train_res['acc'], val_res['acc']) > best_acc:
            best_acc = min(train_res['acc'], val_res['acc'])
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

    # load best model
    for k in 'ebd', 'clf_all':
        model[k].load_state_dict(best_model[k])

    # get the results
    test_res = test_loop(test_loader, model, ep, args,
                         att_idx_dict=test_data.test_att_idx_dict)
    print('Best train')
    print(train_res)
    print('Best val')
    val_res = best_val_res
    print(val_res)
    print('Test')
    print(test_res)

    return {
        'train': train_res,
        'val': val_res,
        'test': test_res,
        'partition': None
    }, None, None
