import numpy as np

from evolution.utils import *



evo_args = {
    'population_size': 50,
    'generations': 20,
    'mutation_rate_init': 0.5,
    'elitism_k': 3,
    'tournament_k': 5,
    'mutation_decay': 0.998,
}


# ==== 遗传算法主循环 ====
def evolution_search(ins_list, outs_list, act_quant, W, b, args, prefix, device):
    
    # 全随机化初始化种群
    _, C_in = W.shape
    population = initialize_population(C_in, evo_args['population_size'], device)
    population.append(torch.arange(C_in).to(device))

    # # 一半通过排序，一半随机初始化
    # population = initialize_population_half_sort(W, evo_args['population_size'], pmax=1.0, pmin=0.75, ratio=1.0)

    best_individual = None
    best_fitness = float('inf')
    mut_rate = evo_args['mutation_rate_init']

    batches = len(ins_list)
    for gen in range(evo_args['generations']):
        # fitnesses = [evaluate_fitness_calib(W, individual, cur_inps, cur_outps, quant_args) for individual in population]
        
        avg_fitnesses = np.zeros(evo_args['population_size'] + 1, dtype=np.float32)
        # avg_fitnesses = np.zeros(1, dtype=np.float32)
        for batch_index in range(batches):
            cached_inps = ins_list[batch_index]
            cached_outs = outs_list[batch_index]
            fitnesses = np.array([evaluate_fitness(W, b, individual, cached_inps, cached_outs, args, act_quant) for individual in population])
            avg_fitnesses += fitnesses
        avg_fitnesses /= batches

        idx_sorted = np.argsort(avg_fitnesses)
        min_idx = int(np.argmin(avg_fitnesses))

        # 精英保留
        new_population = [population[i] for i in idx_sorted[:evo_args['elitism_k']]]

        if avg_fitnesses[min_idx] < best_fitness:
            best_fitness = avg_fitnesses[min_idx]
            best_individual = population[min_idx].clone()

        # 生成后续个体
        while len(new_population) < evo_args['population_size'] + 1:
            parent1 = tournament_selection(population, avg_fitnesses, evo_args['tournament_k'])
            parent2 = tournament_selection(population, avg_fitnesses, evo_args['tournament_k'])
            child = order_crossover(parent1, parent2)
            new_population.append(mutate(child, mut_rate))
            # new_population.append(child)

        population = new_population
        # mut_rate *= evo_args['mutation_decay']
        # if gen % 10 == 0 or (gen + 1) == evo_args['generations']:
        print(prefix + f" Generation {gen+1}: Best MSE = {best_fitness:.6f}")

    return best_individual