import argparse
import os
import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
from utils.fairness_cal import *
import torch_pruning as tp
import pandas as pd


def fitness_loss(
    output: torch.Tensor,
    labels: torch.Tensor,
    target_idx,
    sensitive_idx,
    args,
    mode=None,
):
    # Move tensors to the GPU if available

    output = output.cuda(args.gpu)
    labels = labels.cuda(args.gpu)

    y_score = output.detach()
    y_bar = torch.argmax(y_score, dim=1)
    y = labels[:, target_idx]
    z = labels[:, sensitive_idx]

    fitness_setting = args.fitness

    if mode is not None:
        fitness_setting = mode

    if fitness_setting == "DEO":
        fitness1 = (
            torch.sum((y_bar == 1) & (y == 0) & (z == 0)).float()
            / torch.sum((y == 0) & (z == 0)).float()
            - torch.sum((y_bar == 1) & (y == 0) & (z == 1)).float()
            / torch.sum((y == 0) & (z == 1)).float()
        )
        fitness2 = (
            torch.sum((y_bar == 1) & (y == 1) & (z == 0)).float()
            / torch.sum((y == 1) & (z == 0)).float()
            - torch.sum((y_bar == 1) & (y == 1) & (z == 1)).float()
            / torch.sum((y == 1) & (z == 1)).float()
        )
        fitness = abs(fitness1) + abs(fitness2)

        return fitness.item()
    elif fitness_setting == "DI":
        fitness = min(
            (torch.sum((y_bar == 1) & (z == 0)).float() / torch.sum((z == 0)).float())
            / (
                torch.sum((y_bar == 1) & (z == 1)).float() / torch.sum((z == 1)).float()
            ),
            (torch.sum((y_bar == 1) & (z == 1)).float() / torch.sum((z == 1)).float())
            / (
                torch.sum((y_bar == 1) & (z == 0)).float() / torch.sum((z == 0)).float()
            ),
        )
        return fitness.item()
    elif fitness_setting == "ALL":
        # DEO
        fitness1 = (
            torch.sum((y_bar == 1) & (y == 0) & (z == 0)).float()
            / torch.sum((y == 0) & (z == 0)).float()
            - torch.sum((y_bar == 1) & (y == 0) & (z == 1)).float()
            / torch.sum((y == 0) & (z == 1)).float()
        )
        fitness2 = (
            torch.sum((y_bar == 1) & (y == 1) & (z == 0)).float()
            / torch.sum((y == 1) & (z == 0)).float()
            - torch.sum((y_bar == 1) & (y == 1) & (z == 1)).float()
            / torch.sum((y == 1) & (z == 1)).float()
        )
        DEO_fitness = abs(fitness1) + abs(fitness2)
        # DI
        DI_fitness = min(
            (torch.sum((y_bar == 1) & (z == 0)).float() / torch.sum((z == 0)).float())
            / (
                torch.sum((y_bar == 1) & (z == 1)).float() / torch.sum((z == 1)).float()
            ),
            (torch.sum((y_bar == 1) & (z == 1)).float() / torch.sum((z == 1)).float())
            / (
                torch.sum((y_bar == 1) & (z == 0)).float() / torch.sum((z == 0)).float()
            ),
        )
        return DEO_fitness.item(), DI_fitness.item()

def accuracy(output, label, topk=(1,)):
    maxk = max(topk)
    batch_size = label.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(label.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def fairness_validate(
    val_test_loader,
    model,
    criterion,
    args,
    target_idx,
    sensitive_idx,
    print_result=True,
    mode="Valid",
):
    batch_time = AverageMeter()
    time_data_indexing = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    logger = logging.getLogger("train_logger")
    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_test_loader):
            time_data_indexing_end = time.time()
            if args.gpu is not None:
                labels = labels.cuda(args.gpu)
                inputs_var = inputs.cuda(args.gpu)
                labels_var = labels[:, target_idx].cuda(args.gpu)
            else:
                inputs_var = inputs
                labels_var = labels[:, target_idx]
            time_data_indexing.update(time.time() - time_data_indexing_end)
            output = model(inputs_var).float()
            loss = criterion(output, labels_var).float()
            prec1 = accuracy(output.data, labels_var)[0]
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            if i % args.print_freq == 0 and print_result == True:
                logger = logging.getLogger("train_logger")
                logger.info(
                    mode
                    + " [{0}/{1}]\t"
                    "Data_index Time {time_data_indexing.val:.3f} ({time_data_indexing.avg:.3f})\t"
                    "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    "Loss {loss.val:.4f}({loss.avg:.4f})\t"
                    "Prec@1 {top1.val:.3f}%({top1.avg:.3f}%)\t".format(
                        i,
                        len(val_test_loader),
                        time_data_indexing=time_data_indexing,
                        batch_time=batch_time,
                        loss=losses,
                        top1=top1,
                    )
                )
        di = DI(val_test_loader, model, target_idx, sensitive_idx, gpu=args.gpu)
        deo = DEO(val_test_loader, model, target_idx, sensitive_idx, gpu=args.gpu)
        if print_result:
            logger.info(
                "{0} acc: {top1.avg}%\t {0} DI: {1}\t {0} DEO: {2}".format(
                    mode, di, deo, top1=top1
                )
            )
    return top1.avg, di, deo

