
from base.base_trainer import BaseTrainer
from base.base_dataset import BaseADDataset
from base.base_net import BaseNet
from sklearn.metrics import roc_auc_score, f1_score
from optim.sinkhorn import SinkhornDistance
import time
import torch
import torch.optim as optim
import numpy as np
from tqdm import trange
from tensorboardX import SummaryWriter as sw
from scipy.stats.distributions import chi2
from math import sqrt

# self-defined library


class Trainer(BaseTrainer):

    def __init__(self, optimizer_name: str = 'adam', lr: float = 0.001, begin_epoch: int = 1, end_epoch: int = 100,
                 batch_size: int = 128, device: str = 'cuda', print=None, dataset_name=None, results_dir=None, 
                 _lambda=None, latent_dimension=None, Fairway=None, stop_threshold=None, entropy_reg_coe=None, 
                 beta=None, milestones=None, threshold_ratio=0.9):

        super().__init__(optimizer_name, lr, end_epoch - begin_epoch, None, batch_size, device)


        self.begin_epoch = begin_epoch
        self.end_epoch = end_epoch
        self._lambda = _lambda
        self.beta = beta
 
        self.print = print
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.results_dir = results_dir
        self.lr_milestones = milestones
        self.latent_dimension = latent_dimension
        self.target_distribution_sampling_epoch = 1
        self.fairway = Fairway
        self.test_epoch = 1
        self.results = {
            'auc': 0.0,
            "normal_adpd": 0.0,
            "all_adpd": 0.0,
            "normal_fr": 0.0,
            "all_fr": 0.0,
            "Time": None
        }
 
        self.emd_optimizer = SinkhornDistance(eps=entropy_reg_coe, max_iter=int(5 * 1e3), thresh=stop_threshold, device=self.device)
        self.threshold = None
        self.threshold_ratio = threshold_ratio


    def train(self, dataset: BaseADDataset, ae_net: BaseNet):
        
        # tensorboardX saving loss
        loss_writer = sw()

        # Set device for network
        ae_net = ae_net.to(self.device)
        # Set optimizer
        if self.optimizer_name == 'adam':
            optimizer = optim.Adam(ae_net.parameters(), lr=self.lr)
        elif self.optimizer_name == 'sgd':
            optimizer = optim.SGD(ae_net.parameters(), lr=self.lr)
        else:
            raise Exception(f'Unknown optimizer name [{self.optimizer_name}].')
        # Set learning rate scheduler
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=0.1)
        # Get train data loader
        train_loader, _ = dataset.loaders(batch_size=self.batch_size)

        # Training
        print('Starting training...')
        ae_net.train()
        step = 1

        start = time.time()
        with trange(self.begin_epoch, self.end_epoch) as pbar:
            for epoch in pbar:
                loss_epoch = 0.0
                n_batches = 0
                recon_loss_batch = 0.0
                dist_loss_batch = 0.0
                fair_loss_batch = 0.0

                all_mid_repre = None

                for data in train_loader:
                    inputs, _, _, pvs = data
                    inputs = inputs.to(self.device)
                    
                    optimizer.zero_grad()
                    
                    if self.fairway == 'explicit':

                        outputs, mid_repre = ae_net(inputs)
                        if all_mid_repre is None:
                            all_mid_repre = mid_repre
                        else:
                            all_mid_repre = torch.cat((all_mid_repre, mid_repre))          

                        # reconstruction loss
                        if self.dataset_name in ['celeba']:
                            recon_loss = torch.sum(torch.abs((outputs - inputs)), dim=tuple(range(1, outputs.dim()))) * self.beta
                        else:
                            recon_loss = torch.sum((outputs - inputs) ** 2, dim=tuple(range(1, outputs.dim()))) * self.beta

                        loss = torch.mean(recon_loss)
                        recon_loss_batch += loss.item()
                        loss_writer.add_scalar('%s_%s_recon_loss' % (self.dataset_name, str(self.latent_dimension)), loss.item(), step)

                        mid_repre_pv_a = mid_repre[pvs == 1]
                        mid_repre_pv_b = mid_repre[pvs != 1]
                        if len(mid_repre_pv_a) > 1 and len(mid_repre_pv_b) > 1:

                            score_pv_a = torch.sqrt(torch.sum(mid_repre_pv_a ** 2, dim=tuple(range(1, mid_repre_pv_a.dim()))))
                            score_pv_b = torch.sqrt(torch.sum(mid_repre_pv_b ** 2, dim=tuple(range(1, mid_repre_pv_b.dim()))))
                            score_pv_a = score_pv_a.unsqueeze(1)
                            score_pv_b = score_pv_b.unsqueeze(1)

                            fair_loss = self.emd_optimizer(score_pv_a, score_pv_b) * self._lambda

                            fair_loss_batch += fair_loss.item()
                            loss_writer.add_scalar('%s_fair_loss' % self.dataset_name, fair_loss.item(), step)
                            loss += fair_loss


                        # distribution distance loss term
                        targets = target_distribution_sampling(self.batch_size, self.latent_dimension)
                        targets = targets.to(self.device)
                        dist_loss = self.emd_optimizer(mid_repre, targets)
                        dist_loss_batch += dist_loss.item()
                        loss_writer.add_scalar('%s_%s_dist_loss' % (self.dataset_name, str(self.latent_dimension)), dist_loss.item(), step)
                        loss += dist_loss
                    
                    elif self.fairway == 'implicit':

                        inputs_pv_a = inputs[pvs == 1]
                        inputs_pv_b = inputs[pvs != 1]

                        # reconstruction loss term
                        outputs_pv_a, mid_repre_pv_a = ae_net(inputs_pv_a)
                        outputs_pv_b, mid_repre_pv_b = ae_net(inputs_pv_b)

                        if all_mid_repre is None:
                            all_mid_repre = torch.cat((mid_repre_pv_a, mid_repre_pv_b))
                        else:
                            all_mid_repre = torch.cat((all_mid_repre, torch.cat((mid_repre_pv_a, mid_repre_pv_b))))
                        recon_loss = torch.sum((torch.cat((outputs_pv_a, outputs_pv_b)) - torch.cat((inputs_pv_a, inputs_pv_b))) ** 2, dim=tuple(range(1, outputs_pv_a.dim()))) * self.beta
                        loss = torch.mean(recon_loss)
                        recon_loss_batch += loss.item()
                        fair_loss = 0.0
                        if len(inputs_pv_a) > 1:    
                            targets_pv_a = target_distribution_sampling(len(inputs_pv_a), self.latent_dimension)
                            targets_pv_a = targets_pv_a.to(self.device)

                            fair_loss += self.emd_optimizer(mid_repre_pv_a, targets_pv_a)
                            

                        if len(inputs_pv_b) > 1:
                            targets_pv_b = target_distribution_sampling(len(inputs_pv_b), self.latent_dimension)
                            targets_pv_b = targets_pv_b.to(self.device)

                            fair_loss += self.emd_optimizer(mid_repre_pv_b, targets_pv_b)
                        
                        fair_loss_batch += fair_loss.item()
                        loss_writer.add_scalar('%s_fair_loss' % self.dataset_name, fair_loss.item(), step)
                        loss += fair_loss


                    loss.backward()
                    optimizer.step()

                    loss_writer.add_scalar('%s_%s_loss' % (self.dataset_name, str(self.latent_dimension)), loss.item(), step)
                    step += 1
                    loss_epoch += loss.item()
                    n_batches += 1
                
                pbar.set_description(
                    'Loss: {:.4f}  distri_loss: {:.4f}  recon_loss: {:.4f}  fair_loss: {:.4f}'.format(loss_epoch / n_batches, dist_loss_batch / n_batches, recon_loss_batch / n_batches, fair_loss_batch / n_batches)
                            )

                if epoch % self.test_epoch == 0:
                    
                    # # finding a feasible threshold
                    dist = torch.sqrt(torch.sum(all_mid_repre ** 2, dim=tuple(range(1, all_mid_repre.dim()))))
                    dist_list = dist.cpu().data.numpy().tolist()
                    dist_list.sort()
                    self.threshold = dist_list[int(self.threshold_ratio * len(dist_list))]
                    # self.print(f'\nepoch:[{epoch}]#############################')
                    # print(f'Threshold:[{self.threshold}]')
                    self.test(dataset, ae_net)
                    ae_net.train()
                # scheduler.step()
                    
        self.results['Time'] = time.time() - start
        print(f'using time: {self.results["Time"]}')
        print('Finished training.')


    def test(self, dataset: BaseADDataset, ae_net: BaseNet):

        # Set device for network
        ae_net = ae_net.to(self.device)

        # Get test data loader
        _, test_loader = dataset.loaders(batch_size=self.batch_size)
        # self.print('model testing ...')

        score = None
        idx_label_pv_score = []
        ae_net.eval()

        with torch.no_grad():
            for data in test_loader:
                inputs, labels, idx, pv = data
                inputs = inputs.to(self.device)

                _, mid_repre = ae_net(inputs)

     
                score = torch.sqrt(torch.sum(mid_repre ** 2, dim=tuple(range(1, mid_repre.dim()))))
                
                idx_label_pv_score += list(zip(
                    idx.cpu().data.numpy().tolist(),
                    labels.cpu().data.numpy().tolist(),
                    pv.cpu().data.numpy().tolist(),
                    score.cpu().data.numpy().tolist()
                    ))


        _, labels, pvs, scores = zip(*idx_label_pv_score)
        
        # detection performance =========================================
        # AUC
        auc = roc_auc_score(labels, scores)
        self.print(f'Test set AUC: [{auc * 100:.2f}%]')

        pred_labels = []
        for score in scores:
            if score > self.threshold:
                pred_labels.append(1)
            else:
                pred_labels.append(0)
        f1 = f1_score(labels, pred_labels)
        
        # Fairness ===================================================
        # Average Demographic Parity Difference
        all_adpd, normal_adpd, _ = average_demographic_parity_difference(scores=np.array(scores), labels=np.array(labels), pvs=np.array(pvs))
        print(f'Test set ADPD (normal): [{normal_adpd * 100:.2f}%]')
        print(f'Test set ADPD (all): [{all_adpd * 100:.2f}%]')

        # Fairness Ratio
        all_fr, normal_fr = fairness_ratio(scores=np.array(scores), labels=np.array(labels), pvs=np.array(pvs), threshold=self.threshold)



