import argparse
import os
import random
import numpy as np
import torch
import json
import wandb
from deap import base, creator, tools
from transformers import GPT2LMHeadModel, GPT2Config, AutoTokenizer, HfArgumentParser
import dataclasses  # Added for dataclasses.replace

from utils.permutation_utils import get_permutations, generate_all_permutation_matrices, generate_random_permutation
from loader.data_collator import PermutationExperimentDataCollator

# loader.data._load_data is not directly used, TextContinuationDataset is
from trainer.permutation_loss_logging_trainer import (
    PermutationLossLoggingTrainer,
    PermutationLossLoggingTrainingArguments,
)
from main_permutation_loss_analysis import ScriptArguments, TextContinuationDataset  # Reusing from existing script
from dataclasses import dataclass, field  # Already imported via ScriptArguments but good for clarity
from typing import Optional, List, Dict, Any, Tuple

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


# DEAP Setup
# weights=(-1.0,) indicates that we are minimizing the fitness (loss)
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
# Individual is a list of integers representing the permutation, with the FitnessMin attribute
creator.create("Individual", list, fitness=creator.FitnessMin)

toolbox = base.Toolbox()

# Global variable to hold trainer and related objects to avoid re-initialization in every evaluation
EVAL_SETUP = {}


def setup_evaluation_environment(script_args: ScriptArguments, training_args: PermutationLossLoggingTrainingArguments):
    """
    Initializes tokenizer, model config, datasets, and trainer args templates once.
    Stores them in the global EVAL_SETUP dictionary.
    """
    if EVAL_SETUP and "model_config_params" in EVAL_SETUP:
        logger.info("Evaluation environment already set up.")
        return

    logger.info("Setting up evaluation environment...")
    EVAL_SETUP["script_args"] = script_args
    EVAL_SETUP["device"] = training_args.device

    # 1. Initialize Tokenizer
    from data.tokenizers import set_tokenizer, set_vocab

    vocab = set_vocab(
        0, field="ZZ", max_coeff=500, max_degree=1, continuous_coefficient=False, continuous_exponent=False
    )
    tokenizer = set_tokenizer(vocab)
    if not hasattr(tokenizer, "unk_token") or tokenizer.unk_token is None:
        tokenizer.add_special_tokens({"unk_token": "[UNK]"})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token is not None else "[PAD]"
        if tokenizer.pad_token == "[PAD]":  # if it was newly added
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    EVAL_SETUP["tokenizer"] = tokenizer

    # 2. Store Model Configuration Parameters
    # We'll create fresh models in evaluate_permutation, so store config here.
    model_config_params = {
        "vocab_size": len(tokenizer),
        "n_positions": script_args.max_seq_length,
        "n_ctx": script_args.max_seq_length,
        "n_embd": script_args.gpt2_n_embd,
        "n_layer": script_args.gpt2_n_layer,
        "n_head": script_args.gpt2_n_head,
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
    }
    EVAL_SETUP["model_config_params"] = model_config_params
    EVAL_SETUP["target_len"] = script_args.target_len

    # 3. Load Datasets (Train and Eval)
    train_file_path = f"{script_args.dataset_path_prefix}.train"
    train_dataset = TextContinuationDataset(
        tokenizer, train_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    if not train_dataset or len(train_dataset) == 0:
        raise ValueError("Training dataset is empty or could not be loaded.")
    EVAL_SETUP["train_dataset"] = train_dataset

    eval_file_path = f"{script_args.dataset_path_prefix}.test"
    eval_dataset = TextContinuationDataset(
        tokenizer, eval_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    if not eval_dataset or len(eval_dataset) == 0:
        raise ValueError("Evaluation dataset is empty or could not be loaded.")
    EVAL_SETUP["eval_dataset"] = eval_dataset

    # 4. Generate Permutations for Internal Training
    # These are the general permutations used in the training phase within evaluate_permutation,
    # mimicking main_permutation_loss_analysis.py's training.
    if script_args.permutation_type == "all":
        training_perms_tensor = generate_all_permutation_matrices(script_args.target_len)
    elif script_args.permutation_type == "random":
        training_perms_tensor = generate_random_permutation(
            N=script_args.target_len, num_samples=script_args.permutation_select_num
        )
    elif script_args.permutation_type == "family":
        training_perms_tensor = get_permutations(
            target_len=script_args.target_len, permutation_select_num=script_args.permutation_select_num
        )
    else:  # Includes "identity" or if permutation_select_num is 0 or 1 for "family"
        logger.info(
            f"Using identity permutation or limited set for internal training based on type: {script_args.permutation_type}, num: {script_args.permutation_select_num}"
        )
        # Default to identity or whatever get_permutations returns for num=1
        training_perms_tensor = get_permutations(
            target_len=script_args.target_len, permutation_select_num=max(1, script_args.permutation_select_num)
        )

    EVAL_SETUP["training_permutations_list"] = (
        list(torch.unbind(training_perms_tensor))
        if training_perms_tensor is not None and training_perms_tensor.nelement() > 0
        else []
    )
    logger.info(f"Generated {len(EVAL_SETUP['training_permutations_list'])} permutations for internal training phases.")

    # 5. Store Training Arguments
    # Full training_args for the internal training loop
    EVAL_SETUP["full_training_args_template"] = dataclasses.replace(training_args)

    # Minimal training_args for the final evaluation step within evaluate_permutation
    # Ensure output_dir is set, even if not fully training.
    minimal_eval_output_dir = os.path.join(training_args.output_dir, "ga_internal_evals_temp")
    os.makedirs(minimal_eval_output_dir, exist_ok=True)

    minimal_eval_args_template = PermutationLossLoggingTrainingArguments(
        output_dir=minimal_eval_output_dir,
        per_device_eval_batch_size=training_args.per_device_eval_batch_size,
        dataloader_num_workers=training_args.dataloader_num_workers,
        report_to=[],
        fp16=training_args.fp16,
        remove_unused_columns=False,  # Important: keep columns like permutation_idx
        # Other necessary fields from TrainingArguments like device are handled by Trainer
    )
    EVAL_SETUP["minimal_eval_args_template"] = minimal_eval_args_template

    logger.info("Evaluation environment setup complete.")


def train_model_for_generation(
    generation_permutations_as_lists: List[List[int]], generation_num: int
) -> Tuple[GPT2LMHeadModel, int]:
    """
    Trains a new model using the permutations from the current generation.
    Returns the trained model and the number of unique permutations used for training.
    """
    logger.info(
        f"Starting model training for generation {generation_num} using {len(generation_permutations_as_lists)} permutations."
    )
    script_args: ScriptArguments = EVAL_SETUP["script_args"]
    device = EVAL_SETUP["device"]
    model_config_params = EVAL_SETUP["model_config_params"]
    tokenizer = EVAL_SETUP["tokenizer"]
    train_dataset = EVAL_SETUP["train_dataset"]
    training_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["full_training_args_template"]
    target_len = EVAL_SETUP["target_len"]
    training_args_template.dataloader_pin_memory = False  # Disable pin_memory for compatibility with CPU training

    fresh_model_config = GPT2Config(**model_config_params)
    fresh_model = GPT2LMHeadModel(fresh_model_config)
    if fresh_model.config.vocab_size != len(tokenizer):
        fresh_model.resize_token_embeddings(len(tokenizer))
    fresh_model.to(device)

    gen_train_output_dir = os.path.join(
        training_args_template.output_dir, f"gen_{generation_num}_training_{random.randint(1000,9999)}"
    )
    os.makedirs(gen_train_output_dir, exist_ok=True)

    current_train_args = dataclasses.replace(training_args_template)
    current_train_args.output_dir = gen_train_output_dir
    current_train_args.logging_dir = os.path.join(gen_train_output_dir, "logs")
    current_train_args.report_to = []
    current_train_args.remove_unused_columns = False

    # Convert list of lists to list of tuples for hashing (to find unique ones)
    # Sort for consistent order, which might be helpful for reproducibility or debugging, though not strictly necessary for uniqueness.
    generation_perms_as_tuples = [tuple(p) for p in generation_permutations_as_lists]
    unique_perms_as_tuples = sorted(list(set(generation_perms_as_tuples)))
    unique_perms_as_lists = [list(p) for p in unique_perms_as_tuples]

    logger.info(
        f"Generation {generation_num}: Original {len(generation_permutations_as_lists)} permutations, "
        f"using {len(unique_perms_as_lists)} unique permutations for training."
    )

    unique_perms_as_tensors = perms_to_tensor_list(unique_perms_as_lists, target_len)
    train_data_collator = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=unique_perms_as_tensors,
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=True if unique_perms_as_tensors and len(unique_perms_as_tensors) > 0 else False,
    )

    trainer = PermutationLossLoggingTrainer(
        model=fresh_model,
        args=current_train_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=train_data_collator,
    )
    # breakpoint()

    try:
        logger.info(f"Training model for generation {generation_num} (output: {trainer.args.output_dir})...")
        trainer.train()
        logger.info(f"Model training completed for generation {generation_num}.")
        return trainer.model, len(unique_perms_as_lists)
    except Exception as e_train:
        logger.error(f"Error during model training for generation {generation_num}: {e_train}", exc_info=True)
        raise  # Re-raise the exception to halt GA if generation training fails


def evaluate_individual_on_trained_model(
    trained_model: GPT2LMHeadModel, individual_permutation_as_list: List[int], individual_id_str: str
) -> Tuple[float,]:
    """
    Evaluates a single individual (permutation) using a pre-trained model.
    """
    logger.debug(f"Evaluating individual {individual_id_str} ({individual_permutation_as_list}) on pre-trained model.")
    script_args: ScriptArguments = EVAL_SETUP["script_args"]
    device = EVAL_SETUP["device"]
    tokenizer = EVAL_SETUP["tokenizer"]
    eval_dataset = EVAL_SETUP["eval_dataset"]
    eval_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["minimal_eval_args_template"]
    target_len = EVAL_SETUP["target_len"]

    individual_perm_tensor = perms_to_tensor_list([individual_permutation_as_list], target_len)

    eval_data_collator_individual = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=individual_perm_tensor,
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=False,
        fixed_permutation_index=0,
    )

    current_eval_args = dataclasses.replace(eval_args_template)
    current_eval_args.dataloader_pin_memory = False
    eval_output_dir = os.path.join(
        current_eval_args.output_dir, f"ind_{individual_id_str}_eval_{random.randint(1000,9999)}"
    )
    os.makedirs(eval_output_dir, exist_ok=True)
    current_eval_args.output_dir = eval_output_dir
    current_eval_args.logging_dir = os.path.join(eval_output_dir, "logs")

    eval_trainer = PermutationLossLoggingTrainer(
        model=trained_model,
        args=current_eval_args,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=eval_data_collator_individual,
    )

    try:
        metric_key = f"ga_eval_ind_{individual_id_str}"
        metrics = eval_trainer.evaluate(metric_key_prefix=metric_key)
        loss_key_to_find = f"{metric_key}_loss"
        loss = metrics.get(loss_key_to_find)

        if loss is None:
            logger.error(f"Loss key '{loss_key_to_find}' not found. Metrics: {metrics}")
            return (float("inf"),)
        logger.debug(f"Individual {individual_id_str}, Loss: {loss}")
        return (loss,)
    except Exception as e_eval:
        logger.error(f"Error during evaluation of individual {individual_id_str}: {e_eval}", exc_info=True)
        return (float("inf"),)


