import random
import logging
import copy
from evaluation import get_result


def generate_initial_pop(pop_size, ori_head_num, target_head_num, pop_init_rate):
    """Generate initial population for a sub-EA."""
    p = []
    indiv = [i for i in range(ori_head_num)]
    for _ in range(pop_size):
        new_indiv = mutation(
            indiv,
            ori_head_num,
            mutation_rate=pop_init_rate,
            target_count=target_head_num,
        )
        p.append(new_indiv)
    p.sort()
    return p


def mutation(indiv, ori_head_num, mutation_rate, target_count):
    """Perform mutation operation on an individual."""
    temp_np = [[i, 0] for i in range(int(ori_head_num))]
    for idx in indiv:
        temp_np[idx][1] = 1
    random.shuffle(temp_np)
    current_count = len(indiv)
    for i in range(ori_head_num):
        if random.random() < mutation_rate:
            temp_np[i][1] = 1 - temp_np[i][1]
            current_count += 1 if temp_np[i][1] == 1 else -1
    if current_count != target_count:
        indices = [i for i, v in enumerate(temp_np) if v[1] == 1]
        random.shuffle(indices)
        if current_count > target_count:
            for idx in indices[: current_count - target_count]:
                temp_np[idx][1] = 0
        else:
            not_selected_indices = [i for i, v in enumerate(temp_np) if v[1] == 0]
            random.shuffle(not_selected_indices)
            for idx in not_selected_indices[: target_count - current_count]:
                temp_np[idx][1] = 1
    new_indiv = [item[0] for item in temp_np if item[1] == 1]
    new_indiv.sort()
    if len(new_indiv) != target_count:
        raise ValueError(
            "The number of heads in the new individual is not equal to the target number of heads"
        )
    return new_indiv


def get_fitness(bias, ppl, scalar, ori_bias, ori_ppl, fitness_mode="BIAS_PPL"):
    """Calculate fitness value based on bias and perplexity changes."""
    bias_change_ratio = (bias - ori_bias) / ori_bias
    ppl_change_ratio = (ppl - ori_ppl) / ori_ppl

    if fitness_mode == "BIAS_PPL":
        return -(bias_change_ratio + scalar * ppl_change_ratio)
    elif fitness_mode == "BIAS":
        return -bias_change_ratio
    elif fitness_mode == "PPL":
        return -ppl_change_ratio
    else:
        raise ValueError("The fitness mode is not supported")


def roulette_wheel_selection(fitness_list):
    """Perform roulette wheel selection."""
    fitness_values = [item[1] for item in fitness_list]
    min_fitness = min(fitness_values)
    if min_fitness < 0:
        shifted_values = [f - min_fitness + 1e-8 for f in fitness_values]
    else:
        shifted_values = [f + 1e-8 for f in fitness_values]

    sum_shifted = sum(shifted_values)

    try:
        probs = [v / sum_shifted for v in shifted_values]
    except ZeroDivisionError:
        return random.randint(0, len(fitness_list) - 1)

    cum_probs = []
    current = 0.0
    for p in probs:
        current += p
        cum_probs.append(current)
    r = random.random()
    for i, cp in enumerate(cum_probs):
        if r <= cp:
            return i
    return len(fitness_list) - 1


def non_dominated_sorting(population):
    """Perform non-dominated sorting for NSGA-II."""
    pop_info = [
        {"original": ind, "dominates": set(), "dominated_by": 0} for ind in population
    ]

    for i in range(len(pop_info)):
        for j in range(len(pop_info)):
            if i == j:
                continue
            if dominates(pop_info[i]["original"], pop_info[j]["original"]):
                pop_info[i]["dominates"].add(j)
                pop_info[j]["dominated_by"] += 1

    fronts = []
    current_front = []

    for idx, info in enumerate(pop_info):
        if info["dominated_by"] == 0:
            current_front.append(pop_info[idx]["original"])

    while current_front:
        fronts.append(current_front)
        next_front = []

        for ind in current_front:
            ind_idx = next(
                (i for i, info in enumerate(pop_info) if info["original"] == ind), -1
            )
            if ind_idx == -1:
                continue

            for dominated_idx in pop_info[ind_idx]["dominates"]:
                pop_info[dominated_idx]["dominated_by"] -= 1
                if pop_info[dominated_idx]["dominated_by"] == 0:
                    next_front.append(pop_info[dominated_idx]["original"])

        current_front = next_front

    return fronts