def target_distribution_sampling(size, sample_dim):

    '''
    :params
    size:
    sample_dim: the dimension of the sample from the restricted distribution

    '''
    # print(f'size: {size}')
    # size = 1000 if size < 1000 else size
    Sampler = torch.randn
    r_min = 0.0

    r = get_radius(d=sample_dim, p=0.9, sigma=1)

    targets = None
    
    while size > 0:

        sample = Sampler(sample_dim)
        sample_norm = torch.sqrt(torch.sum(sample ** 2))

        if r_min < sample_norm < r:
            if targets is None:
                targets = sample.unsqueeze(0)
            else:
                targets = torch.cat((targets, sample.unsqueeze(0)))
            size -= 1
    return targets


def get_radius(d, p=0.9, sigma=1):

    assert 0 < p < 1

    r = sigma * sqrt(chi2.ppf(p, d))
    return r
        

def average_demographic_parity_difference(scores, labels, pvs):
    
    sort_scores = scores.tolist()
    sort_scores.sort()
    all_adpd = []
    normal_adpd = []
    abnormal_adpd = []

    n_pv_a = len(scores[pvs == 1])
    n_pv_b = len(scores[pvs != 1])

    for t in sort_scores:
        n_pv_a_1 = 0
        n_pv_b_1 = 0
        n_normal_pv_a = 0
        n_normal_pv_b = 0
        n_normal_pv_a_1 = 0
        n_normal_pv_b_1 = 0
        n_abnormal_pv_a = 0
        n_abnormal_pv_b = 0
        n_abnormal_pv_a_1 = 0
        n_abnormal_pv_b_1 = 0

        for score, label, pv in zip(scores, labels, pvs):
            
            if label == 0:
                if pv == 1:
                    n_normal_pv_a += 1
                    if score > t:
                        n_normal_pv_a_1 += 1
                        n_pv_a_1 += 1
                else:
                    n_normal_pv_b += 1
                    if score > t:
                        n_normal_pv_b_1 += 1
                        n_pv_b_1 += 1
            else:
                if pv == 1:
                    n_abnormal_pv_a += 1
                    if score > t:
                        n_abnormal_pv_a_1 += 1
                        n_pv_a_1 += 1
                else:
                    n_abnormal_pv_b += 1
                    if score > t:
                        n_abnormal_pv_b_1 += 1
                        n_pv_b_1 += 1
        
        all_p_a = n_pv_a_1 / n_pv_a
        all_p_b = n_pv_b_1 / n_pv_b
        all_adpd.append(abs(all_p_a - all_p_b))

        normal_p_a = n_normal_pv_a_1 / n_normal_pv_a
        normal_p_b = n_normal_pv_b_1 / n_normal_pv_b
        normal_adpd.append(abs(normal_p_a - normal_p_b))

        abnormal_p_a = n_abnormal_pv_a_1 / n_abnormal_pv_a
        abnormal_p_b = n_abnormal_pv_b_1 / n_abnormal_pv_b
        abnormal_adpd.append(abs(abnormal_p_a - abnormal_p_b))


    return np.mean(all_adpd), np.mean(normal_adpd), np.mean(abnormal_adpd)


