import numpy as np
import geatpy as ea

"""================ ensure Chrom has at most M errors ==================="""
def repair_variable(variable, M):
    """
    - Chrom: numpy.ndarray of shape (num_chromosomes, N), where each row is a chromosome.
    - M: int, the exact number of errors each chromosome should have.
    """
    num_chromosomes, N = variable.shape
    variable_repaired = variable.copy()
    errors_mask = variable_repaired != 0  # Shape: (num_chromosomes, N)
    num_errors = errors_mask.sum(axis=1)  # Shape: (num_chromosomes,)
    over_error_indices = np.where(num_errors > M)[0]
    zero_error_indices = np.where(num_errors == 0)[0]
    
    # Repair chromosomes with too many errors
    for idx in over_error_indices:
        chromosome = variable_repaired[idx]
        error_positions = np.where(chromosome != 0)[0]
        num_to_zero = num_errors[idx] - M
        # Randomly select positions to set to zero
        positions_to_zero = np.random.choice(error_positions, size=num_to_zero, replace=False)
        chromosome[positions_to_zero] = 0
        # Update the chromosome in the repaired population
        variable_repaired[idx] = chromosome
    
    # Repair chromosomes with zero errors
    for idx in zero_error_indices:
        chromosome = variable_repaired[idx]
        # Randomly select positions to add errors
        positions_to_add = np.random.choice(N, size=M, replace=False)
        # Assign random error types (1, 2, or 3) to these positions
        random_error_types = np.random.randint(1, 4, size=M)
        chromosome[positions_to_add] = random_error_types
        # Update the chromosome in the repaired population
        variable_repaired[idx] = chromosome
    
    return variable_repaired


def OneGen_task(N, M, K, NIND, selS, recS, mutS, FieldD, model, inputs, targets, aim, curChrom, pc=0.8, Encoding ='BG', cm_ratio=(0.3, 0.6)):
    # cm_ratio = (0.3, 0.6) is the default ratio of crossover and mutation

    maxormins = np.array([-1])
    Lind =int(np.sum(FieldD[0, :]))
    pm= 2/Lind

    # evaluate the current populations
    Phen = ea.bs2ri(curChrom, FieldD)
    ObjV = []
    for i in range(len(Phen)):
        loss = aim(Phen[i], model, inputs, targets)
        ObjV.append(loss)
    ObjV = np.array(ObjV).reshape(-1,1)
    best_ind = np.argsort(ObjV.reshape(-1))[-1]
    best_pop = Phen[best_ind]

    FitnV = ea.ranking(maxormins * ObjV)
    N_c = int(NIND * cm_ratio[0])
    N_m = NIND - N_c - K
    SelCh_c = curChrom[ea.selecting(selS,FitnV, N_c),:]
    SelCh_c = ea.recombin(recS, SelCh_c, pc)
    SelCh_m = curChrom[ea.selecting(selS,FitnV, N_m),:]
    SelCh_m = ea.mutate(mutS, Encoding, SelCh_m, pm)
    # repair the chromosome
    rep_SelCh_c = ea.ri2bs(repair_variable(ea.bs2ri(SelCh_c, FieldD), M), FieldD)
    rep_SelCh_m = ea.ri2bs(repair_variable(ea.bs2ri(SelCh_m, FieldD), M), FieldD)
    elit_ind = ea.selecting(selS,FitnV,K)
    curChrom = np.vstack([curChrom[elit_ind, :], rep_SelCh_c, rep_SelCh_m])
    
    return curChrom, best_pop


