import argparse
import os
import random
import numpy as np
import torch
import json
import wandb
from transformers import GPT2LMHeadModel, GPT2Config, AutoTokenizer, HfArgumentParser
import dataclasses  # Added for dataclasses.replace
import datetime  # Added for timestamp in output filename

from utils.permutation_utils import (
    get_permutations,
    generate_all_permutation_matrices,
    generate_random_permutation,
    generate_next_level_permutations_v1,
)
from loader.data_collator import PermutationExperimentDataCollator

# loader.data._load_data is not directly used, TextContinuationDataset is
from trainer.permutation_loss_logging_trainer import (
    PermutationLossLoggingTrainer,
    PermutationLossLoggingTrainingArguments,
)
from main_permutation_loss_analysis import ScriptArguments, TextContinuationDataset  # Reusing from existing script
from dataclasses import dataclass, field  # Already imported via ScriptArguments but good for clarity
from typing import Optional, List, Dict, Any, Tuple

import logging
import math
import itertools
import random
import sys

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


# Global variable to hold trainer and related objects to avoid re-initialization in every evaluation
EVAL_SETUP = {}


def sample_block_reverse(n: int, max_ops: int, num: int = 10000):
    base = list(range(n))
    seen, out = {tuple(base)}, []
    while len(out) < num:
        p = base.copy()
        k = random.randint(1, max_ops)
        for _ in range(k):
            i, j = sorted(random.sample(range(n), 2))
            p[i : j + 1] = reversed(p[i : j + 1])
        t = tuple(p)
        if t not in seen:
            seen.add(t)
            out.append(p)
    return out

import math
import random
from typing import List

def sample_block_permutations(
    L: int,
    k: int,
    n_samples: int,
    include_reverse: bool = True,
    seed: int | None = None,
) -> List[List[int]]:
    if L % k != 0:
        raise ValueError("k must divide L (L % k == 0)")

    n_blocks = L // k
    max_unique = math.factorial(n_blocks) * (2 if include_reverse else 1)
    if n_samples > max_unique:
        raise ValueError(
            f"n_samples={n_samples} exceeds the maximum possible "
            f"unique permutations ({max_unique})."
        )

    if seed is not None:
        random.seed(seed)

    forward = list(range(L))
    reverse = forward[::-1]

    def split_blocks(seq):
        return [seq[i * k:(i + 1) * k] for i in range(n_blocks)]

    f_blocks = split_blocks(forward)
    r_blocks = split_blocks(reverse)

    sampled: list[list[int]] = []
    seen: set[tuple[int, ...]] = set()

    while len(sampled) < n_samples:
        perm_indices = random.sample(range(n_blocks), n_blocks)

        use_reverse = include_reverse and random.choice((False, True))
        base_blocks = r_blocks if use_reverse else f_blocks

        seq = [idx for b in perm_indices for idx in base_blocks[b]]
        key = tuple(seq)

        if key not in seen: 
            seen.add(key)
            sampled.append(seq)

    return sampled

def sample_intrablock_permutations(
    L: int,
    k: int,
    n_samples: int,
    include_reverse: bool = True,
    seed: int | None = None,
) -> List[List[int]]:
    if L % k != 0:
        raise ValueError("k must divide L")

    n_blocks = L // k
    max_unique = pow(math.factorial(k), n_blocks) * (2 if include_reverse else 1)
    if n_samples > max_unique:
        raise ValueError(
            f"n_samples={n_samples} exceeds the maximum unique permutations "
            f"({max_unique})."
        )

    if seed is not None:
        random.seed(seed)

    # 正順・逆順の基底列
    forward = list(range(L))
    reverse = forward[::-1]

    # ブロック分割
    def split_blocks(seq):
        return [seq[i * k:(i + 1) * k] for i in range(n_blocks)]

    f_blocks_orig = split_blocks(forward)
    r_blocks_orig = split_blocks(reverse)

    sampled: list[list[int]] = []
    seen: set[tuple[int, ...]] = set()

    while len(sampled) < n_samples:
        use_rev = include_reverse and random.choice((False, True))
        base_blocks = r_blocks_orig if use_rev else f_blocks_orig

        intrablock_perm = [
            random.sample(block, k) if k > 1 else block
            for block in base_blocks
        ]

        seq = [idx for block in intrablock_perm for idx in block]
        key = tuple(seq)

        if key not in seen:
            seen.add(key)
            sampled.append(seq)

    return sampled

