import copy
import datetime
from termcolor import colored
from tqdm import tqdm
import torch
import numpy as np
from sklearn.cluster import k_means
import umap
from sklearn.mixture import GaussianMixture
from collections import Counter
import training.utils as utils
import torch.nn.functional as F
from data_utils import EnvSampler, is_textdata
from torch.utils.data import DataLoader
from torch import autograd
from model_utils import get_model
import pickle
from sklearn import metrics


def compute_l2(XS, XQ):
    '''
        Compute the pairwise l2 distance
        @param XS (support x): support_size x ebd_dim
        @param XQ (support x): query_size x ebd_dim

        @return dist: query_size x support_size

    '''
    diff = XS.unsqueeze(0) - XQ.unsqueeze(1)
    dist = torch.norm(diff, dim=2)

    return dist ** 2


def train_dro_loop(train_loaders, model, opt, ep, args):
    stats = {}
    for k in ['worst_loss', 'avg_loss', 'worst_acc', 'avg_acc']:
        stats[k] = []

    step = 0
    for batches in zip(*train_loaders):
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()

        x, y = [], []

        for batch in batches:
            batch = utils.to_cuda(utils.squeeze_batch(batch))
            x.append(batch['X'])
            y.append(batch['Y'])

        if is_textdata(args.dataset):
            # text models have varying length between batches
            pred = []
            for cur_x in x:
                pred.append(model['clf_all'](model['ebd'](cur_x)))
            pred = torch.cat(pred, dim=0)
        else:
            pred = model['clf_all'](model['ebd'](torch.cat(x, dim=0)))

        cur_idx = 0

        avg_loss = 0
        avg_acc = 0
        worst_loss = 0
        worst_acc = 0

        for cur_true in y:
            cur_pred = pred[cur_idx:cur_idx+len(cur_true)]
            cur_idx += len(cur_true)

            loss = F.cross_entropy(cur_pred, cur_true)
            acc = torch.mean((torch.argmax(cur_pred, dim=1) == cur_true).float()).item()

            avg_loss += loss.item()
            avg_acc += acc

            if loss.item() > worst_loss:
                worst_loss = loss
                worst_acc = acc

        opt.zero_grad()
        worst_loss.backward()
        opt.step()

        avg_loss /= len(y)
        avg_acc /= len(y)

        stats['avg_acc'].append(avg_acc)
        stats['avg_loss'].append(avg_loss)
        stats['worst_acc'].append(worst_acc)
        stats['worst_loss'].append(worst_loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def train_loop(train_loader, model, opt, ep, args):
    stats = {}
    for k in ['acc', 'loss']:
        stats[k] = []

    step = 0
    for batch in train_loader:
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']

        acc, loss = model['clf_all'](x, y, return_pred=False,
                                     grad_penalty=False)

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def cluster_umap_loop(test_loader, model, args):
    '''
        use umap to reduce data dim
        then apply gmm
    '''
    model['ebd'].eval()
    groups = {}

    for batch in test_loader:
        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x_s = model['ebd'](batch['X']).cpu().numpy()
        # x_s = model['ebd'](batch['X'], True).cpu().numpy()
        # exit(0)

        y_s = batch['Y'].cpu().numpy()
        c_s = batch['C'].cpu().numpy()
        idx_s = batch['idx'].cpu().numpy()

        for x, y, c, idx in zip(x_s, y_s, c_s, idx_s):
            if int(y) not in groups:
                groups[int(y)] = {
                    'x': [],
                    'c': [],
                    'idx': [],
                }
            groups[int(y)]['x'].append(x)
            groups[int(y)]['c'].append(c)
            groups[int(y)]['idx'].append(idx)

    # print('saving to pickle')
    # pickle.dump(groups, open("groups_visual_all.p", "wb" ))
    # print('done')
    # exit(0)


    clusters = []
    clustering_metrics = [metrics.homogeneity_score,
                          metrics.completeness_score,
                          metrics.v_measure_score,
                          metrics.adjusted_rand_score,
                          metrics.adjusted_mutual_info_score]
    clustering_results = []

    for k, v in groups.items():
        x = np.stack(v['x'], axis=0)
        print(x.shape)
        reduced_x = umap.UMAP().fit_transform(x)

        cur_clusters = {}
        cur_cs = {}

        label = GaussianMixture(n_components=2, random_state=0).fit_predict(reduced_x)

        # centroid, label, inertia = k_means(x, 3)
        # print(inertia)

        metric_c = np.stack(v['c'], axis=0)
        clustering_results.append([m(metric_c, label) for m in
                                   clustering_metrics])

        for cluster_id, idx, c in zip(label, v['idx'], v['c']):
            if cluster_id not in cur_clusters:
                cur_clusters[cluster_id] = []
                cur_cs[cluster_id] = []

            cur_clusters[cluster_id].append(idx)
            cur_cs[cluster_id].append(c)

        for cluster_id, cluster in cur_clusters.items():
            clusters.append(cluster)
            cnt = Counter(cur_cs[cluster_id])
            print('size: {}, color '.format(len(cur_cs[cluster_id])), end='')
            for c, cur_cnt in sorted(cnt.items()):
                print('{}={:.2f}, '.format(c, cur_cnt / len(cur_cs[cluster_id])),
                      end='')
            print()

    clustering_results = np.array(clustering_results)
    print('clustering metrics')
    print(clustering_results.shape)
    print(np.mean(clustering_results, axis=0))

    return clusters


def test_loop(test_loader, model, ep, args, return_idx=False, att_idx_dict=None):
    loss_list = []
    true, pred, cor = [], [], []
    if (att_idx_dict is not None) or return_idx:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']
        c = batch['C']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)
        cor.append(c)

        if (att_idx_dict is not None) or return_idx:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if return_idx:
        cor = torch.cat(cor).tolist()
        true = true.tolist()
        pred = pred.tolist()
        idx = torch.cat(idx).tolist()

        # split correct and wrong idx
        correct_idx, wrong_idx = [], []

        # compute correlation between cor and y for analysis
        correct_cor, wrong_cor = [], []
        correct_y, wrong_y = [], []

        for i, y, y_hat, c in zip(idx, true, pred, cor):
            if y == y_hat:
                correct_idx.append(i)
                correct_cor.append(c)
                correct_y.append(y)
            else:
                wrong_idx.append(i)
                wrong_cor.append(c)
                wrong_y.append(y)

        return {
            'acc': acc,
            'loss': loss,
            'correct_idx': correct_idx,
            'correct_cor': correct_cor,
            'correct_y': correct_y,
            'wrong_idx': wrong_idx,
            'wrong_cor': wrong_cor,
            'wrong_y': wrong_y,
        }

    if att_idx_dict is not None:
        return utils.get_worst_acc(true, pred, idx, loss, att_idx_dict)

    return {
        'acc': acc,
        'loss': loss,
    }