def fairness_ratio(scores, labels, pvs, threshold):

    n_pv_a = len(scores[pvs == 1])
    n_pv_b = len(scores[pvs != 1])
    n_pv_a_1 = 0
    n_pv_b_1 = 0

    n_normal_pv_a = 0
    n_normal_pv_b = 0
    n_normal_pv_a_1 = 0
    n_normal_pv_b_1 = 0
    n_abnormal_pv_a = 0
    n_abnormal_pv_b = 0
    n_abnormal_pv_a_1 = 0
    n_abnormal_pv_b_1 = 0


    for score, label, pv in zip(scores, labels, pvs):
        if label == 0:
            if pv == 1:
                n_normal_pv_a += 1
                if score > threshold:
                    n_normal_pv_a_1 += 1
                    n_pv_a_1 += 1
            else:
                n_normal_pv_b += 1
                if score > threshold:
                    n_normal_pv_b_1 += 1
                    n_pv_b_1 += 1
        else:
            if pv == 1:
                n_abnormal_pv_a += 1
                if score > threshold:
                    n_abnormal_pv_a_1 += 1
                    n_pv_a_1 += 1
            else:
                n_abnormal_pv_b += 1
                if score > threshold:
                    n_abnormal_pv_b_1 += 1
                    n_pv_b_1 += 1
    
    # fairness ratio of all the test set

    if n_pv_a > 0 and n_pv_b > 0 and n_pv_a_1 > 0 and n_pv_b_1 > 0:
        p_a = n_pv_a_1 / n_pv_a
        p_b = n_pv_b_1 / n_pv_b
        all_fr = min(p_a / p_b, p_b / p_a)
    else:
        all_fr = 0.0

    # fairness ratio of normal samples in test set
    if n_normal_pv_a > 0 and n_normal_pv_b > 0 and n_normal_pv_a_1 > 0 and n_normal_pv_b_1 > 0:
        normal_p_a = n_normal_pv_a_1 / n_normal_pv_a
        normal_p_b = n_normal_pv_b_1 / n_normal_pv_b
        normal_fr = min(normal_p_a / normal_p_b, normal_p_b / normal_p_a)
    else:
        normal_fr = 0.0

    return all_fr, normal_fr
 
