import copy
import datetime
from termcolor import colored
from tqdm import tqdm
import torch
import numpy as np
import torch.nn.functional as F
import training.utils as utils
from torch.utils.data import DataLoader
from data_utils import EnvSampler
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from sklearn.decomposition import PCA
from sklearn.cluster import k_means
from sklearn.preprocessing import StandardScaler

class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambd, reverse=True):
        ctx.lambd = lambd
        ctx.reverse=reverse
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.reverse:
            return (grad_output * -ctx.lambd), None, None
        else:
            return (grad_output * ctx.lambd), None, None

class HLoss(nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x):
        b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
        b = -1.0 * b.sum(dim=1).mean()
        return b

def grad_reverse(x, lambd=1.0, reverse=True):
    return GradReverse.apply(x, lambd, reverse)


class Discriminator(nn.Module):
    def __init__(self, dims, grl=True, reverse=True):
        if len(dims) != 4:
            raise ValueError("Discriminator input dims should be three dim!")
        super(Discriminator, self).__init__()
        self.grl = grl
        self.reverse = reverse
        self.model = nn.Sequential(
            nn.Linear(dims[0], dims[1]),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(dims[1], dims[2]),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(dims[2], dims[3]),
        )
        self.lambd = 0.0

    def set_lambd(self, lambd):
        self.lambd = lambd

    def forward(self, x):
        if self.grl:
            x = grad_reverse(x, self.lambd, self.reverse)
        x = self.model(x)
        return x


def train_loop(train_loader, model, discriminator, opt, ep, args):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    step = 0

    p = ep / args.num_epochs
    grl_weight = 1.0
    entropy_weight = 1.0
    alpha = (2. / (1. + np.exp(-10 * p)) -1) * grl_weight
    beta = (2. / (1. + np.exp(-10 * p)) -1) * entropy_weight
    discriminator.set_lambd(alpha)
    entropy_criterion = HLoss()

    for batch in train_loader:
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()
        discriminator.train()

        batch = utils.to_cuda(utils.squeeze_batch(batch))
        x = model['ebd'](batch['X'])
        y = batch['Y']
        d = batch['D']  # latent domain generated by clustering

        logit = model['clf_all'](x, y, return_logit=True)
        loss_class = F.cross_entropy(logit, y)
        acc = torch.mean((torch.argmax(logit, dim=1) == y).float()).item()

        logit_domain = discriminator(x)
        loss_domain = F.cross_entropy(logit_domain, d)

        loss_entropy = entropy_criterion(logit)

        loss = loss_class + loss_domain + loss_entropy * beta

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C)
    return feat_mean, feat_std


def cluster_loop(test_loader, model, args):
    model['ebd'].eval()
    groups = {'x': [], 'idx': []}

    for batch in test_loader:
        batch = utils.to_cuda(utils.squeeze_batch(batch))

        conv_feats = model['ebd'].conv_features(batch['X'])
        for layer, feats in enumerate(conv_feats):
            feat_mean, feat_std = calc_mean_std(feats)
            if layer == 0:
                aux = torch.cat((feat_mean, feat_std), 1)
            else:
                aux = torch.cat((aux, torch.cat((feat_mean, feat_std), 1)),
                                axis=1)

        for i,  x in zip(batch['idx'].cpu().numpy(), aux.cpu().numpy()):
            groups['idx'].append(i)
            groups['x'].append(x)

    clusters = []
    clustering_results = []

    # start pca
    x = np.stack(groups['x'], axis=0)
    # print(x.shape)
    # not sure why
    # https://github.com/mil-tokyo/dg_mmld/blob/aef26b2745beabc6356accd183ff3e17f71657ce/clustering/clustering.py
    # didn't normalize the features before pca
    x = StandardScaler().fit_transform(x)
    x = PCA(n_components=256).fit_transform(x)
    # print(x.shape)

    cur_clusters = {}
    cur_cs = {}

    centroid, domain_list, inertia = k_means(x, args.num_clusters)
    print(np.mean(domain_list))
    # print(inertia)

    idx_list = np.stack(groups['idx'], axis=0)

    return idx_list, domain_list


def test_loop(test_loader, model, ep, args, att_idx_dict=None):
    loss_list = []
    true, pred = [], []

    if att_idx_dict is not None:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])

        y = batch['Y']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)

        if att_idx_dict is not None:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if att_idx_dict is not None:
        return utils.get_worst_acc(true, pred, idx, loss, att_idx_dict)

    return {
        'acc': acc,
        'loss': loss,
    }


def dg_mmld(train_data, test_data, model, opt, args, train_ebd=True):

    val_loader = DataLoader(
        train_data,
        sampler=EnvSampler(-1, args.batch_size, 2,
                           train_data.envs[2]['idx_list']),
        num_workers=10)

    test_loader = DataLoader(
        train_data,
        sampler=EnvSampler(-1, args.batch_size, 3,
                           train_data.envs[3]['idx_list']),
        num_workers=10)

    discriminator = Discriminator([model['ebd'].out_dim, args.hidden_dim,
                                   args.hidden_dim, args.num_clusters]).cuda()

    # override opt to incldue parameters for discriminator
    opt = torch.optim.SGD(list(model['ebd'].parameters()) +
                          list(model['clf_all'].parameters()) +
                          list(discriminator.parameters()),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.weight_decay,
                          nesterov=True)
    scheduler = StepLR(optimizer=opt, step_size=24, gamma=0.1)

    # start training
    best_acc = -1
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        # run train data in test mode for clustering
        test_train_loader = DataLoader(
            train_data,
            sampler=EnvSampler(-1, args.batch_size, 0,
                               train_data.envs[0]['idx_list']),
            num_workers=10)

        # assign domain label by clustering
        with torch.no_grad():
            idx_list, domain_list = cluster_loop(test_train_loader, model, args)
        train_data.set_domain_label(0, idx_list, domain_list)

        # load new train data
        train_loader = DataLoader(
            train_data,
            sampler=EnvSampler(args.num_batches, args.batch_size, 0,
                               train_data.envs[0]['idx_list']),
            num_workers=10)

        # train
        train_res = train_loop(train_loader, model, discriminator, opt, ep, args)

        with torch.no_grad():
            val_res = test_loop(val_loader, model, ep, args, None)

        utils.print_res(train_res, val_res, ep)

        if min(train_res['acc'], val_res['acc']) > best_acc:
            best_acc = min(train_res['acc'], val_res['acc'])
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

        scheduler.step()

    # load best model
    for k in 'ebd', 'clf_all':
        model[k].load_state_dict(best_model[k])

    # get the results
    test_res = test_loop(test_loader, model, ep, args,
                         test_data.test_att_idx_dict)
    print('Best train')
    print(train_res)
    print('Best val')
    val_res = best_val_res
    print(val_res)
    print('Test')
    print(test_res)

    return {
        'train': train_res,
        'val': val_res,
        'test': test_res,
    }