def print_partition_res(train_res, val_res, ep):
    print(("epoch {epoch}, train {loss} {train_loss:>10.7f} pos {dis_pos:>10.5f} "
           "neg {dis_neg:>10.5f} cross {dis_cross:>10.5f} "
           "val {loss} {val_loss:>10.7f}").format(
               epoch=ep,
               loss=colored("loss", "yellow"),
               train_loss=train_res["loss"],
               val_loss=val_res["loss"],
               dis_pos=train_res["dis_pos"],
               dis_neg=train_res["dis_neg"],
               dis_cross=train_res["dis_cross"],
           ), flush=True)


def print_res(train_res, val_res, ep):
    print(("epoch {epoch}, train {acc} {train_acc:>7.4f} {train_worst_acc:>7.4f} "
           "{loss} {train_loss:>10.7f} {train_worst_loss:>10.7f} "
           "val {acc} {val_acc:>10.7f}, {loss} {val_loss:>10.7f}").format(
               epoch=ep,
               acc=colored("acc", "blue"),
               loss=colored("loss", "yellow"),
               regret=colored("regret", "red"),
               train_acc=train_res["avg_acc"],
               train_worst_acc=train_res["worst_acc"],
               train_loss=train_res["avg_loss"],
               train_worst_loss=train_res["worst_loss"],
               val_acc=val_res["acc"],
               val_loss=val_res["loss"]), flush=True)


def print_pretrain_res(train_res, test_res, ep, i):
    print(("petrain {i}, epoch {epoch}, train {acc} {train_acc:>7.4f} "
           "{loss} {train_loss:>7.4f}, "
           "val {acc} {test_acc:>7.4f}, {loss} {test_loss:>7.4f} ").format(
               epoch=ep,
               i = i,
               acc=colored("acc", "blue"),
               loss=colored("loss", "yellow"),
               ebd=colored("ebd", "red"),
               train_acc=train_res["acc"],
               train_loss=train_res["loss"],
               test_acc=test_res["acc"],
               test_loss=test_res["loss"]), flush=True)


def eiil_partition(test_loader, model, args):
    '''
    partition the data based on the pre-trained erm classifier
    '''
    model['ebd'].eval()
    model['clf_all'].eval()

    logit_list = []
    idx_list = []
    y_list = []
    with torch.no_grad():
        for batch in test_loader:
            # work on each batch
            batch = utils.to_cuda(utils.squeeze_batch(batch))

            x = model['ebd'](batch['X'])
            y = batch['Y']
            i = batch['idx']

            logit = model['clf_all'](x, y, return_logit=True)
            logit_list.append(logit.cpu())
            idx_list.append(i.cpu())
            y_list.append(y.cpu())

    logits = torch.cat(logit_list, dim=0).cuda()
    idx_list = torch.cat(idx_list)
    labels = torch.cat(y_list).cuda()

    print(logits.size())
    print(idx_list.size())

    scale = torch.tensor(1.).cuda().requires_grad_()
    loss = F.cross_entropy(logits * scale, labels, reduction='none')

    env_w = torch.randn(len(logits)).cuda().requires_grad_()
    optimizer = torch.optim.Adam([env_w], lr=0.001)

    print('learning soft environment assignments')
    n_steps = 5000
    for i in tqdm(range(n_steps)):
        # penalty for env a
        lossa = (loss.squeeze() * env_w.sigmoid()).mean()
        grada = autograd.grad(lossa, [scale], create_graph=True)[0]
        penaltya = torch.sum(grada**2)
        # penalty for env b
        lossb = (loss.squeeze() * (1-env_w.sigmoid())).mean()
        gradb = autograd.grad(lossb, [scale], create_graph=True)[0]
        penaltyb = torch.sum(gradb**2)
        # negate
        npenalty = - torch.stack([penaltya, penaltyb]).mean()

        optimizer.zero_grad()
        npenalty.backward(retain_graph=True)
        optimizer.step()

    # split envs based on env_w threshold
    partition_dict = {}
    for i, e, y in zip(idx_list, env_w.sigmoid().cpu(), labels.cpu()):
        e = int(e > 0.5)
        k = '{}_{}'.format(e, y)
        if k not in partition_dict:
            partition_dict[k] = []
        partition_dict[k].append(i)

    print(partition_dict.keys())
    for k, v in partition_dict.items():
        print(k, ' ', len(v))

    return partition_dict.values()