def create_init_permutation(num_perm: int, target_len: int) -> List[int]:
    """
    Creates a default initial permutation for M items.
    Returns a list [0, 1, 2, ..., M-1].
    """
    perms_list = []
    for i in range(num_perm):
        perm = list(range(target_len))
        random.shuffle(perm)
        perms_list.append(perm)
    return perms_list


def setup_evaluation_environment(script_args: ScriptArguments, training_args: PermutationLossLoggingTrainingArguments):
    """
    Initializes tokenizer, model config, datasets, and trainer args templates once.
    Stores them in the global EVAL_SETUP dictionary.
    """
    if EVAL_SETUP and "model_config_params" in EVAL_SETUP:
        logger.info("Evaluation environment already set up.")
        return

    logger.info("Setting up evaluation environment...")
    EVAL_SETUP["script_args"] = script_args
    EVAL_SETUP["device"] = training_args.device

    # 1. Initialize Tokenizer
    from data.tokenizers import set_tokenizer, set_vocab

    vocab = set_vocab(
        0, field="ZZ", max_coeff=500, max_degree=1, continuous_coefficient=False, continuous_exponent=False
    )
    tokenizer = set_tokenizer(vocab)
    if not hasattr(tokenizer, "unk_token") or tokenizer.unk_token is None:
        tokenizer.add_special_tokens({"unk_token": "[UNK]"})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token is not None else "[PAD]"
        if tokenizer.pad_token == "[PAD]":  # if it was newly added
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    EVAL_SETUP["tokenizer"] = tokenizer

    # 2. Store Model Configuration Parameters
    # We'll create fresh models in evaluate_permutation, so store config here.
    model_config_params = {
        "vocab_size": len(tokenizer),
        "n_positions": script_args.max_seq_length,
        "n_ctx": script_args.max_seq_length,
        "n_embd": script_args.gpt2_n_embd,
        "n_layer": script_args.gpt2_n_layer,
        "n_head": script_args.gpt2_n_head,
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
    }
    EVAL_SETUP["model_config_params"] = model_config_params
    EVAL_SETUP["target_len"] = script_args.target_len

    # 3. Load Datasets (Train and Eval)
    train_file_path = f"{script_args.dataset_path_prefix}.train"
    train_dataset = TextContinuationDataset(
        tokenizer, train_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    if not train_dataset or len(train_dataset) == 0:
        raise ValueError("Training dataset is empty or could not be loaded.")
    EVAL_SETUP["train_dataset"] = train_dataset

    eval_file_path = f"{script_args.dataset_path_prefix}.test"
    eval_dataset = TextContinuationDataset(
        tokenizer, eval_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    if not eval_dataset or len(eval_dataset) == 0:
        raise ValueError("Evaluation dataset is empty or could not be loaded.")
    EVAL_SETUP["eval_dataset"] = eval_dataset

    # 4. Generate Permutations for Internal Training
    # These are the general permutations used in the training phase within evaluate_permutation,
    # mimicking main_permutation_loss_analysis.py's training.
    if script_args.permutation_type == "all":
        training_perms_tensor = generate_all_permutation_matrices(script_args.target_len)
    elif script_args.permutation_type == "random":
        training_perms_tensor = generate_random_permutation(
            N=script_args.target_len, num_samples=script_args.permutation_select_num
        )
    elif script_args.permutation_type == "family":
        training_perms_tensor = get_permutations(
            target_len=script_args.target_len, permutation_select_num=script_args.permutation_select_num
        )
    else:  # Includes "identity" or if permutation_select_num is 0 or 1 for "family"
        logger.info(
            f"Using identity permutation or limited set for internal training based on type: {script_args.permutation_type}, num: {script_args.permutation_select_num}"
        )
        # Default to identity or whatever get_permutations returns for num=1
        training_perms_tensor = get_permutations(
            target_len=script_args.target_len, permutation_select_num=max(1, script_args.permutation_select_num)
        )

    EVAL_SETUP["training_permutations_list"] = (
        list(torch.unbind(training_perms_tensor))
        if training_perms_tensor is not None and training_perms_tensor.nelement() > 0
        else []
    )
    logger.info(f"Generated {len(EVAL_SETUP['training_permutations_list'])} permutations for internal training phases.")

    # 5. Store Training Arguments
    # Full training_args for the internal training loop
    EVAL_SETUP["full_training_args_template"] = dataclasses.replace(training_args)

    # Minimal training_args for the final evaluation step within evaluate_permutation
    # Ensure output_dir is set, even if not fully training.
    minimal_eval_output_dir = os.path.join(training_args.output_dir, "ga_internal_evals_temp")
    os.makedirs(minimal_eval_output_dir, exist_ok=True)

    minimal_eval_args_template = PermutationLossLoggingTrainingArguments(
        output_dir=minimal_eval_output_dir,
        per_device_eval_batch_size=training_args.per_device_eval_batch_size,
        dataloader_num_workers=training_args.dataloader_num_workers,
        report_to=[],
        fp16=training_args.fp16,
        remove_unused_columns=False,  # Important: keep columns like permutation_idx
        # Other necessary fields from TrainingArguments like device are handled by Trainer
    )
    EVAL_SETUP["minimal_eval_args_template"] = minimal_eval_args_template

    logger.info("Evaluation environment setup complete.")


def train_model_for_generation(
    generation_permutations_as_lists: List[List[int]], generation_num: int
) -> Tuple[GPT2LMHeadModel, int]:
    """
    Trains a new model using the permutations from the current generation.
    Returns the trained model and the number of unique permutations used for training.
    """
    logger.info(
        f"Starting model training for generation {generation_num} using {len(generation_permutations_as_lists)} permutations."
    )
    script_args: ScriptArguments = EVAL_SETUP["script_args"]
    device = EVAL_SETUP["device"]
    model_config_params = EVAL_SETUP["model_config_params"]
    tokenizer = EVAL_SETUP["tokenizer"]
    train_dataset = EVAL_SETUP["train_dataset"]
    training_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["full_training_args_template"]
    target_len = EVAL_SETUP["target_len"]
    training_args_template.dataloader_pin_memory = False  # Disable pin_memory for compatibility with CPU training

    fresh_model_config = GPT2Config(**model_config_params)
    fresh_model = GPT2LMHeadModel(fresh_model_config)
    if fresh_model.config.vocab_size != len(tokenizer):
        fresh_model.resize_token_embeddings(len(tokenizer))
    fresh_model.to(device)

    gen_train_output_dir = os.path.join(
        training_args_template.output_dir, f"gen_{generation_num}_training_{random.randint(1000,9999)}"
    )
    os.makedirs(gen_train_output_dir, exist_ok=True)

    current_train_args = dataclasses.replace(training_args_template)
    current_train_args.output_dir = gen_train_output_dir
    current_train_args.logging_dir = os.path.join(gen_train_output_dir, "logs")
    current_train_args.report_to = []
    current_train_args.remove_unused_columns = False
    # breakpoint()

    # Convert list of lists to list of tuples for hashing (to find unique ones)
    # Sort for consistent order, which might be helpful for reproducibility or debugging, though not strictly necessary for uniqueness.
    generation_perms_as_tuples = [tuple(p) for p in generation_permutations_as_lists]
    unique_perms_as_tuples = sorted(list(set(generation_perms_as_tuples)))
    unique_perms_as_lists = [list(p) for p in unique_perms_as_tuples]

    logger.info(
        f"Generation {generation_num}: Original {len(generation_permutations_as_lists)} permutations, "
        f"using {len(unique_perms_as_lists)} unique permutations for training."
    )

    unique_perms_as_tensors = perms_to_tensor_list(unique_perms_as_lists, target_len)
    train_data_collator = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=unique_perms_as_tensors,
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=True if unique_perms_as_tensors and len(unique_perms_as_tensors) > 0 else False,
    )

    trainer = PermutationLossLoggingTrainer(
        model=fresh_model,
        args=current_train_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=train_data_collator,
    )
    # breakpoint()

    try:
        logger.info(f"Training model for generation {generation_num} (output: {trainer.args.output_dir})...")
        trainer.train()
        logger.info(f"Model training completed for generation {generation_num}.")
        return trainer.model, len(unique_perms_as_lists)
    except Exception as e_train:
        logger.error(f"Error during model training for generation {generation_num}: {e_train}", exc_info=True)
        raise  # Re-raise the exception to halt GA if generation training fails


def evaluate_individual_on_trained_model(
    trained_model: GPT2LMHeadModel, individual_permutation_as_list: List[int], individual_id_str: str
) -> Tuple[float,]:
    """
    Evaluates a single individual (permutation) using a pre-trained model.
    """
    logger.debug(f"Evaluating individual {individual_id_str} ({individual_permutation_as_list}) on pre-trained model.")
    script_args: ScriptArguments = EVAL_SETUP["script_args"]
    device = EVAL_SETUP["device"]
    tokenizer = EVAL_SETUP["tokenizer"]
    eval_dataset = EVAL_SETUP["eval_dataset"]
    eval_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["minimal_eval_args_template"]
    target_len = EVAL_SETUP["target_len"]

    individual_perm_tensor = perms_to_tensor_list([individual_permutation_as_list], target_len)

    eval_data_collator_individual = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=individual_perm_tensor,
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=False,
        fixed_permutation_index=0,
    )

    current_eval_args = dataclasses.replace(eval_args_template)
    current_eval_args.dataloader_pin_memory = False
    eval_output_dir = os.path.join(
        current_eval_args.output_dir, f"ind_{individual_id_str}_eval_{random.randint(1000,9999)}"
    )
    os.makedirs(eval_output_dir, exist_ok=True)
    current_eval_args.output_dir = eval_output_dir
    current_eval_args.logging_dir = os.path.join(eval_output_dir, "logs")

    eval_trainer = PermutationLossLoggingTrainer(
        model=trained_model,
        args=current_eval_args,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=eval_data_collator_individual,
    )

    try:
        metric_key = f"ga_eval_ind_{individual_id_str}"
        metrics = eval_trainer.evaluate(metric_key_prefix=metric_key)
        loss_key_to_find = f"{metric_key}_loss"
        loss = metrics.get(loss_key_to_find)

        if loss is None:
            logger.error(f"Loss key '{loss_key_to_find}' not found. Metrics: {metrics}")
            return (float("inf"),)
        logger.debug(f"Individual {individual_id_str}, Loss: {loss}")
        return (loss,)
    except Exception as e_eval:
        logger.error(f"Error during evaluation of individual {individual_id_str}: {e_eval}", exc_info=True)
        return (float("inf"),)


