import torch
from qdit.quant import quantize_tensor_channel_group

import random

# ==== 全随机初始化种群 ====
def initialize_population(num_channels, pop_size, device):
    return [torch.randperm(num_channels).to(device) for _ in range(pop_size)]


# ==== 锦标赛选择 ====
def tournament_selection(population, fitnesses, k):
    selected = random.sample(list(zip(population, fitnesses)), k)
    selected.sort(key=lambda x: x[1])
    return selected[0][0].clone()


# ==== Order Crossover (OX) ====
def order_crossover(parent1: torch.Tensor, parent2: torch.Tensor):
    parent1 = parent1.to(dtype=torch.int64)
    parent2 = parent2.to(dtype=torch.int64)
        
    N = parent1.size(0)
    cut1, cut2 = sorted(random.sample(range(N), 2))

    child = -torch.ones(N, dtype=torch.int64).to(parent1.device)  # Fix: make sure dtype can hold -1
    child[cut1:cut2+1] = parent1[cut1:cut2+1]

    segment = set(child[cut1:cut2+1].tolist())
    fill_genes = [gene for gene in parent2.tolist() if gene not in segment]

    idx = 0
    for i in list(range(0, cut1)) + list(range(cut2+1, N)):
        child[i] = fill_genes[idx]
        idx += 1

    return child


# ==== 变异操作：随机交换两个位置 ====
def mutate(ind, rate):
    if random.random() < rate:
        i, j = random.sample(range(len(ind)), 2)
        i_data = ind[i].item()
        j_data = ind[j].item()
        ind[i] = j_data
        ind[j] = i_data
        # ind[i], ind[j] = ind[j], ind[i]
    return ind


def evaluate_fitness(W, bias, sorted_index, inps, outs, args, act_quant):
    inps_reorder = torch.index_select(inps, 2, sorted_index)
    W_reorder = torch.index_select(W, 1, sorted_index)
    
    inps_qunat = act_quant(inps_reorder)
    W_quant = quantize_tensor_channel_group(
        W_reorder, 
        n_bits=args['n_bits'],
        sym=args['sym'],
        group_size=args['group_size'],
        clip_ratio=1.0,
        tiling=0,
        quant_type='int'
    )

    quant_outps = torch.nn.functional.linear(inps_qunat, W_quant, bias)
    return torch.nn.functional.mse_loss(quant_outps, outs).item()