import torch
from FPVE_config import *
from thop import profile
import copy

import torchvision
import random
import numpy as np
import logging
from utils.result_logger import Result_Logger
from utils.train_eval import fairness_validate
from prune_tp.prune_finetune import Finetune
import torch_pruning as tp


class FPVE_fairness:
    def __init__(
        self,
        model,
        train_loader,
        valid_loader,
        test_loader,
        args,
        target_idx,
        sensitive_idx,
        FPVE_fitness_loader=None,
    ):
        # Dataloaders
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        self.args = args

        self.target_idx = target_idx
        self.sensitive_idx = sensitive_idx

        # Records
        self.acc = []
        self.DEO = []
        self.DI = []
        self.valid_acc = []
        self.valid_DEO = []
        self.valid_DI = []
        self.params = []
        self.FLOPS = []

        self.FILTER_NUMS = []
        self.FILTER_NUM = []

        self.ori_model = model
        self.model = model

        self.criterion = torch.nn.CrossEntropyLoss()
        self.h64 = True

        self.pop_size = self.args.pop_size
        self.pop_init_rate = self.args.pop_init_rate
        self.pruning_ratio = self.args.pruning_ratio
        self.mutation_rate = self.args.mutation_rate
        self.evolution_epoch = self.args.evolution_epoch

        self.result_logger = Result_Logger(self.args.save_dir)

        self.fitness_mode = self.args.fitness_mode

        self.use_FPVE_fitness_loader = False

        if FPVE_fitness_loader is not None:
            self.result_logger.log(
                "Use FPVE fitness loader, ratio: {}, num: {}".format(
                    self.args.FPVE_fitness_data_ratio, len(FPVE_fitness_loader)
                )
            )
            self.use_FPVE_fitness_loader = True
            self.FPVE_fitness_loader = FPVE_fitness_loader

        self.adversary_for_fitness = None

        random.seed(args.random_seed)
        np.random.seed(args.random_seed)
        torch.manual_seed(args.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.example_inputs = torch.randn(1, 3, 224, 224).cuda(self.args.gpu)

        self.init_acc = None
        self.init_deo = None
        self.scalar = self.args.scalar

    def fitness(self, test_model=None, test_adversary=None):
        if self.use_FPVE_fitness_loader:
            cur_dataloader = self.FPVE_fitness_loader
        else:
            cur_dataloader = self.valid_loader

        if test_model:
            cur_acc, cur_di, cur_deo = fairness_validate(
                cur_dataloader,
                test_model,
                self.criterion,
                self.args,
                self.target_idx,
                self.sensitive_idx,
                print_result=False,
            )
        else:
            cur_acc, cur_di, cur_deo = fairness_validate(
                cur_dataloader,
                self.model,
                self.criterion,
                self.args,
                self.target_idx,
                self.sensitive_idx,
                print_result=False,
            )

        if self.fitness_mode == "ACC":
            print("ACC...")
            return cur_acc, [cur_acc, cur_di, cur_deo]
        elif self.fitness_mode == "ACC_FAIRNESS":
            print("ACC_FAIRNESS...")
            acc_change_ratio = (cur_acc - self.init_acc) / self.init_acc
            deo_change_ratio = (cur_deo - self.init_deo) / self.init_deo
            return acc_change_ratio - self.scalar * deo_change_ratio, [
                cur_acc,
                cur_di,
                cur_deo,
                (acc_change_ratio, deo_change_ratio),
            ]
        else:
            raise NotImplementedError

    def evolution_step(self, group_name, filter_num, target_filter_num=None):
        pop = self.generate_initial_pop(filter_num, target_filter_num=target_filter_num)
        logger = logging.getLogger("train_logger")
        logger.info("Group:{0} | filter_num:{1}\n".format(group_name, filter_num))
        parent_fitness = []
        initial_fitness, initial_tri_fitness = self.fitness()
        logger.info(f"Initial fitness:{initial_fitness} | {initial_tri_fitness}")
        logger.info(f"Initial population")

        for i in range(self.pop_size):
            test_model = copy.deepcopy(self.model)

            pruning_idxs = [_ for _ in range(filter_num) if _ not in pop[i]]
            DG = tp.DependencyGraph()
            DG.build_dependency(test_model, self.example_inputs)
            for name, m in test_model.named_modules():
                if name == group_name:
                    cur_group = m
            pruning_group = DG.get_pruning_group(
                cur_group, tp.prune_conv_out_channels, idxs=pruning_idxs
            )
            pruning_group.prune()

            fitness_i, tri_fitness_i = self.fitness(test_model)
            parent_fitness.append([i, fitness_i, pop[i], len(pop[i]), tri_fitness_i])
            self.check_model_profile(test_model)
            logger.info(
                [
                    i,
                    fitness_i,
                    [_ for _ in range(filter_num) if _ not in pop[i]],
                    len(pop[i]),
                    tri_fitness_i,
                ]
            )

        parent_fitness.sort(key=lambda x: (x[1], -x[3]), reverse=True)
        for i in range(self.evolution_epoch):
            child_fitness = []
            logger.info(f"Population at round {i}")
            for j in range(self.pop_size):
                parent = pop[random.randint(0, self.pop_size - 1)]
                child_indiv = self.mutation(
                    parent,
                    filter_num,
                    self.mutation_rate,
                    target_count=target_filter_num,
                )
                test_model = copy.deepcopy(self.model)

                pruning_idxs = [_ for _ in range(filter_num) if _ not in child_indiv]
                DG = tp.DependencyGraph()
                DG.build_dependency(test_model, self.example_inputs)
                for name, m in test_model.named_modules():
                    if name == group_name:
                        cur_group = m
                pruning_group = DG.get_pruning_group(
                    cur_group, tp.prune_conv_out_channels, idxs=pruning_idxs
                )
                pruning_group.prune()

                fitness_j, tri_fitness_j = self.fitness(test_model)
                child_fitness.append(
                    [j, fitness_j, child_indiv, len(child_indiv), tri_fitness_j]
                )
                self.check_model_profile(test_model)
                logger.info(
                    [
                        j,
                        fitness_j,
                        [_ for _ in range(filter_num) if _ not in child_indiv],
                        len(child_indiv),
                        tri_fitness_j,
                    ]
                )
            logger.info("\n\n")
            temp_list = []
            for j in range(len(parent_fitness)):
                temp_list.append(parent_fitness[j])
            for j in range(len(child_fitness)):
                temp_list.append(child_fitness[j])

            temp_list.sort(key=lambda x: (x[1], -x[3]), reverse=True)
            logger.info(f"Population at epoch {i}:")
            for j in range(self.pop_size):
                pop[j] = temp_list[j][2]
                parent_fitness[j] = temp_list[j]

                logger.info(
                    [
                        parent_fitness[j][0],
                        parent_fitness[j][1],
                        [_ for _ in range(filter_num) if _ not in parent_fitness[j][2]],
                        len(parent_fitness[j][2]),
                        parent_fitness[j][4],
                    ]
                )
            logger.info(f"\n\n")
            best_ind = parent_fitness[0]
            logger.info(
                f"Best so far {best_ind[1]}, Initial fitness: {initial_fitness}, Filter now: {best_ind[3]}, Pruning ratio: {1 - best_ind[3] / filter_num}"
            )
        logger.info(
            f"Pruned filters {[_ for _ in range(filter_num) if _ not in best_ind[2]]}"
        )
        return best_ind[2]

    def generate_initial_pop(
        self, filter_num, add_unpruned_indiv=False, target_filter_num=None
    ):
        #  Generate initial population for a sub-EA.
        p = []
        indiv = [i for i in range(filter_num)]
        cnt = self.pop_size
        if add_unpruned_indiv:
            p.append(indiv)
            cnt = cnt - 1
        for _ in range(cnt):
            new_indiv = self.mutation(
                indiv,
                filter_num,
                mutation_rate=self.pop_init_rate,
                target_count=target_filter_num,
            )
            p.append(new_indiv)
        print(p)
        p.sort()
        return p

    def mutation(
        self, indiv, filter_num, mutation_rate=0.1, mode="shuffle", target_count=None
    ):
        if mode == "shuffle":
            temp_np = [[i, 0] for i in range(int(filter_num))]
            for idx in indiv:
                temp_np[idx][1] = 1

            random.shuffle(temp_np)

            current_count = len(indiv)

            for i in range(filter_num):
                if random.random() < mutation_rate:
                    temp_np[i][1] = 1 - temp_np[i][1]
                    current_count += 1 if temp_np[i][1] == 1 else -1

            if current_count != target_count:
                indices = [i for i, v in enumerate(temp_np) if v[1] == 1]
                random.shuffle(indices)
                if current_count > target_count:
                    for idx in indices[: current_count - target_count]:
                        temp_np[idx][1] = 0
                else:
                    not_selected_indices = [
                        i for i, v in enumerate(temp_np) if v[1] == 0
                    ]
                    random.shuffle(not_selected_indices)
                    for idx in not_selected_indices[: target_count - current_count]:
                        temp_np[idx][1] = 1

            new_indiv = [item[0] for item in temp_np if item[1] == 1]
            new_indiv.sort()
            print(new_indiv)
            return new_indiv if new_indiv else indiv

    def check_model_profile(self, model=None):
        logger = logging.getLogger("train_logger")
        if model == None:
            model = self.model.cuda(self.args.gpu)

        model_input = torch.randn(1, 3, 224, 224)
        if self.args.arch != "vgg":
            i_flops, i_params = profile(
                self.ori_model, inputs=(model_input.cuda(self.args.gpu),), verbose=False
            )
            logger.info(
                "initial model: FLOPs: {0}, params: {1}".format(i_flops, i_params)
            )
            p_flops, p_params = profile(
                model, inputs=(model_input.cuda(self.args.gpu),), verbose=False
            )
            logger.info(
                "pruned model: FLOPs: {0}({1:.2f}%), params: {2}({3:.2f}%)".format(
                    p_flops,
                    (p_flops / i_flops) * 100,
                    p_params,
                    (p_params / i_params) * 100,
                )
            )
        else:
            i_flops, i_params = profile(
                self.ori_model.module,
                inputs=(model_input.cuda(self.args.gpu),),
                verbose=False,
            )
            logger.info(
                "initial model: FLOPs: {0}, params: {1}".format(i_flops, i_params)
            )
            p_flops, p_params = profile(
                model, inputs=(model_input.cuda(self.args.gpu),), verbose=False
            )
            logger.info(
                "pruned model: FLOPs: {0}({1:.2f}%), params: {2}({3:.2f}%)".format(
                    p_flops,
                    (p_flops / i_flops) * 100,
                    p_params,
                    (p_params / i_params) * 100,
                )
            )

        self.model.cuda(self.args.gpu)
        return (p_flops / i_flops) * 100, (p_params / i_params) * 100

    def run(self, run_epoch):
        self.model.cuda(self.args.gpu)
        logger = logging.getLogger("train_logger")

        # BLOCK_NUM : List = []
        # FILTER_NUM_P2 : List = []
        # FILTER_NUM_P2 : List = []
        # LAYERS_P2 : str
        # LAYERS_P1 : str
        # sol_p1 = []
        # sol_p2 = []

        if self.args.arch in {"resnet18", "resnet18_half"}:
            BLOCK_NUM = [2, 2, 2, 2]
            LAYERS_P2 = [f"layer{i+1}.0.conv2" for i in range(len(BLOCK_NUM))]
            LAYERS_P1 = [
                f"layer{i+1}.{j}.conv1"
                for i in range(len(BLOCK_NUM))
                for j in range(BLOCK_NUM[i])
            ]
            if self.args.arch == "resnet18_half":
                FILTER_NUM_P2 = [32, 64, 128, 256]
                filter_num_p2 = FILTER_NUM_P2.copy()
            elif self.args.arch == "resnet18":
                FILTER_NUM_P2 = [64, 128, 256, 512]
                filter_num_p2 = FILTER_NUM_P2.copy()

            sol_p2 = [[] for _ in range(len(LAYERS_P2))]

            if self.args.arch == "resnet18_half":
                FILTER_NUM_P1 = [
                    32 * (2**layer)
                    for layer in range(len(BLOCK_NUM))
                    for _ in range(BLOCK_NUM[layer])
                ]
                filter_num_p1 = FILTER_NUM_P1.copy()
                sol_p1 = [[] for _ in range(sum(BLOCK_NUM))]
            elif self.args.arch == "resnet18":
                FILTER_NUM_P1 = [
                    64 * (2**layer)
                    for layer in range(len(BLOCK_NUM))
                    for _ in range(BLOCK_NUM[layer])
                ]
                filter_num_p1 = FILTER_NUM_P1.copy()
                sol_p1 = [[] for _ in range(sum(BLOCK_NUM))]
        else:
            raise NotImplementedError("Not supported arch")

        self.result_logger.log(self.args.__dict__)

        self.check_model_profile()
        logger.info(f"use_FPVE_fitness_loader:{self.use_FPVE_fitness_loader}")
        if self.use_FPVE_fitness_loader:
            cur_dataloader = self.FPVE_fitness_loader
        else:
            cur_dataloader = self.valid_loader
        ori_top1_acc, ori_DI, ori_DEO = fairness_validate(
            cur_dataloader,
            self.model,
            self.criterion,
            self.args,
            self.target_idx,
            self.sensitive_idx,
        )
        self.init_acc = ori_top1_acc
        self.init_deo = ori_DEO

        if self.use_FPVE_fitness_loader:
            self.result_logger.log(
                "FPVE fitness set | Original model ACC: {}, Original model DI: {}, DEO: {}".format(
                    ori_top1_acc, ori_DI, ori_DEO
                )
            )
        else:
            self.result_logger.log(
                "Valid set | Original model ACC: {}, Original model DI: {}, DEO: {}".format(
                    ori_top1_acc, ori_DI, ori_DEO
                )
            )

        logger.info("Valid set:")
        ori_top1_acc, ori_DI, ori_DEO = fairness_validate(
            self.valid_loader,
            self.model,
            self.criterion,
            self.args,
            self.target_idx,
            self.sensitive_idx,
        )
        self.result_logger.log(
            "Valid set | Original model ACC: {}, Original model DI: {}, DEO: {}".format(
                ori_top1_acc, ori_DI, ori_DEO
            )
        )
        logger.info("Test set:")
        ori_top1_acc, ori_DI, ori_DEO = fairness_validate(
            self.test_loader,
            self.model,
            self.criterion,
            self.args,
            self.target_idx,
            self.sensitive_idx,
        )
        self.result_logger.log(
            "Test set | Original model ACC: {}, Original model DI: {}, DEO: {}".format(
                ori_top1_acc, ori_DI, ori_DEO
            )
        )

        for i in range(run_epoch):
            cur_model = copy.deepcopy(self.model)
            logger.info(cur_model)


            # P2
            for layer in range(len(LAYERS_P2)):
                if (
                    self.args.arch == "resnet18"
                    or self.args.arch == "resnet18_half"
                ):
                    sol_p2[layer] = self.evolution_step(
                        LAYERS_P2[layer],
                        filter_num_p2[layer],
                        target_filter_num=int(
                            FILTER_NUM_P2[layer]
                            * (1 - (i + 1) * self.pruning_ratio / run_epoch)
                        ),
                    )

            # P1
            for layer in range(len(LAYERS_P1)):
                if self.args.arch == "resnet18" or self.args.arch == "resnet18_half":
                    sol_p1[layer] = self.evolution_step(
                        LAYERS_P1[layer],
                        filter_num_p1[layer],
                        target_filter_num=int(
                            FILTER_NUM_P1[layer]
                            * (1 - (i + 1) * self.pruning_ratio / run_epoch)
                        ),
                    )


            # P2 pruning
            for layer in range(len(LAYERS_P2)):
                if (
                    self.args.arch == "resnet18"
                    or self.args.arch == "resnet18_half"
                ):
                    logger.info(
                        f"{layer}, {LAYERS_P2[layer]}, {filter_num_p2[layer]}, {sol_p2[layer]}"
                    )
                    pruning_idxs = [
                        _
                        for _ in range(filter_num_p2[layer])
                        if _ not in sol_p2[layer]
                    ]
                    DG = tp.DependencyGraph()
                    DG.build_dependency(
                        cur_model, example_inputs=self.example_inputs
                    )
                    for name, m in cur_model.named_modules():
                        if name == LAYERS_P2[layer]:
                            cur_group = m
                    pruning_group = DG.get_pruning_group(
                        cur_group, tp.prune_conv_out_channels, idxs=pruning_idxs
                    )
                    pruning_group.prune()
                    filter_num_p2[layer] = len(sol_p2[layer])

            # P1 pruning
            for layer in range(len(LAYERS_P1)):
                if self.args.arch == "resnet18" or self.args.arch == "resnet18_half":
                    logger.info(
                        f"{layer}, {LAYERS_P1[layer]}, {filter_num_p1[layer]}, {sol_p1[layer]}"
                    )
                    pruning_idxs = [
                        _ for _ in range(filter_num_p1[layer]) if _ not in sol_p1[layer]
                    ]
                    DG = tp.DependencyGraph()
                    DG.build_dependency(cur_model, example_inputs=self.example_inputs)
                    for name, m in cur_model.named_modules():
                        if name == LAYERS_P1[layer]:
                            cur_group = m
                    pruning_group = DG.get_pruning_group(
                        cur_group, tp.prune_conv_out_channels, idxs=pruning_idxs
                    )
                    pruning_group.prune()
                    filter_num_p1[layer] = len(sol_p1[layer])
            logger.info("Pruned arch: ")
            logger.info(cur_model)
            logger.info(f"filter_num_p1 {filter_num_p1}")
            logger.info(f"filter_num_p2 {filter_num_p2}")
            logger.info("epoch:{0} before fine-tune...".format(i + 1))
            self.check_model_profile(cur_model)
            self.FILTER_NUMS.append(filter_num_p1[:])

            logger.info("Valid set:")
            self.result_logger("Valid set:")
            acc_val_top1, DI_val, DEO_val = fairness_validate(
                self.valid_loader,
                cur_model,
                self.criterion,
                self.args,
                self.target_idx,
                self.sensitive_idx,
            )
            self.result_logger.log(
                "epoch:{} before fine-tune... ACC: {}, DI: {}, DEO: {}".format(
                    i + 1, acc_val_top1, DI_val, DEO_val
                )
            )

            logger.info("Test set:")
            self.result_logger("Test set:")
            acc_test_top1, DI_test, DEO_test = fairness_validate(
                self.test_loader,
                cur_model,
                self.criterion,
                self.args,
                self.target_idx,
                self.sensitive_idx,
            )
            self.result_logger.log(
                "epoch:{} before fine-tune... ACC: {}, DI: {}, DEO: {}".format(
                    i + 1, acc_test_top1, DI_test, DEO_test
                )
            )

            # finetuning
            finetune_config = {
                "num_class": self.args.num_class,
                "num_sensitive_class": self.args.num_sensitive_class,
                "num_features": filter_num_p2[-1],
                "gpu": self.args.gpu,
                "epochs": self.args.ft_epochs,
                "p_lr": self.args.p_lr,
                'a_lr': self.args.a_lr,
                "adv_mode": self.args.adv_mode,
                "save_every": 5,
                "arch": self.args.arch,
                "noadv_add_schedular": self.args.noadv_add_schedular,
                "args": self.args,
            }
            cur_save_dir = self.args.save_dir + "/iter_{}/".format(i + 1)
            if not os.path.exists(cur_save_dir):
                os.makedirs(cur_save_dir)
            finetune_worker = Finetune(
                self.train_loader,
                self.valid_loader,
                self.test_loader,
                self.target_idx,
                self.sensitive_idx,
                finetune_config,
            )

            self.model, _ = finetune_worker.do(
                cur_model, cur_save_dir, return_best_valid=False, FPVE_flag=True
            )

            logger.info("epoch:{0} after fine-tune...".format(i + 1))
            flops, params = self.check_model_profile(self.model)

            logger.info("Valid set:")
            self.result_logger("Valid set:")
            acc_val_top1, DI_val, DEO_val = fairness_validate(
                self.valid_loader,
                self.model,
                self.criterion,
                self.args,
                self.target_idx,
                self.sensitive_idx,
                mode="Valid",
            )
            self.result_logger.log(
                "epoch:{} after fine-tune... ACC: {}, DI: {}, DEO: {}".format(
                    i + 1, acc_val_top1, DI_val, DEO_val
                )
            )

            logger.info("Test set:")
            self.result_logger("Test set:")
            acc_test_top1, DI_test, DEO_test = fairness_validate(
                self.test_loader,
                self.model,
                self.criterion,
                self.args,
                self.target_idx,
                self.sensitive_idx,
                mode="Test",
            )
            self.result_logger.log(
                "epoch:{} after fine-tune... ACC: {}, DI: {}, DEO: {}".format(
                    i + 1, acc_test_top1, DI_test, DEO_test
                )
            )

            self.acc.append(acc_test_top1)
            self.FLOPS.append(flops)
            self.params.append(params)
            self.DEO.append(DEO_test)
            self.DI.append(DI_test)

            self.valid_acc.append(acc_val_top1)
            self.valid_DEO.append(DEO_val)
            self.valid_DI.append(DI_val)

            logger.info(f"Test set ACC:{self.acc}")
            logger.info(f"Test set DEO:{self.DEO}")
            logger.info(f"Test set DI:{self.DI}")
            logger.info(f"Valid set ACC:{self.valid_acc}")
            logger.info(f"Valid set DEO:{self.valid_DEO}")
            logger.info(f"Valid set DI:{self.valid_DI}")
            logger.info(f"FLOPS:{self.FLOPS}")
            logger.info(f"Params:{self.params}")

            self.result_logger.log(f"Test set ACC:{self.acc}")
            self.result_logger.log(f"Test set DEO:{self.DEO}")
            self.result_logger.log(f"Test set DI:{self.DI}")
            self.result_logger.log(f"Valid set ACC:{self.valid_acc}")
            self.result_logger.log(f"Valid set DEO:{self.valid_DEO}")
            self.result_logger.log(f"Valid set DI:{self.valid_DI}")

            self.result_logger.log(f"FLOPS:{self.FLOPS}")
            self.result_logger.log(f"Params:{self.params}")
            for i in range(len(self.FILTER_NUMS)):
                logger.info(f"FILTER_NUM at epoch {i + 1}:{self.FILTER_NUMS[i]}")

        logger.info(f"Test set ACC:{self.acc}")
        logger.info(f"Test set DEO:{self.DEO}")
        logger.info(f"Test set DI:{self.DI}")

        logger.info(f"Valid set ACC:{self.valid_acc}")
        logger.info(f"Valid set DEO:{self.valid_DEO}")
        logger.info(f"Valid set DI:{self.valid_DI}")
        logger.info(f"FLOPS:{self.FLOPS}")
        logger.info(f"Params:{self.params}")
        for i in range(len(self.FILTER_NUMS)):
            logger.info(f"FILTER_NUM at epoch {i+1}:{self.FILTER_NUMS[i]}")