# Helper function to convert list of individual permutations to list of PyTorch tensors
def perms_to_tensor_list(perms_list_of_lists: List[List[int]], target_len: int) -> List[torch.Tensor]:
    tensor_list = []
    for p_list in perms_list_of_lists:
        matrix = torch.zeros((target_len, target_len), dtype=torch.float32)
        for i, p_i in enumerate(p_list):
            matrix[i, p_i] = 1.0
        tensor_list.append(matrix)
    return tensor_list


def factorial(n):
    if n < 0:
        raise ValueError("Factorial is not defined for negative numbers")
    if n == 0:
        return 1
    return math.factorial(n)


def select_top_permutations(evaluated_results: List[Dict[str, Any]], num_to_select: int) -> List[List[int]]:
    """
    Selects the top N permutations based on loss.
    Args:
        evaluated_results (List[Dict[str, Any]]): A list of dictionaries,
            where each dict has 'permutation' (list of int) and 'loss' (float).
            This list is assumed to be sorted by 'loss' in ascending order.
        num_to_select (int): The number of top permutations to select.

    Returns:
        List[List[int]]: A list of the selected permutations (list of int).
    """
    if not evaluated_results:
        return []
    if num_to_select <= 0:
        return []

    selected_perms_data = evaluated_results[:num_to_select]

    return [item["permutation"] for item in selected_perms_data]


