import torch
from qdit.quant import quantize_tensor_channel_group

import random

# ==== 全随机初始化种群 ====
def initialize_population(num_channels, pop_size, device):
    return [torch.randperm(num_channels).to(device) for _ in range(pop_size)]


# ==== 锦标赛选择 ====
def tournament_selection(population, fitnesses, k):
    selected = random.sample(list(zip(population, fitnesses)), k)
    selected.sort(key=lambda x: x[1])
    return selected[0][0].clone()


# ==== Order Crossover (OX) ====
def order_crossover(parent1: torch.Tensor, parent2: torch.Tensor):
    parent1 = parent1.to(dtype=torch.int64)
    parent2 = parent2.to(dtype=torch.int64)
        
    N = parent1.size(0)
    cut1, cut2 = sorted(random.sample(range(N), 2))

    child = -torch.ones(N, dtype=torch.int64).to(parent1.device)  # Fix: make sure dtype can hold -1
    child[cut1:cut2+1] = parent1[cut1:cut2+1]

    segment = set(child[cut1:cut2+1].tolist())
    fill_genes = [gene for gene in parent2.tolist() if gene not in segment]

    idx = 0
    for i in list(range(0, cut1)) + list(range(cut2+1, N)):
        child[i] = fill_genes[idx]
        idx += 1

    return child


# ==== 变异操作：随机交换两个位置 ====
def mutate(ind, rate):
    if random.random() < rate:
        i, j = random.sample(range(len(ind)), 2)
        i_data = ind[i].item()
        j_data = ind[j].item()
        ind[i] = j_data
        ind[j] = i_data
    return ind


# # ==== 变异操作：随机交换两个位置（加距离约束） ====
# def mutate(ind, rate, min_distance=128):
#     """
#     Args:
#         ind: 一个排序（list 或 torch.Tensor）
#         rate: 变异概率
#         min_distance: 限制最小位置间距，只有当 |i-j| >= min_distance 时才交换
#     """
#     if random.random() < rate:
#         n = len(ind)
#         # 随机尝试找到满足条件的 i, j
#         for _ in range(10 * n):  # 给足尝试次数，避免死循环
#             i, j = random.sample(range(n), 2)
#             if abs(i - j) >= min_distance:
#                 i_data = ind[i].item()
#                 j_data = ind[j].item()
#                 ind[i] = j_data
#                 ind[j] = i_data
#                 break
#     return ind


def evaluate_fitness(W, bias, sorted_index, inps, outs, args, act_quant):
    inps_reorder = torch.index_select(inps, 2, sorted_index)
    W_reorder = torch.index_select(W, 1, sorted_index)
    
    inps_qunat = act_quant(inps_reorder)
    W_quant = quantize_tensor_channel_group(
        W_reorder, 
        n_bits=args.wbits,
        exponential=args.exponential, 
        sym=args.w_sym,
        group_size=args.weight_group_size[0],
        channel_group=args.weight_channel_group,
        clip_ratio=args.w_clip_ratio,
        tiling=args.tiling,
        quant_type=args.quant_type,
        quant_method=args.quant_method,
    )

    quant_outps = torch.nn.functional.linear(inps_qunat, W_quant, bias)
    return torch.nn.functional.mse_loss(quant_outps, outs).item()