from base.base_trainer import BaseTrainer
from sklearn.metrics.cluster import normalized_mutual_info_score
from utils.evaluation import acc
from base.base_dataset import BaseADDataset
from base.base_net import BaseNet
from sklearn.cluster import KMeans
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F
import logging
import time
import torch
import torch.optim as optim
import numpy as np
import random
import time


def compute_balance(n_clusters, y_pred, yt):
    balance = []
    size_list = []
    nbins = len(np.unique(yt))
    print(np.size(np.unique(y_pred)))

    for k in range(n_clusters):
        idx = np.where(y_pred == k)
        y_k = np.squeeze(yt[idx])
        cluster_size = np.size(idx)
        hist = np.zeros((nbins,))
        sum_v = 0
        for i in range(nbins):
            val = i
            hist[i] = np.size(np.where(y_k == val))
            sum_v += hist[i]
        size_list.append(sum_v)
        #print("cluster ", k, "hist:", hist)
        if cluster_size>0:
            p_rule = np.min(hist) / np.max(hist)
            balance.append(p_rule)
    #print("balance list:", balance)
    #print("size list:", size_list)
    return np.amin(balance)

class DFDC_Trainer(BaseTrainer):

    def __init__(self, dataset_name, objective, supervised_weight, optimizer_name: str = 'adam', lr: float = 0.001, n_epochs: int = 150, lr_milestones: tuple = (),
                 batch_size: int = 128, weight_decay: float = 1e-6, device: str = 'cuda', n_jobs_dataloader: int = 0):
        super().__init__(dataset_name, optimizer_name, lr, n_epochs, lr_milestones, batch_size, weight_decay, device,
                         n_jobs_dataloader)
        # Results
        self.train_time = None
        self.test_auc = None
        self.test_time = None
        self.test_scores = None
        self.supervised_weight = supervised_weight
        self.dataset_name = dataset_name

    def get_ideal_ratios(self, y, attr):
        attr = np.squeeze(np.array(attr))
        uniq_vals = np.unique(attr)
        uniq_vals.sort()
        nbins = np.size(uniq_vals)
        ideal_r = []
        for i in range(nbins):
            val = uniq_vals[i]
            inds = np.squeeze(np.where(attr == val))
            ideal_r.append(np.size(y[inds]) / np.size(y))
        return np.array(ideal_r)

    def get_current_ratios(self, y, attr):
        attr = np.squeeze(np.array(attr))
        uniq_vals = np.unique(attr)
        uniq_vals.sort()
        abins = np.size(uniq_vals)
        y_vals = np.unique(y)
        y_vals.sort()
        ybins = np.size(y_vals)
        fair_count = np.zeros((abins, ybins))
        for i in range(abins):
            val = uniq_vals[i]
            inds = np.squeeze(np.where(attr==val))
            selected_y = y[inds]
            for j in range(ybins):
                v = y_vals[j]
                selected_inds = np.squeeze(np.where(selected_y == v))
                fair_count[i][j] = np.size(selected_inds)
        return fair_count

    def fair_goal(self, q, r, ideal, y, attr, ml_cons):
        attr = np.squeeze(np.array(attr))
        uniq_vals = np.unique(attr)
        nbin = np.size(uniq_vals)
        uniq_vals.sort()
        y_vals = np.unique(y)
        nclusters = np.size(y_vals)
        y_vals.sort()
        n, k = q.shape
        u = np.zeros((n, nbin))
        for i in range(n):
            u[i][int(attr[i])] = 1
        import gurobipy as gp
        from gurobipy import GRB
        from gurobipy import quicksum
        m = gp.Model("mip1")
        m.setParam('OutputFlag', False) 
        x = m.addVars(n, k, vtype=GRB.CONTINUOUS, name="x")
        modification = 0
        for i in range(n):
            for j in range(k):
                modification += 1 - x[i, j] * q[i, j]
        m.setObjective(modification, GRB.MINIMIZE)
        row_sum = np.sum(r, axis=0)
        import time
        start_time = time.time()
        m.addConstrs(quicksum(x[t,i]*u[t,j] for t in range(n)) <= int((ideal[j]) * row_sum[i]) + 1 for i in range(nclusters) for j in range(nbin))
        m.addConstrs(quicksum(x[t,i]*u[t,j] for t in range(n)) >= int((ideal[j]) * row_sum[i]) for i in range(nclusters) for j in range(nbin))
        for ml in ml_cons:
            p, q = ml[0], ml[1]
            m.addConstrs(x[p, j] == x[q, j] for j in range(k))
        m.addConstrs(x[i,j] >= 0 for i in range(n) for j in range(k))
        m.addConstrs(x[i,j] <= 1 for i in range(n) for j in range(k))
        m.addConstrs(quicksum(x[i,j] for j in range(k)) == 1 for i in range(n))
        m.optimize()
        second_loop = time.time()
        print("ILP solver's time:", second_loop - start_time)
        p = np.zeros((n, k))
        ind = 0
        for v in m.getVars():
            p[int(ind / k)][ind % k] = float(v.x)
            ind += 1
        return p

    def upload_nearest_dist(self, dataset):
        # Import the range of local perturbation for VAT
        nearest_dist = np.loadtxt(dataset + '/' + '10th_neighbor.txt').astype(np.float32)
        return nearest_dist

    def calculate_nearest_dist(self, dataset:BaseADDataset, address):
        import os.path
        if os.path.isfile(address+'/10th_neighbor.txt'):
            return
        train_loader, _ = dataset.loaders(batch_size=dataset.train_size, shuffle_train=False, num_workers=self.n_jobs_dataloader)
        num_data = [10]
        dataset_address = address
        with torch.no_grad():
            for j, data_t in enumerate(train_loader, 0):
                dist_list = [[] for i in range(len(num_data))]
                # get all inputs
                inputs_t, labels_t, idx_t, psv_t = data_t
                inputs_t, labels_t = inputs_t.to(self.device), labels_t.to(self.device)
                idx_t, psv_t = idx_t.to(self.device), psv_t.to(self.device)
                for i in range(len(inputs_t)):
                    if i%1000 == 0:
                        print(i)
                    aa = torch.mul(inputs_t - inputs_t[i],inputs_t - inputs_t[i])
                    if len(aa.shape) > 2:
                        dist = torch.sqrt(torch.sum(aa,dim=(2,3)))
                        dist_m = dist[:,0]
                    else:
                        dist = torch.sqrt(torch.sum(aa, dim=1))
                        dist_m = dist
                    dist_m[i] = 1000
                    sorted_dist = np.sort(dist_m.cpu().numpy())
                    for jj in range(len(num_data)):
                        dist_list[jj].append(sorted_dist[num_data[jj]])
            for ii in range(len(num_data)):
                np.savetxt(dataset_address + '/' + str(num_data[ii]) + 'th_neighbor.txt', np.array(dist_list[ii]))


    def train(self, dataset: BaseADDataset, ae_net: BaseNet, address):
        lam, mu = 0.1, 0.4  # hyper-parameters for deep_rim
        if self.dataset_name == 'har':
            lam, mu = 1, 6
        elif self.dataset_name == 'mnist_usps':
            lam, mu = 0.1, 4
        data_address = address
        self.calculate_nearest_dist(dataset, data_address)
        nearest_dist = torch.from_numpy(self.upload_nearest_dist(data_address))
        nearest_dist = nearest_dist.to(self.device)
        
        logger = logging.getLogger()
        from torch import nn
        # Set device for network
        ae_net = ae_net.to(self.device)
        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size, shuffle_train=True, num_workers=self.n_jobs_dataloader)
        # Set optimizer (Adam optimizer for now)
        optimizer = optim.Adam(ae_net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        for element in list(ae_net.parameters()):
            print(element.shape)
            
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

        # Training
        logger.info('Starting training...')
        start_time = time.time()
        ae_net.train()
        label_list, psvs_list, idx_list = [], [], []
        for data in train_loader:
            inputs, labels, idx, psvs = data
            psvs_list.extend(psvs.tolist())
        psvs_list=np.array(psvs_list)
        if self.dataset_name in ['adult', 'credit', 'har', 'bank']:
            offset = 0
            pretrain_epoch = 20
        else:
            offset = len(psvs_list[psvs_list==0])
            pretrain_epoch = 30
        p, p_t, target_ind = [], [], []
        for epoch in range(self.n_epochs):
            scheduler.step()
            if epoch in self.lr_milestones:
                logger.info('  LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
            loss_epoch = 0.0
            rim_loss, sat_loss = 0.0, 0.0
            sup_loss = 0.0
            n_batches = 0
            epoch_start_time = time.time()            
            idx_label_score = []
            emb_list = None
            emb_batch = 0
            T1 = time.clock()
            for data in train_loader:
                inputs, labels, idx, psvs = data
                if offset != 0:
                    inputs = inputs.view(-1, 28*28)
                    idx[psvs==1] += offset
                labels = labels.to(self.device)
                inputs = inputs.to(self.device)
 
                # forward
                scores = F.softmax(ae_net(inputs), dim=1)
                aver_entropy, entropy_aver = ae_net.compute_entropy(inputs)
                r_mutual_i = aver_entropy - mu * entropy_aver
                loss_ul = ae_net.loss_unlabeled(inputs, nearest_dist[idx])
                loss_fn2 = nn.CrossEntropyLoss()
                if epoch > pretrain_epoch:
                    p_now = torch.tensor(p[idx])
                    p_now = p_now.to(self.device, dtype=torch.long)
                    supervised_loss =  loss_fn2(scores, p_now) * (self.supervised_weight)
                else:
                    supervised_loss = torch.tensor(0)
                loss = loss_ul + lam * r_mutual_i + supervised_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_epoch += loss.item()
                rim_loss += lam * r_mutual_i.item()
                sat_loss += loss_ul.item()
                sup_loss += supervised_loss.item() 
                n_batches += 1
                # Save triple of (idx, label, score) in a list
                idx_label_score += list(zip(idx.cpu().data.numpy().tolist(),
                                            labels.cpu().data.numpy().tolist(),
                                            scores.cpu().data.numpy().tolist(),
                                            psvs.cpu().data.numpy().tolist()))

            idx, labels, scores, psvs = zip(*idx_label_score)
            labels = np.array(labels)
            scores = np.array(scores)
            idx_list = np.array(idx)
            psvs_list = np.array(psvs)
            if offset != 0:
                idx_list[psvs_list==1] += offset
            # log epoch statistics
            epoch_train_time = time.time() - epoch_start_time
            y_pred = np.argmax(scores, axis=1)
            y = labels
            ind = np.argsort(idx_list)
            if epoch >= pretrain_epoch:
                r = self.get_current_ratios(y_pred, psvs_list)
                ideal_r = self.get_ideal_ratios(labels, psvs_list)
                ml_cons = []
                p_target = self.fair_goal(scores, r, ideal_r, y_pred, psvs_list, ml_cons)
                p_t = np.argmax(p_target, axis=1)
                target_ind = ind
            if epoch >= pretrain_epoch:
                p = p_t[target_ind]
            n_clusters = len(np.unique(y))
            balance_result = compute_balance(n_clusters, y_pred, psvs_list)
            #print("balance result:", balance_result)
            if y is not None:
                print("==============================================================================")
                print("Epoch: %d, acc: %.5f, nmi: %.5f" % (epoch, acc(y, y_pred), normalized_mutual_info_score(y, y_pred)))
                final_acc = acc(y, y_pred)
                final_nmi = normalized_mutual_info_score(y, y_pred)
                #if epoch >=pretrain_epoch:
                #    print("fair assignment:", acc(y, p_t), normalized_mutual_info_score(y, p_t))
            logger.info('Train set nmi: {:.2f}%'.format(100. * final_nmi))
            logger.info('Train set acc: {:.2f}%'.format(100. * final_acc))
            logger.info('Train set balance: %.8f',balance_result)
        return ae_net

    def test(self, dataset: BaseADDataset, ae_net: BaseNet):
        logger = logging.getLogger()

        # Set device for network
        ae_net = ae_net.to(self.device)

        # Get test data loader
        _, test_loader = dataset.loaders(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)

        # Testing
        logger.info('Testing')
        n_batches = 0
        start_time = time.time()
        idx_label_score = []
        ae_net.eval()
        with torch.no_grad():
            for data in test_loader:
                inputs, labels, idx, psv = data
                inputs = inputs.to(self.device)
                outputs = ae_net(inputs)
                scores = F.softmax(outputs, dim=1)
                # Save triple of (idx, label, score) in a list
                idx_label_score += list(zip(idx.cpu().data.numpy().tolist(),
                                            labels.cpu().data.numpy().tolist(),
                                            scores.cpu().data.numpy().tolist(),
                                            psv.cpu().data.numpy().tolist()))

                n_batches += 1


        _, labels, scores, psvs = zip(*idx_label_score)
        labels = np.array(labels)
        scores = np.array(scores)
        print("scores:", scores)
        self.test_scores = idx_label_score
        y_pred = np.argmax(scores, axis=1)
        y = labels
        print(y_pred, y_pred.shape)
        if y is not None:
            print("acc: %.5f, nmi: %.5f" % (acc(y, y_pred), normalized_mutual_info_score(y, y_pred)))
            final_acc = acc(y, y_pred)
            final_nmi = normalized_mutual_info_score(y, y_pred)

        logger.info('Test set nmi: {:.2f}%'.format(100. * final_nmi))
        logger.info('Test set acc: {:.2f}%'.format(100. * final_acc))

        self.test_time = time.time() - start_time
        logger.info('Deep_RIM testing time: %.3f' % self.test_time)