# Helper function to convert list of individual permutations to list of PyTorch tensors
def perms_to_tensor_list(perms_list_of_lists: List[List[int]], target_len: int) -> List[torch.Tensor]:
    tensor_list = []
    for p_list in perms_list_of_lists:
        matrix = torch.zeros((target_len, target_len), dtype=torch.float32)
        for i, p_i in enumerate(p_list):
            matrix[i, p_i] = 1.0
        tensor_list.append(matrix)
    return tensor_list


def create_individual_permutation():
    target_len = EVAL_SETUP.get("target_len")
    if target_len is None:
        # This case should ideally be prevented by ensuring setup_evaluation_environment is called first.
        raise ValueError(
            "target_len not set in EVAL_SETUP. Call setup_evaluation_environment before toolbox registration."
        )
    ind = list(range(target_len))
    random.shuffle(ind)
    return creator.Individual(ind)


def main():
    # --- Argument Parsing ---
    hf_parser = HfArgumentParser((PermutationLossLoggingTrainingArguments, ScriptArguments, GaArguments))
    training_args, script_args, ga_args = hf_parser.parse_args_into_dataclasses()

    # --- Setup Evaluation Environment (Tokenizer, Model, Data) ---
    # Critical to call this early. It populates EVAL_SETUP.
    setup_evaluation_environment(script_args, training_args)

    # --- WandB Setup ---
    if ga_args.wandb_project_ga:  # Assuming GaArguments has wandb_project_ga, etc.
        run_name = (
            ga_args.wandb_run_name_ga
            if ga_args.wandb_run_name_ga
            else f"ga_opt_{script_args.dataset_name}_tl{script_args.target_len}_pop{ga_args.population_size}_gen{ga_args.num_generations}"
        )
        wandb.init(
            project=ga_args.wandb_project_ga,
            entity=(
                ga_args.wandb_entity_ga if hasattr(ga_args, "wandb_entity_ga") else None
            ),  # Add entity if defined in GaArgs
            name=run_name,
            config={**vars(script_args), **vars(ga_args), **vars(training_args)},
        )
        logger.info(f"Initialized wandb for project '{ga_args.wandb_project_ga}', run '{run_name}'")
    else:
        logger.info("wandb_project_ga not set, wandb logging for GA disabled.")

    # --- Seed for reproducibility ---
    # Ensure GaArguments uses a distinct seed name e.g. ga_seed to avoid clash if ScriptArguments also has 'seed'
    seed_value = ga_args.ga_seed if hasattr(ga_args, "ga_seed") else 42
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
    logger.info(f"Set random seed to {seed_value}")

    # --- DEAP Toolbox Registration ---
    # Ensure EVAL_SETUP is populated before these registrations if they depend on it (create_individual_permutation does)
    toolbox.register("individual", create_individual_permutation)
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)
    # Note: toolbox.evaluate is NOT registered here as it depends on the generation's trained model.

    toolbox.register("mate", tools.cxOrdered)
    toolbox.register("mutate", tools.mutShuffleIndexes, indpb=ga_args.mutation_indpb)
    toolbox.register("select", tools.selTournament, tournsize=ga_args.tournament_size)

    # --- Genetic Algorithm Execution ---
    population = toolbox.population(n=ga_args.population_size)

    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)

    hof = tools.HallOfFame(1)

    logger.info(
        f"Starting custom GA with population_size={ga_args.population_size}, num_generations={ga_args.num_generations}"
    )
    logger.info(f"Initial Crossover P_CX={ga_args.initial_crossover_prob}, Final P_CX={ga_args.final_crossover_prob}")
    logger.info(f"Initial Mutation P_MUT={ga_args.initial_mutation_prob}, Final P_MUT={ga_args.final_mutation_prob}")
    logger.info(
        f"Decay starts at gen {ga_args.decay_start_generation}, ends at {ga_args.decay_end_generation_ratio*100}% of total generations."
    )

    logbook = tools.Logbook()
    logbook.header = (
        ["gen", "nevals"] + (stats.fields if stats else []) + ["crossover_prob", "mutation_prob"]
    )  # Added probs to header

    # Calculate absolute generation for decay end
    decay_end_generation_absolute = int(ga_args.num_generations * ga_args.decay_end_generation_ratio)

    # --- Custom GA Loop ---
    # Initial population evaluation (Generation 0)
    logger.info("Generation 0: Training initial model and evaluating population...")
    current_crossover_prob_g0 = linear_decay(
        0,
        ga_args.num_generations,
        ga_args.decay_start_generation,
        decay_end_generation_absolute,
        ga_args.initial_crossover_prob,
        ga_args.final_crossover_prob,
    )
    current_mutation_prob_g0 = linear_decay(
        0,
        ga_args.num_generations,
        ga_args.decay_start_generation,
        decay_end_generation_absolute,
        ga_args.initial_mutation_prob,
        ga_args.final_mutation_prob,
    )
    logger.info(
        f"Generation 0: Effective Crossover P_CX={current_crossover_prob_g0:.4f}, Mutation P_MUT={current_mutation_prob_g0:.4f}"
    )

    initial_perms_as_lists = [list(ind) for ind in population]
    gen0_trained_model, num_unique_perms_gen0 = train_model_for_generation(initial_perms_as_lists, 0)

    fitnesses = []
    for i, ind in enumerate(population):
        ind_id_str = f"gen0_ind{i}"
        fitnesses.append(evaluate_individual_on_trained_model(gen0_trained_model, list(ind), ind_id_str))
    for ind, fit in zip(population, fitnesses):
        ind.fitness.values = fit

    hof.update(population)
    record = stats.compile(population)
    logbook.record(
        gen=0,
        nevals=len(population),
        crossover_prob=current_crossover_prob_g0,
        mutation_prob=current_mutation_prob_g0,
        **record,
    )
    if wandb.run:
        wandb.log(
            {
                "generation": 0,
                "nevals": len(population),
                "num_unique_permutations_in_training": num_unique_perms_gen0,
                "current_crossover_prob": current_crossover_prob_g0,
                "current_mutation_prob": current_mutation_prob_g0,
                **{f"ga_stats_{k}": v for k, v in record.items()},
            }
        )
    logger.info(logbook.stream)

    # Generational process
    for gen in range(1, ga_args.num_generations + 1):
        current_crossover_prob = linear_decay(
            gen,
            ga_args.num_generations,
            ga_args.decay_start_generation,
            decay_end_generation_absolute,
            ga_args.initial_crossover_prob,
            ga_args.final_crossover_prob,
        )
        current_mutation_prob = linear_decay(
            gen,
            ga_args.num_generations,
            ga_args.decay_start_generation,
            decay_end_generation_absolute,
            ga_args.initial_mutation_prob,
            ga_args.final_mutation_prob,
        )
        logger.info(
            f"Generation {gen}: Effective Crossover P_CX={current_crossover_prob:.4f}, Mutation P_MUT={current_mutation_prob:.4f}"
        )
        logger.info(f"Generation {gen}: Starting selection, crossover, mutation...")
        offspring = toolbox.select(population, len(population))
        offspring = list(map(toolbox.clone, offspring))

        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < current_crossover_prob:  # Use decayed probability
                toolbox.mate(child1, child2)
                del child1.fitness.values
                del child2.fitness.values

        for mutant in offspring:
            if random.random() < current_mutation_prob:  # Use decayed probability
                toolbox.mutate(mutant)
                del mutant.fitness.values

        logger.info(f"Generation {gen}: Training generation-specific model...")
        current_gen_perms_as_lists = [list(ind) for ind in offspring]
        current_gen_trained_model, num_unique_perms_current_gen = train_model_for_generation(
            current_gen_perms_as_lists, gen
        )

        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        logger.info(f"Generation {gen}: Evaluating {len(invalid_ind)} new individuals...")
        fitnesses_offspring = []
        for i, ind_offspring in enumerate(invalid_ind):
            offspring_id_str = f"gen{gen}_ind{i}"
            fitnesses_offspring.append(
                evaluate_individual_on_trained_model(current_gen_trained_model, list(ind_offspring), offspring_id_str)
            )
        for ind_offspring, fit_offspring in zip(invalid_ind, fitnesses_offspring):
            ind_offspring.fitness.values = fit_offspring

        population[:] = offspring
        hof.update(population)
        record = stats.compile(population)
        logbook.record(
            gen=gen,
            nevals=len(invalid_ind),
            crossover_prob=current_crossover_prob,
            mutation_prob=current_mutation_prob,
            **record,
        )
        if wandb.run:
            wandb.log(
                {
                    "generation": gen,
                    "nevals": len(invalid_ind),
                    "num_unique_permutations_in_training": num_unique_perms_current_gen,
                    "current_crossover_prob": current_crossover_prob,
                    "current_mutation_prob": current_mutation_prob,
                    **{f"ga_stats_{k}": v for k, v in record.items()},
                }
            )
        logger.info(logbook.stream)

    # --- Results ---
    best_individual = hof[0]
    best_fitness = best_individual.fitness.values[0]

    logger.info(f"Best individual found: {list(best_individual)}")
    logger.info(f"Best fitness (loss): {best_fitness}")

    # --- Process all individuals in the final population ---
    final_population_ranked = []
    for ind in population:
        final_population_ranked.append({"permutation": list(ind), "fitness": ind.fitness.values[0]})

    # Sort by fitness (ascending, as lower loss is better)
    final_population_ranked.sort(key=lambda x: x["fitness"])

    logger.info("Final population ranked by fitness (top 5):")
    for i, entry in enumerate(final_population_ranked[:5]):
        logger.info(f"  Rank {i+1}: Permutation: {entry['permutation']}, Fitness: {entry['fitness']:.4f}")

    results = {
        "best_individual_permutation": list(best_individual),
        "best_fitness_loss": best_fitness,
        "final_population_ranked": final_population_ranked,  # Add ranked population
        "ga_args": vars(ga_args),
        "script_args": vars(script_args),
    }

    # Process logbook for JSON and WandB
    processed_logbook = []
    for gen_data in logbook:
        entry = {k: (v.values[0] if isinstance(v, base.Fitness) else v) for k, v in gen_data.items()}
        processed_logbook.append(entry)
    results["logbook"] = processed_logbook

    # Save results to JSON
    output_dir_main = training_args.output_dir  # Main output dir from original args
    if not os.path.exists(output_dir_main):
        os.makedirs(output_dir_main)

    results_file_path = os.path.join(output_dir_main, f"ga_optimization_results_tl{script_args.target_len}.json")
    try:
        with open(results_file_path, "w") as f:
            json.dump(results, f, indent=4)
        logger.info(f"Results saved to {results_file_path}")
    except TypeError as e:
        logger.error(f"Failed to serialize results to JSON: {e}. Results were: {results}")

    if wandb.run:
        wandb.log(
            {
                "best_permutation_list": list(best_individual),  # wandb prefers simple lists/values
                "best_final_loss": best_fitness,
                # Log top N individuals or a summary if the full list is too large
                "top_5_final_permutations": [
                    {"permutation": entry["permutation"], "fitness": entry["fitness"]}
                    for entry in final_population_ranked[:5]
                ],
                "full_final_population_summary_fitnesses": [
                    entry["fitness"] for entry in final_population_ranked
                ],  # Log all fitness values
            }
        )
        for record in processed_logbook:
            wandb_log_record = {}
            for key, value in record.items():
                wandb_log_record[f"ga_{key}"] = value
            wandb.log(wandb_log_record)

        if os.path.exists(results_file_path):
            wandb.save(results_file_path)
        wandb.finish()

    logger.info("Custom GA optimization finished.")


