import copy
import os
import sys
import torch
import torch.nn as nn
import logging
from utils.train_eval import *
import torch.optim.lr_scheduler as lr_scheduler

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)

class Adversarial_f_h64(nn.Module):
    def __init__(self, num_features, num_sensitive_class):
        super().__init__()
        self.num_features = num_features
        self.num_sensitive_class = num_sensitive_class
        self.h = nn.Linear(self.num_features, 64)
        self.l = nn.Linear(64, self.num_sensitive_class)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.h(x)
        x = self.l(x)
        x = self.sigmoid(x)
        return x


class VirtualArgs:
    def __init__(self, configs):
        self.w = configs["w"]
        self.gpu = configs["gpu"]
        self.print_freq = configs["print_freq"]
        self.fitness = configs["fitness"]
        self.epochs = configs["epochs"]
        self.arch = configs["arch"]


class Finetune:
    def __init__(
        self,
        train_loader,
        valid_loader,
        test_loader,
        target_idx,
        sentitive_idx,
        configs,
    ):
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.target_idx = target_idx
        self.sensitive_idx = sentitive_idx
        self.gpu = configs["gpu"]
        self.epochs = configs["epochs"]
        self.w = configs["w"]
        self.p_lr = configs["p_lr"]
        self.a_lr = configs['a_lr']
        self.adv_mode = configs["adv_mode"]
        self.save_every = configs["save_every"]
        self.num_class = configs["num_class"]
        self.num_sensitive_class = configs["num_sensitive_class"]
        self.num_features = configs["num_features"]
        self.arch = configs["arch"]
        self.noadv_add_schedular = configs["noadv_add_schedular"]
        self.args = configs["args"]

    def do(
        self, model, save_dir, return_best_valid=False, FPVE_flag=False
    ):
        logger = logging.getLogger("train_logger")
        train_top1_list = []
        valid_top1_list = []
        test_top1_list = []
        train_DEO_list = []
        valid_DEO_list = []
        test_DEO_list = []
        train_DI_list = []
        valid_DI_list = []
        test_DI_list = []

        best_prec_val = 0
        best_prec_test = 0
        if return_best_valid:
            best_model = copy.deepcopy(model)
        
        # Adversarial training
        if self.adv_mode:
            logger.info("Adversarial training mode")
            adversary = Adversarial_f_h64(
                    self.num_features, self.num_sensitive_class
                )
            adversary.cuda(self.gpu)

            model.cuda(self.gpu)

            best_model_dict = None

            predictor_optimizer = torch.optim.Adam(
                model.parameters(), lr=self.p_lr
                )
            adversary_optimizer = torch.optim.Adam(
                    adversary.parameters(), lr=self.a_lr
                )
            predictor_scheduler = lr_scheduler.CosineAnnealingLR(
                predictor_optimizer, T_max=self.epochs
            )
            adversary_scheduler = lr_scheduler.CosineAnnealingLR(
                adversary_optimizer, T_max=self.epochs
            )
            criterion = nn.CrossEntropyLoss()
            criterion.cuda(self.gpu)

            virtual_args = {
                "w": self.w,
                "gpu": self.gpu,
                "print_freq": 50,
                "fitness": "DEO",
                "epochs": self.epochs,
                "arch": self.arch,
                "use_projection": self.use_projection,
            }
            virtual_args = VirtualArgs(virtual_args)

            for epoch in range(self.epochs):
                logger.info(
                        f"adv_mode: take features before the classification layer as input, 64 hidden layer"
                    )
                prec_train, f_train = adversarial_debias_train_f(
                        self.train_loader,
                        model,
                        adversary,
                        criterion,
                        predictor_optimizer,
                        adversary_optimizer,
                        epoch,
                        self.target_idx,
                        self.sensitive_idx,
                        virtual_args,
                        w_decay=self.w_decay,
                    )
                predictor_scheduler.step()
                adversary_scheduler.step()

                # train_top1_list.append(prec_train)
                prec_train, di_train, deo_train = fairness_validate(
                    self.train_loader,
                    model,
                    criterion,
                    virtual_args,
                    self.target_idx,
                    self.sensitive_idx,
                )
                train_top1_list.append(prec_train)
                train_DI_list.append(di_train)
                train_DEO_list.append(deo_train)

                prec_val, di_val, deo_val = fairness_validate(
                    self.valid_loader,
                    model,
                    criterion,
                    virtual_args,
                    self.target_idx,
                    self.sensitive_idx,
                )
                valid_top1_list.append(prec_val)
                valid_DI_list.append(di_val)
                valid_DEO_list.append(deo_val)

                prec_test, di_test, deo_test = fairness_validate(
                    self.test_loader,
                    model,
                    criterion,
                    virtual_args,
                    self.target_idx,
                    self.sensitive_idx,
                )
                test_top1_list.append(prec_test)
                test_DI_list.append(di_test)
                test_DEO_list.append(deo_test)

                is_best_val = prec_val > best_prec_val
                best_prec_val = max(prec_val, best_prec_val)

                is_best_test = prec_test > best_prec_test
                best_prec_test = max(prec_test, best_prec_test)

                if is_best_val:
                    if return_best_valid:
                        best_valid_acc = valid_top1_list[epoch]
                        best_valid_deo = valid_DEO_list[epoch]
                        best_test_acc = test_top1_list[epoch]
                        best_test_deo = test_DEO_list[epoch]

                        model.zero_grad()
                        best_model = copy.deepcopy(model)
                        best_model_dict = tp.state_dict(model)

                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_val,
                        },
                        is_best_val,
                        file_name=os.path.join(save_dir, "best_model_valid.pth"),
                    )

                if is_best_test:
                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_test,
                        },
                        is_best_test,
                        file_name=os.path.join(save_dir, "best_model_test.pth"),
                    )

                if epoch > 0 and (epoch + 1) % self.save_every == 0:
                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_val,
                        },
                        is_best_val,
                        file_name=os.path.join(
                            save_dir, "checkpoint_ep_{}.pth".format(epoch + 1)
                        ),
                    )
                save_checkpoint(
                    {
                        "state_dict": model.state_dict(),
                        "best_prec1": best_prec_val,
                    },
                    is_best_val,
                    file_name=os.path.join(save_dir, "last_model.pth"),
                )

                logger.info(f"train_top1:{train_top1_list}")
                logger.info(f"train_DI:{train_DI_list}")
                logger.info(f"train_DEO:{train_DEO_list}")
                logger.info(f"valid_top1:{valid_top1_list}")
                logger.info(f"valid_fairness:{valid_DI_list}")
                logger.info(f"valid_DEO:{valid_DEO_list}")
                logger.info(f"test_top1:{test_top1_list}")
                logger.info(f"test_fairness:{test_DI_list}")
                logger.info(f"test_DEO:{test_DEO_list}")

            if return_best_valid:
                logger.info(
                    "Valid acc: {} | Valid DEO: {} | Test acc: {} | Test DEO: {}".format(
                        best_valid_acc, best_valid_deo, best_test_acc, best_test_deo
                    )
                )
                if FPVE_flag:
                    return best_model, None
                else:
                    return best_model, best_model_dict
            else:
                model.zero_grad()
                last_model = copy.deepcopy(model)
                last_model_dict = tp.state_dict(model)

                logger.info("use last model")
                logger.info(
                    "Valid acc: {} | Valid DEO: {} | Test acc: {} | Test DEO: {}".format(
                        prec_val, deo_val, prec_test, deo_test
                    )
                )
                if FPVE_flag:
                    return last_model, None
                else:
                    return last_model, last_model_dict
        # Normal finetune
        else:
            logger.info("Normal finetune")
            model.cuda(self.gpu)

            optimizer = torch.optim.SGD(
                model.parameters(), lr=self.p_lr, momentum=0.9, weight_decay=0.0001
            )

            criterion = nn.CrossEntropyLoss()
            criterion.cuda(self.gpu)

            if self.noadv_add_schedular:
                logger.info("Mode noadv_add_schedular, use CosineAnnealingLR")
                schedular = lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs)

            virtual_args = {
                "w": self.w,
                "gpu": self.gpu,
                "print_freq": 50,
                "fitness": "DEO",
                "epochs": self.epochs,
                "arch": self.arch,
            }
            virtual_args = VirtualArgs(virtual_args)
            for epoch in range(self.epochs):
                prec_train, di_train, deo_train = train(
                    self.train_loader,
                    model,
                    criterion,
                    optimizer,
                    epoch,
                    self.target_idx,
                    self.sensitive_idx,
                    virtual_args,
                )

                if self.noadv_add_schedular:
                    schedular.step()

                prec_train, di_train, deo_train = fairness_validate(
                    self.train_loader,
                    model,
                    criterion,
                    virtual_args,
                    self.target_idx,
                    self.sensitive_idx,
                )
                train_top1_list.append(prec_train)
                train_DI_list.append(di_train)
                train_DEO_list.append(deo_train)

                prec_val, di_val, deo_val = fairness_validate(
                    self.valid_loader,
                    model,
                    criterion,
                    virtual_args,
                    self.target_idx,
                    self.sensitive_idx,
                )
                valid_top1_list.append(prec_val)
                valid_DI_list.append(di_val)
                valid_DEO_list.append(deo_val)

                prec_test, di_test, deo_test = fairness_validate(
                    self.test_loader,
                    model,
                    criterion,
                    virtual_args,
                    self.target_idx,
                    self.sensitive_idx,
                )
                test_top1_list.append(prec_test)
                test_DI_list.append(di_test)
                test_DEO_list.append(deo_test)

                is_best_val = prec_val > best_prec_val
                best_prec_val = max(prec_val, best_prec_val)

                is_best_test = prec_test > best_prec_test
                best_prec_test = max(prec_test, best_prec_test)

                if is_best_val:
                    if return_best_valid:
                        best_valid_acc = valid_top1_list[epoch]
                        best_valid_deo = valid_DEO_list[epoch]
                        best_test_acc = test_top1_list[epoch]
                        best_test_deo = test_DEO_list[epoch]

                        model.zero_grad()
                        best_model = copy.deepcopy(model)
                        best_model_dict = tp.state_dict(model)

                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_val,
                        },
                        is_best_val,
                        file_name=os.path.join(save_dir, "best_model_valid.pth"),
                    )

                if is_best_test:
                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_test,
                        },
                        is_best_test,
                        file_name=os.path.join(save_dir, "best_model_test.pth"),
                    )

                if epoch > 0 and (epoch + 1) % self.save_every == 0:
                    save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "state_dict": model.state_dict(),
                            "best_prec1": best_prec_val,
                        },
                        is_best_val,
                        file_name=os.path.join(
                            save_dir, "checkpoint_ep_{}.pth".format(epoch + 1)
                        ),
                    )
                save_checkpoint(
                    {
                        "state_dict": model.state_dict(),
                        "best_prec1": best_prec_val,
                    },
                    is_best_val,
                    file_name=os.path.join(save_dir, "last_model.pth"),
                )

                logger.info(f"train_DI={train_DI_list}")
                logger.info(f"train_top1={train_top1_list}")
                logger.info(f"train_DEO={train_DEO_list}")

                logger.info(f"valid_DI={valid_DI_list}")
                logger.info(f"valid_top1={valid_top1_list}")
                logger.info(f"valid_DEO={valid_DEO_list}")

                logger.info(f"test_DI={test_DI_list}")
                logger.info(f"test_top1={test_top1_list}")
                logger.info(f"test_DEO={test_DEO_list}")
            if return_best_valid:
                logger.info(
                    "Valid acc: {} | Valid DEO: {} | Test acc: {} | Test DEO: {}".format(
                        best_valid_acc, best_valid_deo, best_test_acc, best_test_deo
                    )
                )

                if FPVE_flag:
                    return best_model, None
                else:
                    return best_model, best_model_dict
            else:
                model.zero_grad()
                last_model = copy.deepcopy(model)
                last_model_dict = tp.state_dict(model)

                logger.info("use last model")
                logger.info(
                    "Valid acc: {} | Valid DEO: {} | Test acc: {} | Test DEO: {}".format(
                        prec_val, deo_val, prec_test, deo_test
                    )
                )

                if FPVE_flag:
                    return last_model, None
                else:
                    return last_model, last_model_dict