def dominates(ind1, ind2):
    """Check if individual 1 dominates individual 2."""
    return (ind1[4] <= ind2[4] and ind1[5] <= ind2[5]) and (
        ind1[4] < ind2[4] or ind1[5] < ind2[5]
    )


def calculate_crowding_distance(front):
    """Calculate crowding distance for NSGA-II."""
    if not front:
        return {}
    distances = {ind[0]: 0 for ind in front}
    for m in [4, 5]:  # objectives: bias and ppl
        front.sort(key=lambda x: x[m])
        distances[front[0][0]] = float("inf")
        distances[front[-1][0]] = float("inf")
        for i in range(1, len(front) - 1):
            distances[front[i][0]] += (front[i + 1][m] - front[i - 1][m]) / (
                front[-1][m] - front[0][m]
            )
    return distances


def get_front_level(ind, fronts):
    """Get front level of an individual."""
    for level, front in enumerate(fronts):
        if ind in front:
            return level
    return float("inf")


def get_crowding_distance(ind, crowding_dict):
    """Get crowding distance of an individual."""
    return crowding_dict.get(ind[0], 0)


def binary_tournament_selection(population, fronts, crowding_dict):
    """Perform binary tournament selection for NSGA-II."""
    candidates = random.sample(population, 2)

    fl1 = get_front_level(candidates[0], fronts)
    fl2 = get_front_level(candidates[1], fronts)

    if fl1 != fl2:
        return candidates[0] if fl1 < fl2 else candidates[1]

    return max(candidates, key=lambda x: crowding_dict.get(x[0], 0))


