import pdb
import numpy as np 
import random

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.utils.data import RandomSampler

from dataloader import CustomizedDataset
from core import _pgd_whitebox
from eval_metrics import compute_accurate_metrics, compute_fairness_metrics

class GaussianDomain(object):
    def __init__(self, args):
        self.args = args
    
    def inter_domain_noise(self, batch_size):
        if self.args.dataset == 'income':
            data_dim = 90
        elif self.args.dataset == 'compas':
            data_dim = 11

        batch_global = torch.randn(size=(batch_size, data_dim)) * torch.tensor(self.args.std).unsqueeze(0) + torch.tensor(self.args.mu).unsqueeze(0)
        batch_global = batch_global.float().cuda()

        return batch_global

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)

    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    l2_distance = ((total0-total1)**2).sum(2)

    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(l2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]

    kernel_val = [torch.exp(-l2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)

def mmd(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX) + torch.mean(YY) - torch.mean(XY) - torch.mean(YX)
    return loss

class LocalUpdate(object):
    def __init__(self, args, user_data, user_id):
        self.args = args
        self.client_id = user_id
        self.train_loader, self.val_loader = self.train_val_split(user_data)
        
        if self.args.gaussian_mode == 'default' or args.gaussian_mode == 'manual':
            self.gaussian_loader = GaussianDomain(args)

        else:
            exit("Wrong gaussian noise!")

    def train_val_split(self, user_data):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """

        inputs, sensitive_attributes, targets = user_data[0], user_data[1], user_data[2]
        num_samples = len(inputs)

        np.random.seed(0)
        idxs = np.random.choice(num_samples, num_samples, replace=False)

        self.idxs_train = idxs[:int(0.8*num_samples)]
        self.idxs_val = idxs[int(0.8*num_samples):]

        self.inputs_arr = np.array(inputs)
        self.sensitive_attributes_arr = np.array(sensitive_attributes)
        self.targets_arr = np.array(targets)

        self.train_dataset = CustomizedDataset(inputs_rows=self.inputs_arr[self.idxs_train], 
                                          sensitive_rows=self.sensitive_attributes_arr[self.idxs_train], 
                                          target_rows=self.targets_arr[self.idxs_train])
                                          
        self.val_dataset = CustomizedDataset(inputs_rows=self.inputs_arr[self.idxs_val], 
                                    sensitive_rows=self.sensitive_attributes_arr[self.idxs_val], 
                                    target_rows=self.targets_arr[self.idxs_val])

        train_loader = DataLoader(self.train_dataset, batch_size=self.args.batch_size, sampler=RandomSampler(self.train_dataset))
        val_loader = DataLoader(self.val_dataset, batch_size=self.args.batch_size, shuffle=False)

        return train_loader, val_loader 

    def update_weights_global_only(self, model, global_round):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=self.args.momentum, 
                                    weight_decay=self.args.weight_decay)
        
        criterion = nn.CrossEntropyLoss(reduction='none')
        criterion_kl = nn.KLDivLoss(reduction='none')

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, batch in enumerate(self.train_loader):
                
                inputs, sensitive_attributes, targets = batch[0].cuda(), batch[1].cuda(), batch[2].cuda()
                batch_size = inputs.shape[0]

                group_1 = (targets == 1) * (sensitive_attributes == 1)
                group_2 = (targets == 1) * (sensitive_attributes == 0)
                group_3 = (targets == 0) * (sensitive_attributes == 1)
                group_4 = (targets == 0) * (sensitive_attributes == 0)

                num_list = [group_1.sum().item(), group_2.sum().item(), group_3.sum().item(), group_4.sum().item()]
                if 0 in num_list:
                    continue

                if self.args.gaussian_mode == 'default' or self.args.gaussian_mode == 'manual':
                    gaussian_batch_size = batch_size
                else:
                    exit("wrong gaussian mode!")

                inputs_global = self.gaussian_loader.inter_domain_noise(gaussian_batch_size)

                model.zero_grad()
                outputs = model(inputs)
                loss_ce_vec = criterion(outputs, targets)

                loss_ce_group_1 = torch.sum(loss_ce_vec * group_1) / (torch.sum(group_1) + 1e-6)
                loss_ce_group_2 = torch.sum(loss_ce_vec * group_2) / (torch.sum(group_2) + 1e-6)
                loss_ce_group_3 = torch.sum(loss_ce_vec * group_3) / (torch.sum(group_3) + 1e-6)
                loss_ce_group_4 = torch.sum(loss_ce_vec * group_4) / (torch.sum(group_4) + 1e-6)
                loss_ce = loss_ce_group_1 + loss_ce_group_2 + loss_ce_group_3 + loss_ce_group_4

                if self.args.gaussian_mode == 'default' or self.args.gaussian_mode == 'manual':
                    outputs_global = model(inputs_global)
                    loss_global = mmd(outputs, outputs_global)

                loss = loss_ce + self.args.lambda_robust * loss_global
                
                loss.backward()
                optimizer.step()

                predictions = torch.argmax(outputs, dim=-1)
                acc = torch.mean((predictions == targets).float())

                batch_loss.append(loss.item())

                if batch_idx % 30 == 0:
                    print('Global Round : {}, Local Epoch : {} [{}/{}]\tLoss: {:.4f}\tLoss_reg: {:.4f}'.format(
                        global_round, iter, batch_idx, len(self.train_loader), loss.item(), loss_global.item()))
 
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

            self.validate(model)

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)


    def update_weights_ins_adv(self, model, global_round):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr, momentum=self.args.momentum, 
                                    weight_decay=self.args.weight_decay)
        
        criterion = nn.CrossEntropyLoss(reduction='none')
        criterion_kl = nn.KLDivLoss(reduction='none')

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, batch in enumerate(self.train_loader):
                
                inputs, sensitive_attributes, targets = batch[0].cuda(), batch[1].cuda(), batch[2].cuda()
                batch_size = inputs.shape[0]

                group_1 = (targets == 1) * (sensitive_attributes == 1)
                group_2 = (targets == 1) * (sensitive_attributes == 0)
                group_3 = (targets == 0) * (sensitive_attributes == 1)
                group_4 = (targets == 0) * (sensitive_attributes == 0)

                num_list = [group_1.sum().item(), group_2.sum().item(), group_3.sum().item(), group_4.sum().item()]
                if 0 in num_list:
                    continue

                inputs_fair = _pgd_whitebox(model, inputs, sensitive_attributes, targets, self.args, iter)
                model.train()

                if self.args.gaussian_mode == 'default' or self.args.gaussian_mode == 'manual':
                    gaussian_batch_size = batch_size
                else:
                    exit("wrong gaussian mode!")

                inputs_global = self.gaussian_loader.inter_domain_noise(gaussian_batch_size)

                model.zero_grad()
                outputs = model(inputs_fair)
                loss_ce_vec = criterion(outputs, targets)

                loss_ce_group_1 = torch.sum(loss_ce_vec * group_1) / (torch.sum(group_1) + 1e-6)
                loss_ce_group_2 = torch.sum(loss_ce_vec * group_2) / (torch.sum(group_2) + 1e-6)
                loss_ce_group_3 = torch.sum(loss_ce_vec * group_3) / (torch.sum(group_3) + 1e-6)
                loss_ce_group_4 = torch.sum(loss_ce_vec * group_4) / (torch.sum(group_4) + 1e-6)
                loss_ce = loss_ce_group_1 + loss_ce_group_2 + loss_ce_group_3 + loss_ce_group_4

                if self.args.gaussian_mode == 'default' or self.args.gaussian_mode == 'manual':
                    outputs_global = model(inputs_global)
                    loss_global = mmd(outputs, outputs_global)

                loss = loss_ce + self.args.lambda_robust * loss_global
                
                loss.backward()
                optimizer.step()

                predictions = torch.argmax(outputs, dim=-1)
                acc = torch.mean((predictions == targets).float())

                batch_loss.append(loss.item())

                if batch_idx % 30 == 0:
                    print('Global Round : {}, Local Epoch : {} [{}/{}]\tLoss: {:.4f}\tLoss_reg: {:.4f}'.format(
                        global_round, iter, batch_idx, len(self.train_loader), loss.item(), loss_global.item()))
 
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

            self.validate(model)

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)


    def validate(self, model):
        """ Returns the validate accuracy and loss.
        """

        model.eval()
        
        targets_his = []
        predictions_his = []
        sensitive_attributes_his = []
        for batch_idx, batch in enumerate(self.val_loader): 
            inputs, sensitive_attributes, targets = batch[0].cuda(), batch[1].cuda(), batch[2].cuda()
     
            # Inference
            outputs = model(inputs)

            # Prediction
            pred_labels = torch.argmax(outputs, dim=1)

            targets_his += targets.tolist()
            predictions_his += pred_labels.tolist()
            sensitive_attributes_his += sensitive_attributes.tolist()

        acc_mean, acc, acc_std = compute_accurate_metrics(np.array(predictions_his), np.array(sensitive_attributes_his), np.array(targets_his))
        demo_parity, eq_odds = compute_fairness_metrics(np.array(predictions_his), np.array(sensitive_attributes_his), np.array(targets_his))

        print("Client ID: {}, DP {:.3f}, EO {:.3f}, ACC {:.3f}".format(self.client_id, demo_parity, eq_odds, acc))
        print("=" * 50)