def linear_decay(
    current_gen: int,
    total_gens: int,
    decay_start_gen: int,
    decay_end_gen_abs: int,  # Absolute generation number for decay end
    initial_val: float,
    final_val: float,
) -> float:
    if current_gen < decay_start_gen:
        return initial_val
    # If decay_end_gen_abs is before or at decay_start_gen,
    # then any current_gen >= decay_start_gen should immediately be final_val.
    # Also, if current_gen is at or beyond decay_end_gen_abs, it's final_val.
    if current_gen >= decay_end_gen_abs or decay_end_gen_abs <= decay_start_gen:
        return final_val

    # total_decay_span is the number of generations over which decay occurs
    total_decay_span = float(decay_end_gen_abs - decay_start_gen)
    # progress is how far along the decay period current_gen is
    progress = (current_gen - decay_start_gen) / total_decay_span

    return initial_val - progress * (initial_val - final_val)


@dataclass
class GaArguments:
    """
    Arguments pertaining to the Genetic Algorithm.
    """

    population_size: int = field(
        default=10, metadata={"help": "Number of individuals in the population."}
    )  # Reduced for faster testing
    num_generations: int = field(
        default=3, metadata={"help": "Number of generations to run the GA."}
    )  # Reduced further for very slow new eval
    initial_crossover_prob: float = field(
        default=0.7, metadata={"help": "Initial probability of mating two individuals."}
    )  # Renamed from crossover_prob
    initial_mutation_prob: float = field(
        default=0.2, metadata={"help": "Initial probability of mutating an individual."}
    )  # Renamed from mutation_prob
    final_crossover_prob: float = field(default=0.1, metadata={"help": "Final crossover probability after decay."})
    final_mutation_prob: float = field(default=0.05, metadata={"help": "Final mutation probability after decay."})
    decay_start_generation: int = field(default=0, metadata={"help": "Generation at which probability decay starts."})
    decay_end_generation_ratio: float = field(
        default=0.8,
        metadata={
            "help": "Ratio of total generations at which decay ends and final probability is reached. e.g. 0.8 means at 80% of total generations."
        },
    )
    mutation_indpb: float = field(
        default=0.05,
        metadata={"help": "Independent probability for each attribute to be mutated if mutation_prob hits."},
    )
    tournament_size: int = field(default=3, metadata={"help": "Size of the tournament for selection."})
    ga_seed: int = field(default=42, metadata={"help": "Random seed for GA reproducibility."})  # Renamed to ga_seed

    wandb_project_ga: Optional[str] = field(
        default="ga_perm_opt_gen_trained_v3", metadata={"help": "WandB project name for GA runs."}
    )
    wandb_entity_ga: Optional[str] = field(default=None, metadata={"help": "WandB entity for GA runs."})  # Added entity
    wandb_run_name_ga: Optional[str] = field(default=None, metadata={"help": "WandB run name for this GA run."})


if __name__ == "__main__":
    main()