def evolution_step_nsgaii(
    model,
    tokenizer,
    ori_head_num,
    target_head_num,
    pop_size,
    evo_epoch,
    pop_init_mutation_rate,
    mutation_rate,
    scalar,
    args,
    ori_valid_bias,
    ori_valid_ppl,
):
    """Execute evolution using NSGA-II algorithm."""
    logger = logging.getLogger("evo_result")
    ori_fitness_valid = get_fitness(
        ori_valid_bias,
        ori_valid_ppl,
        scalar,
        ori_valid_bias,
        ori_valid_ppl,
        fitness_mode=args.fitness_mode,
    )
    logger.info("Original model fitness: {}".format(ori_fitness_valid))

    fathers = generate_initial_pop(
        pop_size, ori_head_num, target_head_num, pop_init_mutation_rate
    )
    fathers_fitness_list = []

    logger.info("Population initialization")
    for i in range(pop_size):
        tmp_model = copy.deepcopy(model)
        pruned_idxs = [_ for _ in range(ori_head_num) if _ not in fathers[i]]
        logger.info(f"Father {i} pruned idx: {pruned_idxs}")
        bias, ppl = get_result(
            model=tmp_model,
            tokenizer=tokenizer,
            idx_pruned_heads=pruned_idxs,
            args=args,
            save_csv=False,
            evo=True,
        )
        father_bias, father_ppl = bias[0], ppl[0]
        father_fitness = get_fitness(
            father_bias,
            father_ppl,
            scalar,
            ori_valid_bias,
            ori_valid_ppl,
            fitness_mode=args.fitness_mode,
        )
        logger.info(
            f"Father {i} bias: {father_bias}, ppl: {father_ppl}, fitness: {father_fitness}"
        )
        fathers_fitness_list.append(
            [i, father_fitness, fathers[i], len(fathers[i]), father_bias, father_ppl]
        )
        logger.info("\n")

    logger.info("Starting NSGA-II evolution")
    for epoch in range(evo_epoch):
        logger.info(f"Evolution epoch {epoch + 1}/{evo_epoch}")
        logger.info("\n")

        children_fitness_list = []

        fronts = non_dominated_sorting(fathers_fitness_list)
        crowding_dict = {}
        for front in fronts:
            crowding_dict.update(calculate_crowding_distance(front))

        for i in range(pop_size):
            if args.parent_selection == "binary_tournament":
                parent = binary_tournament_selection(
                    fathers_fitness_list, fronts, crowding_dict
                )
                father = parent[2]
            elif args.parent_selection == "random":
                father = fathers[random.randint(0, pop_size - 1)]
            elif args.parent_selection == "roulette_wheel":
                raise ValueError("The roulette_wheel selection method is not supported")
            else:
                raise ValueError("This parent selection method is not supported")

            child = mutation(
                father,
                ori_head_num=ori_head_num,
                mutation_rate=mutation_rate,
                target_count=target_head_num,
            )
            pruned_idxs = [_ for _ in range(ori_head_num) if _ not in child]
            logger.info(f"Child {i} pruned idx: {pruned_idxs}")
            bias, ppl = get_result(
                model=copy.deepcopy(model),
                tokenizer=tokenizer,
                idx_pruned_heads=pruned_idxs,
                args=args,
                save_csv=False,
                evo=True,
            )
            child_bias, child_ppl = bias[0], ppl[0]
            child_fitness = get_fitness(
                child_bias,
                child_ppl,
                scalar,
                ori_valid_bias,
                ori_valid_ppl,
                fitness_mode=args.fitness_mode,
            )
            logger.info(
                f"Child {i} bias: {child_bias}, ppl: {child_ppl}, fitness: {child_fitness}"
            )
            children_fitness_list.append(
                [i, child_fitness, child, len(child), child_bias, child_ppl]
            )
            logger.info("\n")

        combined_pop = fathers_fitness_list + children_fitness_list

        if args.ppl_threshold != 0:
            ppl_threshold = args.ppl_threshold * ori_valid_ppl
            filtered_pop = [ind for ind in combined_pop if ind[5] <= ppl_threshold]
            if len(filtered_pop) < pop_size:
                remaining = pop_size - len(filtered_pop)
                selected = random.choices(filtered_pop, k=remaining)
                filtered_pop += selected
                fronts = non_dominated_sorting(filtered_pop)
            else:
                fronts = non_dominated_sorting(filtered_pop)
        else:
            fronts = non_dominated_sorting(combined_pop)

        new_pop = []
        for front in fronts:
            if len(new_pop) + len(front) <= pop_size:
                new_pop.extend(front)
            else:
                crowding_distances = calculate_crowding_distance(front)
                front_sorted = sorted(
                    front, key=lambda x: crowding_distances[x[0]], reverse=True
                )
                new_pop.extend(front_sorted[: pop_size - len(new_pop)])
                break

        fathers_fitness_list = new_pop[:pop_size]
        fathers = [ind[2] for ind in fathers_fitness_list]
        logger.info(f"Population at evolution epoch {epoch + 1}/{evo_epoch}:")
        for idx, ind in enumerate(fathers_fitness_list):
            logger.info(
                f"Individual {idx} | "
                f"Bias: {ind[4]:.4f}, "
                f"PPL: {ind[5]:.2f}, "
                f"Heads num: {ind[3]}, "
                f"pruned head idx list: {[_ for _ in range(ori_head_num) if _ not in ind[2]]}"
            )
        logger.info("\n")
        logger.info(
            f"Best fitness at evolution epoch {epoch + 1}/{evo_epoch}: {max([father[1] for father in fathers_fitness_list])}, original model valid fitness: {ori_fitness_valid}"
        )
        logger.info("\n")

    logger.info("Final evaluation on test set")
    for i in range(pop_size):
        tmp_model = copy.deepcopy(model)
        pruned_idxs = [_ for _ in range(ori_head_num) if _ not in fathers[i]]
        bias, ppl = get_result(
            model=tmp_model,
            tokenizer=tokenizer,
            idx_pruned_heads=pruned_idxs,
            args=args,
            save_csv=False,
            test_only=True,
        )
        father_bias, father_ppl = bias[0], ppl[0]
        logger.info(
            f"Individual {i} | test bias: {father_bias}, test ppl: {father_ppl}"
        )
        logger.info("\n")
    logger.info("\n")