def train(
    train_loader,
    model,
    criterion,
    optimizer,
    epoch,
    target_idx,
    sensitive_idx,
    args,
    cal_DI=True,
):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    fitnesses = AverageMeter()
    fitnesses_DI = AverageMeter()

    logger = logging.getLogger("train_logger")
    model.train()

    end = time.time()
    for i, (inputs, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        if args.gpu is not None:
            labels = labels.cuda(args.gpu)
            inputs_var = inputs.cuda(args.gpu)
            labels_var = labels[:, target_idx].cuda(args.gpu)
        else:
            inputs_var = inputs
            labels_var = labels[:, target_idx]

        output = model(inputs_var)
        loss = criterion(output, labels_var)

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        output = output.float()
        loss = loss.float()

        fitness = fitness_loss(output.data, labels, target_idx, sensitive_idx, args)
        prec1 = accuracy(output.data, labels_var)[0]
        if cal_DI:
            fitness_DI = fitness_loss(
                output.data, labels, target_idx, sensitive_idx, args, mode="DI"
            )
            fitnesses_DI.update(fitness_DI, inputs.size(0))

        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        fitnesses.update(fitness, inputs.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            logger.info(
                "Epoch: [{0}][{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f}({data_time.avg:.3f})\t"
                "Loss {loss.val:.4f}({loss.avg:.4f})\t"
                "Prec@1 {top1.val:.3f}({top1.avg:.3f})\t"
                "DI {fitnesses_DI.val:.5f}({fitnesses_DI.avg:.5f})\t"
                "{3}_fitness_loss {fitnesses.val:.5f}({fitnesses.avg:.5f})".format(
                    epoch,
                    i,
                    len(train_loader),
                    args.fitness,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    top1=top1,
                    fitnesses_DI=fitnesses_DI,
                    fitnesses=fitnesses,
                )
            )

    logger.info(
        " *Prec@1 {top1.avg:.3f}".format(top1=top1)
        + " *Fairness {fitnesses.avg:.3f}".format(fitnesses=fitnesses)
    )
    return top1.avg, fitnesses_DI.avg, fitnesses.avg


hidden_features = None


def hook_fn(module, input, output):
    global hidden_features
    hidden_features = output


def adversarial_debias_train_f(
    train_loader,
    predictor,
    adversary,
    criterion,
    predictor_optimizer,
    adversary_optimizer,
    epoch,
    target_idx,
    sensitive_idx,
    args,
    cal_DI=True,
    w_decay=False,
):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    protect_losses = AverageMeter()
    top1 = AverageMeter()
    fitnesses = AverageMeter()
    fitnesses_DI = AverageMeter()

    logger = logging.getLogger("train_logger")

    predictor.train()
    adversary.train()

    predictor.avgpool.register_forward_hook(hook_fn)

    end = time.time()
    for i, (inputs, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        if args.gpu is not None:
            inputs_var = inputs.cuda(args.gpu)
            labels_var = labels[:, target_idx].cuda(args.gpu)
            sensitive_labels = labels[:, sensitive_idx].cuda(args.gpu)
        else:
            inputs_var = inputs
            labels_var = labels[:, target_idx]
            sensitive_labels = labels[:, sensitive_idx]

        predictor_optimizer.zero_grad()
        adversary_optimizer.zero_grad()

        pred = predictor(inputs_var)

        flat_hidden_features = torch.flatten(hidden_features, start_dim=1)
        protect_pred = adversary(flat_hidden_features)

        pred_loss = criterion(pred, labels_var)

        adversarial_loss_func = nn.CrossEntropyLoss()
        protect_loss = adversarial_loss_func(protect_pred, sensitive_labels)

        protect_loss.backward(retain_graph=True)
        protect_grad = {}
        for name, param in predictor.named_parameters():
            if "fc" not in name:
                protect_grad[name] = param.grad.clone()

        adversary_optimizer.step()
        predictor_optimizer.zero_grad()
        pred_loss.backward()
        if w_decay:
            w = args.w * min(((epoch + 1) / args.epochs), 1.0)
        else:
            w = args.w
        with torch.no_grad():
            for name, param in predictor.named_parameters():
                # print(name, param.grad.shape)
                if "fc" not in name:
                    if args.use_projection:
                        unit_protect = protect_grad[name] / torch.linalg.norm(
                            protect_grad[name]
                        )
                        param.grad -= (param.grad * unit_protect).sum() * unit_protect
                    param.grad -= w * protect_grad[name]
        predictor_optimizer.step()
        pred = pred.float()
        pred_loss = pred_loss.float()
        fitness = fitness_loss(pred.data, labels, target_idx, sensitive_idx, args)
        prec1 = accuracy(pred.data, labels_var)[0]
        if cal_DI:
            fitness_DI = fitness_loss(
                pred.data, labels, target_idx, sensitive_idx, args, mode="DI"
            )
            fitnesses_DI.update(fitness_DI, inputs.size(0))
        protect_losses.update(protect_loss.item(), inputs.size(0))
        losses.update(pred_loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        fitnesses.update(fitness, inputs.size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            logger.info(
                "Epoch: [{0}][{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f}({data_time.avg:.3f})\t"
                "Loss {loss.val:.4f}({loss.avg:.4f})\t"
                "Prec@1 {top1.val:.3f}({top1.avg:.3f})\t"
                "Protect_loss {protect_loss.val:.4f}({protect_loss.avg:.4f})\t"
                "DI {fitnesses_DI.val:.5f}({fitnesses_DI.avg:.5f})\t"
                "{3}_fitness_loss {fitnesses.val:.5f}({fitnesses.avg:.5f})".format(
                    epoch,
                    i,
                    len(train_loader),
                    args.fitness,
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses,
                    top1=top1,
                    protect_loss=protect_losses,
                    fitnesses_DI=fitnesses_DI,
                    fitnesses=fitnesses,
                )
            )
    logger.info(
        " *Prec@1 {top1.avg:.3f}".format(top1=top1)
        + " *Fairness {fitnesses.avg:.3f}".format(fitnesses=fitnesses)
    )
    return top1.avg, fitnesses.avg


def save_checkpoint(state, is_best, file_name="checkpoint.pth.tar"):
    torch.save(state, file_name)


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