def hierarchical_permutation_search_v1(M, initial_elements, script_args):
    if M <= 0:
        print("M must be a positive integer.")
        return None
    if len(initial_elements) != M:
        print(f"Length of initial_elements ({len(initial_elements)}) must be equal to M ({M}).")
        return None

    N_total_permutations = factorial(M) * 2
    print(f"Starting hierarchical permutation search (v1) with M={M}")
    print(f"Total target permutations per level (N = M!): {N_total_permutations}")
    current_permutations = sample_intrablock_permutations(
        L=script_args.target_len,
        k=5,
        n_samples=N_total_permutations,
        include_reverse=True,
        seed=42,
    )
    print(f"Generated {len(current_permutations)} initial random permutations for level k=0.")

    best_permutation_overall = None
    lowest_loss_overall = float("inf")

    all_layers_data = []

    for k_layer_idx in range(1, M + 1):
        logger.info(f"\n--- H-V1 Processing Layer k = {k_layer_idx} ---")

        current_layer_data = {
            "k_layer_idx": k_layer_idx,
            "permutations_at_start_of_layer": [p[:] for p in current_permutations],  # Deep copy
            "evaluated_results": [],
            "selected_permutations": [],
            "generated_for_next_layer": [],
        }

        if not current_permutations:
            logger.info("H-V1: No permutations to process. Stopping.")
            all_layers_data.append(current_layer_data)  # Log current (empty) state
            break

        logger.info(f"H-V1 k={k_layer_idx}: Evaluating {len(current_permutations)} permutations...")

        current_generation_id_for_training = f"h_v1_k{k_layer_idx}"
        trained_model, num_unique_perms_trained = train_model_for_generation(
            current_permutations, current_generation_id_for_training
        )
        logger.info(f"H-V1 k={k_layer_idx}: Trained model using {num_unique_perms_trained} unique permutations.")

        evaluated_results_list = []
        for i, perm_list in enumerate(current_permutations):
            perm_id_str = f"h_v1_k{k_layer_idx}_perm{i}"
            loss_tuple = evaluate_individual_on_trained_model(trained_model, perm_list, perm_id_str)
            evaluated_results_list.append({"permutation": perm_list, "loss": loss_tuple[0]})

        evaluated_results_list.sort(key=lambda x: x["loss"])
        current_layer_data["evaluated_results"] = evaluated_results_list[:]  # Deep copy


        if not evaluated_results_list:
            logger.warning(f"H-V1 k={k_layer_idx}: Evaluation returned no results after processing. Stopping.")
            break

        current_best_in_layer = evaluated_results_list[0]
        logger.info(
            f"H-V1 k={k_layer_idx}: Best loss in current layer: {current_best_in_layer['loss']:.4f} for permutation {current_best_in_layer['permutation']}"
        )
        if current_best_in_layer["loss"] < lowest_loss_overall:
            lowest_loss_overall = current_best_in_layer["loss"]
            best_permutation_overall = current_best_in_layer["permutation"]
            logger.info(
                f"H-V1 k={k_layer_idx}: New overall best found: Loss {lowest_loss_overall:.4f}, Permutation {best_permutation_overall}"
            )
            if wandb.run:
                wandb.log(
                    {
                        f"h_v1_overall_best_loss_at_k{k_layer_idx}": lowest_loss_overall,
                        # "h_v1_overall_best_permutation_str": str(best_permutation_overall),
                        f"h_v1_overall_best_permutation_at_k{k_layer_idx}": best_permutation_overall,
                        "h_v1_current_k_for_best": k_layer_idx,
                    }
                )
        num_to_select_denominator_val = k_layer_idx + 1
        try:
            if num_to_select_denominator_val < 0:  # Should not happen with k_layer_idx >= 1
                raise ValueError("Factorial input must be non-negative.")
            denominator_factorial = factorial(num_to_select_denominator_val)
            if denominator_factorial == 0:  # Should not happen for non-negative input to math.factorial
                num_to_select = N_total_permutations  # Fallback: select all
                logger.warning(
                    f"H-V1 k={k_layer_idx}: Factorial({num_to_select_denominator_val}) resulted in 0. Selecting all {N_total_permutations} permutations."
                )
            else:
                num_to_select = N_total_permutations // denominator_factorial
        except ValueError as e_fact:
            logger.error(
                f"H-V1 k={k_layer_idx}: Error calculating factorial({num_to_select_denominator_val}): {e_fact}. Selecting 1."
            )
            num_to_select = 1

        if num_to_select == 0 and N_total_permutations > 0:
            num_to_select = 1
            logger.info(f"H-V1 k={k_layer_idx}: Calculated num_to_select was 0, adjusted to 1.")

        logger.info(
            f"H-V1 k={k_layer_idx}: Selecting top {num_to_select} permutations (N/({num_to_select_denominator_val})! = {N_total_permutations}/{factorial(num_to_select_denominator_val)}) out of {len(evaluated_results_list)}."
        )

        num_to_select = num_to_select // 2
        selected_permutations = select_top_permutations(evaluated_results_list, num_to_select)
        current_layer_data["selected_permutations"] = [p[:] for p in selected_permutations]  # Deep copy

        if not selected_permutations:
            logger.info(f"H-V1 k={k_layer_idx}: No permutations selected. Stopping.")
            all_layers_data.append(current_layer_data)
            break

        if wandb.run:
            wandb.log(
                {
                    f"h_v1_k{k_layer_idx}_selected_count": len(selected_permutations),
                    f"h_v1_k{k_layer_idx}_best_loss_in_layer": current_best_in_layer[
                        "loss"
                    ],  # Already logged before, but good for context here too
                    f"h_v1_k{k_layer_idx}_num_to_select_calculated": num_to_select,
                }
            )

        if k_layer_idx == M:
            logger.info(f"H-V1 k={k_layer_idx} (M). Search concludes after this layer's selection.")
            break

        num_blocks_for_generation = k_layer_idx + 1
        logger.info(
            f"H-V1 k={k_layer_idx}: Generating next level permutations from {len(selected_permutations)} selected ones. Using {num_blocks_for_generation} blocks for generation."
        )

        current_permutations = generate_next_level_permutations_v1(
            selected_permutations, num_blocks_for_generation, len(selected_permutations[0])
        )
        current_layer_data["generated_for_next_layer"] = [p[:] for p in current_permutations]  # Deep copy

        logger.info(
            f"H-V1 k={k_layer_idx}: Generated {len(current_permutations)} permutations for next layer (k={k_layer_idx+1})."
        )
        if wandb.run:
            wandb.log({f"h_v1_k{k_layer_idx}_generated_for_next_count": len(current_permutations)})

        all_layers_data.append(current_layer_data)

        if not current_permutations and k_layer_idx < M: 
            logger.warning(
                f"H-V1 k={k_layer_idx}: Generated 0 permutations for the next layer, but not at final layer. Stopping."
            )
            break

    # End of H-V1 search loop
    print("\n--- Search Finished (v1) ---")
    final_results = {
        "best_permutation_overall": best_permutation_overall,
        "lowest_loss_overall": lowest_loss_overall,
        "m_param": M,
        "initial_elements": initial_elements,
        "all_layers_data": all_layers_data,
    }

    if best_permutation_overall:
        print(f"Overall best permutation (v1): {best_permutation_overall}")
        print(f"Corresponding lowest loss (v1): {lowest_loss_overall:.4f}")
    else:
        print("No solution found or search was interrupted (v1).")

    return final_results