def george(train_data, test_data, model, opt, args, partition_model=None,
         train_partition_loaders=None, val_partition_loaders=None,
         train_ebd=None):
    ########################
    # Step 1 train a ERM classifier on train data
    ########################
    print('Georege')
    print('Step 1: train an reference ERM classifier on the train data')
    # preparing the train loader
    pos_idx = []
    neg_idx = []
    for i, y in zip(train_data.envs[0]['idx_list'],
                    train_data.get_all_y(0)):
        if y < 0.5:
            neg_idx.append(i)
        else:
            pos_idx.append(i)

    train_loaders = []
    train_loaders.append(DataLoader(
        train_data,
        sampler=EnvSampler(args.num_batches, args.batch_size, 0, pos_idx),
    num_workers=10))

    train_loaders.append(DataLoader(
        train_data,
        sampler=EnvSampler(args.num_batches, args.batch_size, 0, neg_idx),
    num_workers=10))

    val_loader = DataLoader(
        test_data, sampler=EnvSampler(-1, args.batch_size, 2,
                                      test_data.envs[2]['idx_list']),
        num_workers=10)

    test_loader = DataLoader(
        test_data, sampler=EnvSampler(-1, args.batch_size, 3,
                                      test_data.envs[3]['idx_list']),
        num_workers=10)

    # train the erm model
    if hasattr(train_data, 'vocab'):
        erm_model, erm_opt = get_model(args, train_data.vocab)
    else:
        erm_model, erm_opt = get_model(args)
    best_acc = -1
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        train_res = train_dro_loop(train_loaders, erm_model, erm_opt, ep, args)

        with torch.no_grad():
            # validation
            val_res = test_loop(val_loader, erm_model, ep, args)

        print_res(train_res, val_res, ep)

        if min(train_res['worst_acc'], val_res['acc']) > best_acc:
            best_acc = min(train_res['worst_acc'], val_res['acc'])
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(erm_model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

    # load best model
    for k in 'ebd', 'clf_all':
        erm_model[k].load_state_dict(best_model[k])

    # get the results
    test_res = test_loop(test_loader, erm_model, ep, args,
                         att_idx_dict=test_data.test_att_idx_dict)
    print('Best erm train')
    print(train_res)
    print('Best erm val')
    val_res = best_val_res
    print(val_res)
    print('ERM Test')
    print(test_res)


    ########################
    # Step 2 dimensionality reduction on the embedding + clustering
    ########################
    print('Step 2: learn splits from the ERM classifier')
    # loading data in testing mode
    test_train_loader = DataLoader(
        train_data, sampler=EnvSampler(-1, args.batch_size, 0,
                                       test_data.envs[0]['idx_list']),
        num_workers=10)

    with torch.no_grad():
        partition_res = cluster_umap_loop(test_train_loader, erm_model, args)

    ########################
    # Step 3 learn robust model using dro
    ########################
    train_loaders = []
    for group in partition_res:
        train_loaders.append(DataLoader(
            train_data, sampler=EnvSampler(args.num_batches,
                                           args.batch_size, 0, group),
            num_workers=int(10 / 4)))


    best_acc = -1
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        train_res = train_dro_loop(train_loaders, model, opt, ep, args)

        with torch.no_grad():
            # validation
            val_res = test_loop(val_loader, model, ep, args)

        print_res(train_res, val_res, ep)

        if min(train_res['worst_acc'], val_res['acc']) > best_acc:
            best_acc = min(train_res['worst_acc'], val_res['acc'])
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

    # load best model
    for k in 'ebd', 'clf_all':
        model[k].load_state_dict(best_model[k])

    # get the results
    test_res = test_loop(test_loader, model, ep, args,
                         att_idx_dict=test_data.test_att_idx_dict)
    print('Best train')
    print(train_res)
    print('Best val')
    val_res = best_val_res
    print(val_res)
    print('Test')
    print(test_res)

    # this is inference on target task, do not need to retrain the partition
    # model
    res = {
        'train': train_res,
        'val': val_res,
        'test': test_res,
        'partition': None,
    }

    return res