def evolution_step(
    model,
    tokenizer,
    ori_head_num,
    target_head_num,
    pop_size,
    evo_epoch,
    pop_init_mutation_rate,
    mutation_rate,
    scalar,
    args,
    ori_valid_bias,
    ori_valid_ppl,
):
    """Execute evolution using standard genetic algorithm."""
    logger = logging.getLogger("evo_result")
    ori_fitness_valid = get_fitness(
        ori_valid_bias,
        ori_valid_ppl,
        scalar,
        ori_valid_bias,
        ori_valid_ppl,
        fitness_mode=args.fitness_mode,
    )

    fathers = generate_initial_pop(
        pop_size, ori_head_num, target_head_num, pop_init_mutation_rate
    )
    fathers_fitness_list = []

    logger.info("Population initialization")
    for i in range(pop_size):
        tmp_model = copy.deepcopy(model)
        pruned_idxs = [_ for _ in range(ori_head_num) if _ not in fathers[i]]
        print(pruned_idxs)
        logger.info(f"Father {i} pruned idx: {pruned_idxs}")
        bias, ppl = get_result(
            model=tmp_model,
            tokenizer=tokenizer,
            idx_pruned_heads=pruned_idxs,
            args=args,
            save_csv=False,
            evo=True,
        )
        father_bias, father_ppl = bias[0], ppl[0]
        father_fitness = get_fitness(
            father_bias,
            father_ppl,
            scalar,
            ori_valid_bias,
            ori_valid_ppl,
            fitness_mode=args.fitness_mode,
        )
        logger.info(
            f"Father {i} bias: {father_bias}, ppl: {father_ppl}, fitness: {father_fitness}"
        )
        fathers_fitness_list.append(
            [i, father_fitness, fathers[i], len(fathers[i]), father_bias, father_ppl]
        )
        logger.info("\n")

    logger.info("Starting evolution")
    for epoch in range(evo_epoch):
        print(f"Evolution epoch {epoch+1}/{evo_epoch}")
        logger.info(f"Evolution epoch {epoch+1}/{evo_epoch}")
        logger.info("\n")

        children_fitness_list = []

        for i in range(pop_size):
            if args.parent_selection == "roulette_wheel":
                selected_idx = roulette_wheel_selection(fathers_fitness_list)
                father = fathers_fitness_list[selected_idx][2]
            elif args.parent_selection == "random":
                father = fathers[random.randint(0, pop_size - 1)]
            elif args.parent_selection == "binary_tournament":
                raise ValueError(
                    "The binary_tournament selection method is not supported"
                )
            else:
                raise ValueError("This parent selection method is not supported")
            child = mutation(
                father,
                ori_head_num=ori_head_num,
                mutation_rate=mutation_rate,
                target_count=target_head_num,
            )

            pruned_idxs = [_ for _ in range(ori_head_num) if _ not in child]
            logger.info(f"Child {i} pruned idx: {pruned_idxs}")

            bias, ppl = get_result(
                model=copy.deepcopy(model),
                tokenizer=tokenizer,
                idx_pruned_heads=pruned_idxs,
                args=args,
                save_csv=False,
                evo=True,
            )
            child_bias, child_ppl = bias[0], ppl[0]
            child_fitness = get_fitness(
                child_bias,
                child_ppl,
                scalar,
                ori_valid_bias,
                ori_valid_ppl,
                fitness_mode=args.fitness_mode,
            )

            logger.info(
                f"Child {i} bias: {child_bias}, ppl: {child_ppl}, fitness: {child_fitness}"
            )
            children_fitness_list.append(
                [i, child_fitness, child, len(child), child_bias, child_ppl]
            )
            logger.info("\n")

        all_list = fathers_fitness_list + children_fitness_list
        all_list.sort(key=lambda x: x[1], reverse=True)

        logger.info(f"Population at evolution epoch {epoch+1}/{evo_epoch}")
        idx = 0
        for j in range(pop_size):
            fathers[j] = all_list[j][2]
            fathers_fitness_list[j] = all_list[j]
            logger.info(
                f"Individual {idx} | "
                f"Bias: {fathers_fitness_list[j][4]:.4f}, "
                f"PPL: {fathers_fitness_list[j][5]:.2f}, "
                f"Heads num: {fathers_fitness_list[j][3]}, "
                f"pruned head idx list: {[_ for _ in range(ori_head_num) if _ not in fathers_fitness_list[j][2]]}"
            )
            idx += 1
        logger.info("\n")

        best_ind = fathers_fitness_list[0]
        logger.info(
            f"Best fitness at evolution epoch {epoch+1}/{evo_epoch}: {best_ind[1]}, original model valid fitness: {ori_fitness_valid}, head num now: {best_ind[3]}"
        )
        logger.info("\n")

    if best_ind[3] != target_head_num:
        raise ValueError(
            "The number of heads in the best individual is not equal to the target number of heads"
        )

    pruning_idxs = [_ for _ in range(ori_head_num) if _ not in best_ind[2]]
    logger.info(f"Final pruned idx: {pruning_idxs}")
    return pruning_idxs