def main():
    # --- Argument Parsing ---
    hf_parser = HfArgumentParser((PermutationLossLoggingTrainingArguments, ScriptArguments))
    training_args, script_args = hf_parser.parse_args_into_dataclasses()

    # --- Setup Evaluation Environment (Tokenizer, Model, Data) ---
    # Critical to call this early. It populates EVAL_SETUP.
    setup_evaluation_environment(script_args, training_args)
    M_param = script_args.m_param
    initial_elements_param = [f"item_v1_{i}" for i in range(M_param)]

    print(f"Running v1 with M = {M_param}, elements = {initial_elements_param}")

    # --- Start Hierarchical Permutation Search ---
    search_results_v1 = hierarchical_permutation_search_v1(M_param, initial_elements_param, script_args)

    # --- Save Results ---
    if search_results_v1:
        output_dir = training_args.output_dir
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)

        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"hierarchical_search_v1_results_M{M_param}_{timestamp}.json"
        filepath = os.path.join(output_dir, filename)

        try:
            with open(filepath, "w") as f:
                # Convert torch tensors or other non-serializable objects to lists if any
                # For now, assuming permutations are lists of ints and loss is float
                json.dump(search_results_v1, f, indent=4)
            print(f"Successfully saved detailed results to {filepath}")
        except TypeError as e:
            print(
                f"Error saving results to JSON: {e}. Attempting to save with non-serializable items removed or converted."
            )
            # Add more sophisticated handling if necessary, e.g., custom JSON encoder
            try:
                # A simple attempt to convert potential numpy/torch objects to lists/floats
                # This is a basic safeguard; more complex objects might need specific handling
                def convert_to_serializable(obj):
                    if isinstance(obj, np.integer):
                        return int(obj)
                    elif isinstance(obj, np.floating):
                        return float(obj)
                    elif isinstance(obj, np.ndarray):
                        return obj.tolist()
                    elif isinstance(obj, torch.Tensor):
                        return obj.tolist()  # or obj.item() if scalar
                    elif isinstance(obj, dict):
                        return {k: convert_to_serializable(v) for k, v in obj.items()}
                    elif isinstance(obj, list):
                        return [convert_to_serializable(i) for i in obj]
                    return obj

                serializable_results = convert_to_serializable(search_results_v1)
                with open(filepath, "w") as f:
                    json.dump(serializable_results, f, indent=4)
                print(f"Successfully saved detailed results (with basic serialization) to {filepath}")
            except Exception as e_serial:
                print(f"Could not serialize and save results: {e_serial}")
        except Exception as e_general:
            print(f"An error occurred while saving results: {e_general}")
    else:
        print("Search did not return any results to save.")


if __name__ == "__main__":
    main()
