import argparse
import os
import random
import numpy as np
import torch
import json
import wandb
from transformers import GPT2LMHeadModel, GPT2Config, AutoTokenizer, HfArgumentParser
import dataclasses
import datetime

from utils.permutation_utils import (
    get_permutations,
    generate_all_permutation_matrices,
    generate_random_permutation,
)
from loader.data_collator import PermutationExperimentDataCollator

from trainer.permutation_loss_logging_trainer import (
    PermutationLossLoggingTrainer,
    PermutationLossLoggingTrainingArguments,
)
from main_permutation_loss_analysis import ScriptArguments, TextContinuationDataset  # Reusing from existing script
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Tuple

import logging
import math
import itertools  # For permutations
import types
import inspect

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


# Global variable to hold trainer and related objects
EVAL_SETUP = {}


def wrapped_depth(fn):
    depth = 0
    while hasattr(fn, "__wrapped__"):
        depth += 1
        fn = fn.__wrapped__
    return depth


def strip_forward_wrappers(model):
    """
    Remove *all* functools / Accelerate wrappers from model.forward.
    Returns number of wrappers removed.
    Ensures the final forward is **bound** to the model instance.
    """
    depth = 0
    fn = model.forward

    # unwrap chain
    while hasattr(fn, "__wrapped__"):
        fn = fn.__wrapped__
        depth += 1

    # if we ended on an *unbound* function, bind it
    if not hasattr(fn, "__self__") or fn.__self__ is None:
        fn = types.MethodType(fn, model)

    # replace only if anything changed
    if depth:
        model.forward = fn

    return depth


def create_identity_permutation(target_len: int) -> List[int]:
    return list(range(target_len))


def setup_evaluation_environment(script_args: ScriptArguments, training_args: PermutationLossLoggingTrainingArguments):
    """
    Initializes tokenizer, model config, datasets, and trainer args templates once.
    (Largely same as v3/main_permutation_loss_analysis.py)
    """
    if EVAL_SETUP and "model_config_params" in EVAL_SETUP:
        logger.info("Evaluation environment already set up.")
        return

    logger.info("Setting up evaluation environment...")
    EVAL_SETUP["script_args"] = script_args
    EVAL_SETUP["device"] = training_args.device  # Note: Trainer itself handles device placement based on args.

    from data.tokenizers import set_tokenizer, set_vocab

    vocab = set_vocab(
        0, field="ZZ", max_coeff=500, max_degree=1, continuous_coefficient=False, continuous_exponent=False
    )
    tokenizer = set_tokenizer(vocab)
    if not hasattr(tokenizer, "unk_token") or tokenizer.unk_token is None:
        tokenizer.add_special_tokens({"unk_token": "[UNK]"})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token is not None else "[PAD]"
        if tokenizer.pad_token == "[PAD]":  # if it was newly added
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    EVAL_SETUP["tokenizer"] = tokenizer

    model_config_params = {
        "vocab_size": len(tokenizer),
        "n_positions": script_args.max_seq_length,
        "n_ctx": script_args.max_seq_length,
        "n_embd": script_args.gpt2_n_embd,
        "n_layer": script_args.gpt2_n_layer,
        "n_head": script_args.gpt2_n_head,
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
    }
    EVAL_SETUP["model_config_params"] = model_config_params
    EVAL_SETUP["target_len"] = script_args.target_len  # N

    train_file_path = f"{script_args.dataset_path_prefix}.train"
    train_dataset = TextContinuationDataset(
        tokenizer, train_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    if not train_dataset or len(train_dataset) == 0:
        raise ValueError("Training dataset is empty or could not be loaded.")
    EVAL_SETUP["train_dataset"] = train_dataset

    eval_file_path = f"{script_args.dataset_path_prefix}.test"
    eval_dataset = TextContinuationDataset(
        tokenizer, eval_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    if not eval_dataset or len(eval_dataset) == 0:
        raise ValueError("Evaluation dataset is empty or could not be loaded.")
    EVAL_SETUP["eval_dataset"] = eval_dataset

    EVAL_SETUP["full_training_args_template"] = dataclasses.replace(training_args)

    # Adjusted output_dir name for clarity
    minimal_eval_output_dir = os.path.join(training_args.output_dir, "h_v5_internal_evals_temp")
    os.makedirs(minimal_eval_output_dir, exist_ok=True)

    minimal_eval_args_template = PermutationLossLoggingTrainingArguments(
        output_dir=minimal_eval_output_dir,
        per_device_eval_batch_size=training_args.per_device_eval_batch_size,
        dataloader_num_workers=training_args.dataloader_num_workers,
        report_to=[],  # Disable WandB for these internal, frequent evaluations
        fp16=training_args.fp16,
        remove_unused_columns=False,
    )
    EVAL_SETUP["minimal_eval_args_template"] = minimal_eval_args_template
    logger.info("Evaluation environment setup complete.")


def train_model_for_generation(
    generation_permutations_as_lists: List[List[int]], generation_num_str: str
) -> Tuple[GPT2LMHeadModel, int]:
    """
    Trains a new model using the permutations from the current step/generation.
    Returns the trained model and the number of unique permutations used for training.
    (Largely same as v3, generation_num is now a string for better logging)
    """
    logger.info(
        f"Starting model training for step/gen '{generation_num_str}' using {len(generation_permutations_as_lists)} permutations."
    )
    script_args: ScriptArguments = EVAL_SETUP["script_args"]
    model_config_params = EVAL_SETUP["model_config_params"]
    tokenizer = EVAL_SETUP["tokenizer"]
    train_dataset = EVAL_SETUP["train_dataset"]
    training_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["full_training_args_template"]
    target_len = EVAL_SETUP["target_len"]
    training_args_template.dataloader_pin_memory = False  # Often good to disable if issues arise

    fresh_model_config = GPT2Config(**model_config_params)
    fresh_model = GPT2LMHeadModel(fresh_model_config)
    if fresh_model.config.vocab_size != len(tokenizer):  # Ensure vocab size matches
        fresh_model.resize_token_embeddings(len(tokenizer))

    # Use a more descriptive output directory for training related to a specific generation/step
    gen_train_output_dir = os.path.join(
        training_args_template.output_dir, f"train_step_{generation_num_str}_{random.randint(1000,9999)}"
    )
    os.makedirs(gen_train_output_dir, exist_ok=True)

    current_train_args = dataclasses.replace(training_args_template)
    current_train_args.output_dir = gen_train_output_dir
    current_train_args.logging_dir = os.path.join(gen_train_output_dir, "logs")  # Separate logs for this training run
    current_train_args.report_to = []  # Avoid nested WandB runs if main run exists and logs there
    current_train_args.remove_unused_columns = False

    # Convert list of lists to list of tuples for hashing (to find unique ones)
    generation_perms_as_tuples = [tuple(p) for p in generation_permutations_as_lists]
    unique_perms_as_tuples = sorted(list(set(generation_perms_as_tuples)))
    unique_perms_as_lists = [list(p) for p in unique_perms_as_tuples]

    logger.info(
        f"Step/Gen '{generation_num_str}': Original {len(generation_permutations_as_lists)} perms, "
        f"using {len(unique_perms_as_lists)} unique for training."
    )

    if not unique_perms_as_lists:  # Handle case with no permutations to train on
        logger.warning(
            f"Step/Gen '{generation_num_str}': No unique permutations to train on. Returning un-trained model."
        )
        return fresh_model, 0  # Return a fresh (but not trained) model

    unique_perms_as_tensors = perms_to_tensor_list(unique_perms_as_lists, target_len)
    train_data_collator = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=unique_perms_as_tensors,
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=True if unique_perms_as_tensors and len(unique_perms_as_tensors) > 0 else False,
    )

    trainer = PermutationLossLoggingTrainer(
        model=fresh_model,
        args=current_train_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=train_data_collator,
    )

    try:
        logger.info(f"Training model for step/gen '{generation_num_str}' (output: {trainer.args.output_dir})...")
        trainer.train()
        logger.info(f"Model training completed for step/gen '{generation_num_str}'.")
        unwrapped_model = trainer.accelerator.unwrap_model(trainer.model)
        unwrapped_model.to(trainer.args.device)
        # breakpoint()
        return unwrapped_model, len(unique_perms_as_lists)
    except Exception as e_train:
        # breakpoint()
        logger.error(f"Error during model training for step/gen '{generation_num_str}': {e_train}", exc_info=True)
        raise  # Re-raise the exception to halt if training fails


def evaluate_individual_on_trained_model(
    trained_model: GPT2LMHeadModel,
    individual_permutation_as_list: List[int],
    individual_id_str: str,
) -> Tuple[float,]:
    """
    Evaluates a single individual (permutation) using a pre-trained model.
    (Largely same as v3)
    """
    logger.debug(f"Evaluating individual {individual_id_str} ({individual_permutation_as_list}) on pre-trained model.")
    script_args: ScriptArguments = EVAL_SETUP["script_args"]
    tokenizer = EVAL_SETUP["tokenizer"]
    eval_dataset = EVAL_SETUP["eval_dataset"]
    eval_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["minimal_eval_args_template"]
    target_len = EVAL_SETUP["target_len"]

    individual_perm_tensor = perms_to_tensor_list([individual_permutation_as_list], target_len)

    eval_data_collator_individual = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=individual_perm_tensor,  # Should be a list containing one tensor
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=False,
        fixed_permutation_index=0,  # Evaluate with this specific permutation
    )

    current_eval_args = dataclasses.replace(eval_args_template)
    current_eval_args.dataloader_pin_memory = False
    # Unique output dir for each evaluation to avoid conflicts
    eval_output_dir = os.path.join(
        current_eval_args.output_dir, f"eval_{individual_id_str}_{random.randint(1000,9999)}"
    )
    os.makedirs(eval_output_dir, exist_ok=True)
    current_eval_args.output_dir = eval_output_dir
    current_eval_args.logging_dir = os.path.join(eval_output_dir, "logs")
    current_eval_args.report_to = []  # Ensure no sub-WandB runs here
    eval_trainer = PermutationLossLoggingTrainer(
        model=trained_model,
        args=current_eval_args,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=eval_data_collator_individual,
    )

    try:
        # Use a clear metric key prefix for this specific evaluation
        metric_key = f"h_v5_eval_ind_{individual_id_str}"
        metrics = eval_trainer.evaluate(metric_key_prefix=metric_key)

        # More robust way to find the loss value.
        # 1. Try the default key from trainer.evaluate().
        loss = metrics.get("eval_loss")
        trained_model = eval_trainer.accelerator.unwrap_model(eval_trainer.model)
        strip_forward_wrappers(trained_model)  # Ensure we have the original forward method
        # breakpoint()
        # 2. Fallback for custom-named loss keys, e.g., "h_v5_eval_ind_..._loss"
        if loss is None:
            loss_key_to_find = f"{metric_key}_loss"
            loss = metrics.get(loss_key_to_find)

        # 3. Generic fallback: find any key ending with '_loss'
        if loss is None:
            logger.warning(f"'eval_loss' and custom loss keys not found. Metrics: {metrics}. Trying generic fallback.")
            for k_metric in metrics:
                if k_metric.endswith("_loss"):
                    loss = metrics[k_metric]
                    logger.info(f"Found alternative loss key: {k_metric} with value {loss}")
                    break

        if loss is None:
            logger.error(f"No loss key could be found in metrics: {metrics}. Returning inf.")
            return (float("inf"),)

        logger.debug(f"Individual {individual_id_str}, Loss: {loss:.4f}")
        return (loss,)
    except Exception as e_eval:
        logger.error(f"Error during evaluation of individual {individual_id_str}: {e_eval}", exc_info=True)
        breakpoint()
        return (float("inf"),)


def perms_to_tensor_list(perms_list_of_lists: List[List[int]], target_len: int) -> List[torch.Tensor]:
    """Converts list of permutations (as lists of ints) to list of permutation matrices (PyTorch tensors)."""
    tensor_list = []
    for p_list in perms_list_of_lists:
        if len(p_list) != target_len:
            raise ValueError(f"Permutation length {len(p_list)} does not match target_len {target_len}")
        matrix = torch.zeros((target_len, target_len), dtype=torch.float32)
        for i, p_i in enumerate(p_list):  # p_list[i] is the new position of element i
            if not (0 <= p_i < target_len):
                raise ValueError(f"Permutation element {p_i} at index {i} is out of bounds for target_len {target_len}")
            matrix[i, p_i] = (
                1.0  # This seems to be the standard way permutation matrices are made: row i, col p_list[i]
            )
        tensor_list.append(matrix)
    return tensor_list


# --- Helper functions for v5 (Original versions - will be superseded by V2) ---
def get_block_from_permutation(permutation: List[int], block_idx: int, block_length: int) -> List[int]:
    """Extracts a specific block from a permutation."""
    start_index = block_idx * block_length
    end_index = start_index + block_length
    return permutation[start_index:end_index]


def generate_inter_block_permutations(
    base_permutation: List[int], block_length: int, num_elements: int  # N
) -> List[List[int]]:
    """
    Generates all permutations by permuting the order of blocks,
    keeping the internal order of elements within each block fixed as per base_permutation.
    block_length is k. num_elements is N.
    """
    num_blocks = num_elements // block_length
    if num_elements % block_length != 0 or num_blocks <= 1:  # No non-trivial block permutations for 1 or 0 blocks
        return [list(base_permutation)]

    # Extract all blocks from the base_permutation, maintaining their internal order
    original_blocks_with_content = [
        get_block_from_permutation(base_permutation, i, block_length) for i in range(num_blocks)
    ]

    block_indices = list(range(num_blocks))  # Indices [0, 1, ..., num_blocks-1]

    generated_full_permutations = []

    # Generate all (N/k)! permutations of the block_indices
    for p_block_indices_tuple in itertools.permutations(block_indices):  # Permute the order of block indices
        p_block_indices_list = list(p_block_indices_tuple)  # e.g., [1, 0, 2] if num_blocks=3

        new_full_perm = []
        # Reconstruct the full permutation by appending blocks in the new permuted order
        for original_block_idx_in_new_order in p_block_indices_list:
            new_full_perm.extend(original_blocks_with_content[original_block_idx_in_new_order])
        generated_full_permutations.append(new_full_perm)

    return generated_full_permutations


# --- End Helper functions for v5 (Original versions) ---


# --- Helper functions for v5 (New V2 versions for flexible block sizes) ---
def get_block_from_permutation_v2(permutation: List[int], block_idx: int, actual_block_lengths: List[int]) -> List[int]:
    """Extracts a specific block from a permutation using a list of actual block lengths."""
    if block_idx < 0 or block_idx >= len(actual_block_lengths):
        raise ValueError(f"block_idx {block_idx} is out of bounds for {len(actual_block_lengths)} blocks.")

    start_index = sum(actual_block_lengths[:block_idx])
    current_block_length = actual_block_lengths[block_idx]
    end_index = start_index + current_block_length
    return permutation[start_index:end_index]


def set_block_in_permutation_v2(
    base_permutation: List[int], block_idx: int, sub_permutation_block: List[int], actual_block_lengths: List[int]
) -> List[int]:
    """Creates a new permutation with the content of a specific block replaced, using actual_block_lengths."""
    if block_idx < 0 or block_idx >= len(actual_block_lengths):
        raise ValueError(f"block_idx {block_idx} is out of bounds for {len(actual_block_lengths)} blocks.")

    new_permutation = list(base_permutation)
    start_index = sum(actual_block_lengths[:block_idx])
    current_block_expected_length = actual_block_lengths[block_idx]

    if len(sub_permutation_block) != current_block_expected_length:
        raise ValueError(
            f"Sub-permutation length {len(sub_permutation_block)} must match expected block length {current_block_expected_length} for block_idx {block_idx}"
        )

    for i in range(current_block_expected_length):
        new_permutation[start_index + i] = sub_permutation_block[i]
    return new_permutation


def generate_intra_block_permutations_v2(
    base_permutation: List[int], block_idx_to_permute: int, actual_block_lengths: List[int]
) -> List[List[int]]:
    """
    Generates permutations by permuting elements ONLY within a specific block (block_idx_to_permute),
    using actual_block_lengths. Other blocks remain fixed.
    """
    if not actual_block_lengths or block_idx_to_permute < 0 or block_idx_to_permute >= len(actual_block_lengths):
        logger.error(
            f"Invalid block_idx_to_permute {block_idx_to_permute} for actual_block_lengths: {actual_block_lengths}"
        )
        return [list(base_permutation)]  # or raise error

    current_block_actual_length = actual_block_lengths[block_idx_to_permute]
    if current_block_actual_length <= 1:
        return [list(base_permutation)]

    target_block_content = get_block_from_permutation_v2(base_permutation, block_idx_to_permute, actual_block_lengths)

    generated_full_permutations = []
    for p_block_content_tuple in itertools.permutations(target_block_content):
        p_block_content_list = list(p_block_content_tuple)
        new_full_perm = set_block_in_permutation_v2(
            base_permutation, block_idx_to_permute, p_block_content_list, actual_block_lengths
        )
        generated_full_permutations.append(new_full_perm)

    return generated_full_permutations


def generate_inter_block_permutations_v2(
    base_permutation: List[int], actual_block_lengths: List[int]
) -> List[List[int]]:
    """
    Generates permutations by permuting the order of blocks, using actual_block_lengths.
    The internal order of elements within each block is fixed as per base_permutation.
    """
    num_actual_blocks = len(actual_block_lengths)
    if num_actual_blocks <= 1:
        return [list(base_permutation)]

    original_blocks_with_content = [
        get_block_from_permutation_v2(base_permutation, i, actual_block_lengths) for i in range(num_actual_blocks)
    ]

    block_indices = list(range(num_actual_blocks))
    generated_full_permutations = []

    for p_block_indices_tuple in itertools.permutations(block_indices):
        p_block_indices_list = list(p_block_indices_tuple)
        new_full_perm = []
        for original_block_idx_in_new_order in p_block_indices_list:
            new_full_perm.extend(original_blocks_with_content[original_block_idx_in_new_order])
        generated_full_permutations.append(new_full_perm)

    return generated_full_permutations


# --- End Helper functions for v5 (V2 versions) ---


def hierarchical_permutation_search_v5(
    initial_permutation_list: List[int],
    target_len: int,  # N (also script_args.target_len)
    script_args: ScriptArguments,  # training_args are accessed via EVAL_SETUP
):
    logger.info(
        f"Starting hierarchical permutation search (v5) with initial_permutation={initial_permutation_list}, N (target_len)={target_len}"
    )

    current_base_permutation = list(initial_permutation_list)  # This is the permutation being optimized
    best_permutation_overall = list(initial_permutation_list)  # Tracks the best perm found across all k

    # Initial evaluation of the starting permutation
    # Train a model just with this single permutation to get a baseline loss
    temp_model_for_initial_eval, num_unique_init_train = train_model_for_generation(
        generation_permutations_as_lists=[current_base_permutation], generation_num_str="h_v5_initial_train"
    )
    if num_unique_init_train == 0 and len(current_base_permutation) > 0:
        logger.warning("Initial training (v5) did not use any permutations. Evaluation might be on an untrained model.")

    initial_eval_loss_tuple = evaluate_individual_on_trained_model(
        trained_model=temp_model_for_initial_eval,
        individual_permutation_as_list=current_base_permutation,
        individual_id_str="h_v5_initial_eval",
    )
    lowest_loss_overall = (
        initial_eval_loss_tuple[0]
        if initial_eval_loss_tuple and initial_eval_loss_tuple[0] is not None
        else float("inf")
    )

    logger.info(f"H-V5 Initial: Permutation {current_base_permutation}, Loss: {lowest_loss_overall:.4f}")
    if wandb.run:
        wandb.log(
            {
                "h_v5_initial_loss": lowest_loss_overall,
                "h_v5_initial_permutation_str": str(current_base_permutation),  # Log as string
            }
        )

    all_layers_data = []  # To store detailed logs for each k

    # k_intended_length iterates from 2 up to N/2 (inclusive)
    # This k is the "intended" or "minimum" length for most blocks.
    for k_intended_length in range(2, (target_len // 2) + 1):
        # for k_intended_length in range(2, 8):
        # breakpoint()
        # Calculate actual block lengths based on k_intended_length and target_len (N)

        if k_intended_length <= 0:
            logger.warning(f"H-V5: k_intended_length {k_intended_length} is invalid. Skipping.")
            continue
        if target_len <= 0:
            logger.warning(f"H-V5: target_len {target_len} is invalid. Skipping k loop.")
            break

        # New logic for actual_block_lengths
        actual_block_lengths = []
        if target_len > 0 and k_intended_length > 0:
            num_blocks_of_k_length = target_len // k_intended_length
            remainder = target_len % k_intended_length

            for _ in range(num_blocks_of_k_length):
                actual_block_lengths.append(k_intended_length)

            if remainder > 0:
                actual_block_lengths.append(remainder)

            # If the k_intended_length loop somehow allows k_intended_length > target_len
            # (e.g., for very small target_len where target_len // 2 < 2),
            # and num_blocks_of_k_length becomes 0, actual_block_lengths would be [remainder], which is target_len.
            # This forms a single block of target_len, which is a reasonable fallback.
            # However, if after this, actual_block_lengths is still empty for target_len > 0, it's an issue.
            if not actual_block_lengths and target_len > 0:
                # This case should ideally be covered if k_intended_length is always > 0.
                # If target_len > 0, and k_intended_length > 0, then either num_blocks_of_k_length >=1 or remainder > 0 (or both if k < N)
                # If k > N, num_blocks_of_k_length = 0, remainder = N, so blocks = [N]
                # So, this condition (empty blocks for N>0) should not be met if k_intended_length > 0.
                # For safety, if it implies one block was intended:
                logger.warning(
                    f"H-V5 k_intended={k_intended_length}: Calculated empty blocks for target_len={target_len}. Defaulting to one block."
                )
                actual_block_lengths = [target_len]

        # Sanity checks for block lengths
        if sum(actual_block_lengths) != target_len:
            # This check is crucial. If target_len is 0, sum should be 0 for actual_block_lengths = [].
            logger.error(
                f"H-V5 k_intended={k_intended_length}: Sum of calculated block lengths {sum(actual_block_lengths)} ({actual_block_lengths}) does not match target_len {target_len}. Skipping this k."
            )
            continue

        # if len(set(actual_block_lengths)) != 1:
        #     continue

        num_actual_total_blocks = len(actual_block_lengths)

        if num_actual_total_blocks == 0 and target_len > 0:
            logger.error(
                f"H-V5 k_intended={k_intended_length}: num_actual_total_blocks is 0 for N={target_len} > 0. This implies an issue in block calculation. Skipping k."
            )
            continue
        if num_actual_total_blocks == 0 and target_len == 0:  # No blocks, no work for this k
            logger.info(f"H-V5 k_intended={k_intended_length}: No blocks as N=0. Skipping k processing for this k.")
            continue
        if (
            num_actual_total_blocks == 1
            and target_len > 0
            and target_len <= k_intended_length
            and k_intended_length > (target_len // 2)
        ):
            # If k_intended_length is large (e.g. k=3, N=2 or N=3 leading to single block [2] or [3])
            # and this k is outside the primary loop range (e.g. N=3, k_loop is empty, but if k was passed directly)
            # or if k simply results in one block (e.g. N=5, k=4 or k=5 -> block [5])
            # Skip if this k results in only one block and k is not a small factor defining multiple blocks.
            # The main loop is for k up to N/2, so k resulting in 1 block implies k > N/2.
            # This prevents trying to permute a single block with itself or do inter-block with one block.
            if num_actual_total_blocks == 1 and k_intended_length > (target_len // 2) and target_len > 1:
                logger.info(
                    f"H-V5 k_intended={k_intended_length}: Results in a single block for N={target_len} and k > N/2. Skipping this k as trivial."
                )
                continue

        logger.info(
            f"\n--- H-V5 Processing Layer k_intended_length = {k_intended_length}, num_actual_blocks = {num_actual_total_blocks}, actual_lengths = {actual_block_lengths} ---"
        )

        current_k_layer_data_log = {
            "k_intended_length": k_intended_length,
            "num_actual_blocks": num_actual_total_blocks,
            "actual_lengths": actual_block_lengths,
            "permutation_at_start_of_k_layer": list(current_base_permutation),  # Perm before this k's optimization
            "intra_block_optimization_details": [],
            "inter_block_optimization_details": {},
            "permutation_after_k_layer": [],  # Perm after this k's optimization
            "loss_after_k_layer": float("inf"),
        }

        # --- 1. Intra-block optimization phase ---
        logger.info(f"H-V5 k={k_intended_length}: Starting Intra-block optimization phase.")
        # Start this phase with the current best permutation from previous k or initial
        perm_being_optimized_in_intra_phase = list(current_base_permutation)

        for block_idx in range(num_actual_total_blocks):  # Iterate through each block to optimize its internal order
            intra_block_step_log = {"block_idx": block_idx, "status": "pending"}
            logger.info(
                f"H-V5 k={k_intended_length}: Intra-block optimizing block {block_idx+1}/{num_actual_total_blocks}."
            )
            if actual_block_lengths[block_idx] <= 1:
                logger.info(
                    f"H-V5 k={k_intended_length}, block {block_idx}: Block length is {actual_block_lengths[block_idx]}, skipping intra-block optimization."
                )
                continue  # No optimization needed for blocks of length 1 or less

            # Generate k! candidate permutations by permuting only the current block_idx
            # The base for this generation is the perm_being_optimized_in_intra_phase, which updates after each block
            candidate_perms_for_this_block_step = generate_intra_block_permutations_v2(
                base_permutation=perm_being_optimized_in_intra_phase,
                block_idx_to_permute=block_idx,
                actual_block_lengths=actual_block_lengths,
            )

            if not candidate_perms_for_this_block_step:
                logger.warning(
                    f"H-V5 k={k_intended_length}, block {block_idx}: No candidates for intra-block step. Skipping this block's optimization."
                )
                intra_block_step_log.update({"status": "skipped_no_candidates", "num_candidates": 0})
                current_k_layer_data_log["intra_block_optimization_details"].append(intra_block_step_log)
                continue  # perm_being_optimized_in_intra_phase remains unchanged for this block

            intra_block_step_log["num_candidates"] = len(candidate_perms_for_this_block_step)
            logger.info(
                f"H-V5 k={k_intended_length}, block {block_idx}: Generated {len(candidate_perms_for_this_block_step)} candidates for this block."
            )

            # Train a model using these k! (or fewer if k is large and list is pre-filtered) candidates
            train_id_intra_step = f"h_v5_k{k_intended_length}_intra_b{block_idx}_train"
            trained_model_intra_step, num_unique_trained_intra_step = train_model_for_generation(
                candidate_perms_for_this_block_step, train_id_intra_step
            )
            intra_block_step_log["num_unique_trained_on"] = num_unique_trained_intra_step
            if num_unique_trained_intra_step == 0 and len(candidate_perms_for_this_block_step) > 0:
                logger.warning(
                    f"H-V5 k={k_intended_length}, block {block_idx}: Intra-block training step used 0 unique perms."
                )

            # Evaluate each of the k! candidate permutations using the just-trained model
            evaluated_results_intra_this_block_step = []
            for i, p_candidate in enumerate(candidate_perms_for_this_block_step):
                eval_id_intra_step = f"h_v5_k{k_intended_length}_intra_b{block_idx}_eval_p{i}"
                loss_tuple = evaluate_individual_on_trained_model(
                    trained_model_intra_step, p_candidate, eval_id_intra_step
                )
                evaluated_results_intra_this_block_step.append(
                    {"permutation": p_candidate, "loss": loss_tuple[0] if loss_tuple else float("inf")}
                )
                # breakpoint()

            if not evaluated_results_intra_this_block_step:
                logger.warning(
                    f"H-V5 k={k_intended_length}, block {block_idx}: No evaluation results for intra-block step. Block not updated."
                )
                intra_block_step_log.update({"status": "skipped_no_eval_results"})
                current_k_layer_data_log["intra_block_optimization_details"].append(intra_block_step_log)
                continue

            evaluated_results_intra_this_block_step.sort(key=lambda x: x["loss"])
            best_full_perm_after_this_block_opt = evaluated_results_intra_this_block_step[0]["permutation"]
            best_loss_for_this_block_step = evaluated_results_intra_this_block_step[0]["loss"]

            # IMPORTANT: Update perm_being_optimized_in_intra_phase with the best configuration found for this block.
            # This new configuration becomes the base for optimizing the next block.
            perm_being_optimized_in_intra_phase = list(best_full_perm_after_this_block_opt)

            logger.info(
                f"H-V5 k={k_intended_length}, block {block_idx}: Intra-block step best loss {best_loss_for_this_block_step:.4f}. Updated base perm for next block/phase."
            )
            intra_block_step_log.update(
                {
                    "status": "optimized",
                    "best_loss_for_block_step": best_loss_for_this_block_step,
                    "perm_config_after_block_opt": list(perm_being_optimized_in_intra_phase),  # Log the full perm
                }
            )
            current_k_layer_data_log["intra_block_optimization_details"].append(intra_block_step_log)

        logger.info(
            f"H-V5 k={k_intended_length}: Intra-block optimization phase finished. Resulting perm: {perm_being_optimized_in_intra_phase}"
        )
        # perm_after_intra_block_opt_phase is now the permutation with all its blocks internally optimized.

        # --- 2. Inter-block optimization phase ---
        logger.info(
            f"H-V5 k={k_intended_length}: Starting Inter-block optimization phase using perm from intra-block: {perm_being_optimized_in_intra_phase}"
        )
        inter_block_phase_log = {"status": "pending"}

        # Generate (N/k)! candidate permutations by reordering the blocks of perm_being_optimized_in_intra_phase
        candidate_perms_for_inter_block_phase = generate_inter_block_permutations_v2(
            base_permutation=perm_being_optimized_in_intra_phase, actual_block_lengths=actual_block_lengths
        )
        # breakpoint()

        if not candidate_perms_for_inter_block_phase:
            logger.warning(
                f"H-V5 k={k_intended_length}: No candidates for inter-block phase. Permutation remains {perm_being_optimized_in_intra_phase}."
            )
            # current_base_permutation for next k layer will be the result from intra-block phase
            current_base_permutation = list(perm_being_optimized_in_intra_phase)
            inter_block_phase_log.update({"status": "skipped_no_candidates", "num_candidates": 0})
            # Need to evaluate current_base_permutation to get a loss for this k-layer if we skipped inter-block eval
            # For simplicity, we might just use the loss from the last step of intra-block, or re-evaluate.
            # Let's try to re-evaluate current_base_permutation to have a consistent loss for the layer.
            temp_model_for_layer_eval, _ = train_model_for_generation(
                [current_base_permutation], f"h_v5_k{k_intended_length}_layer_eval_train_skip"
            )
            layer_loss_tuple = evaluate_individual_on_trained_model(
                temp_model_for_layer_eval, current_base_permutation, f"h_v5_k{k_intended_length}_layer_eval_skip"
            )
            current_k_layer_data_log["loss_after_k_layer"] = layer_loss_tuple[0] if layer_loss_tuple else float("inf")

        else:
            inter_block_phase_log["num_candidates"] = len(candidate_perms_for_inter_block_phase)
            logger.info(
                f"H-V5 k={k_intended_length}: Generated {len(candidate_perms_for_inter_block_phase)} candidates for inter-block phase."
            )

            train_id_inter_phase = f"h_v5_k{k_intended_length}_inter_train"
            trained_model_inter_phase, num_unique_trained_inter_phase = train_model_for_generation(
                candidate_perms_for_inter_block_phase, train_id_inter_phase
            )
            inter_block_phase_log["num_unique_trained_on"] = num_unique_trained_inter_phase
            if num_unique_trained_inter_phase == 0 and len(candidate_perms_for_inter_block_phase) > 0:
                logger.warning(f"H-V5 k={k_intended_length}: Inter-block training phase used 0 unique perms.")

            evaluated_results_inter_block_phase = []
            for i, p_candidate in enumerate(candidate_perms_for_inter_block_phase):
                eval_id_inter_phase = f"h_v5_k{k_intended_length}_inter_eval_p{i}"
                loss_tuple = evaluate_individual_on_trained_model(
                    trained_model_inter_phase, p_candidate, eval_id_inter_phase
                )
                # breakpoint()
                evaluated_results_inter_block_phase.append(
                    {"permutation": p_candidate, "loss": loss_tuple[0] if loss_tuple else float("inf")}
                )

            if not evaluated_results_inter_block_phase:
                logger.warning(
                    f"H-V5 k={k_intended_length}: No evaluation results for inter-block phase. Permutation remains from intra-block phase."
                )
                current_base_permutation = list(perm_being_optimized_in_intra_phase)
                inter_block_phase_log.update({"status": "skipped_no_eval_results"})
                # Similar to above, evaluate current_base_permutation for layer loss
                temp_model_for_layer_eval, _ = train_model_for_generation(
                    [current_base_permutation], f"h_v5_k{k_intended_length}_layer_eval_train_nores"
                )
                layer_loss_tuple = evaluate_individual_on_trained_model(
                    temp_model_for_layer_eval, current_base_permutation, f"h_v5_k{k_intended_length}_layer_eval_nores"
                )
                current_k_layer_data_log["loss_after_k_layer"] = (
                    layer_loss_tuple[0] if layer_loss_tuple else float("inf")
                )

            else:
                evaluated_results_inter_block_phase.sort(key=lambda x: x["loss"])
                best_perm_after_inter_block_phase = evaluated_results_inter_block_phase[0]["permutation"]
                best_loss_after_inter_block_phase = evaluated_results_inter_block_phase[0]["loss"]

                # This is the final permutation for this k_intended_length layer, to be used for next k or as final result
                current_base_permutation = list(best_perm_after_inter_block_phase)
                current_k_layer_data_log["loss_after_k_layer"] = best_loss_after_inter_block_phase

                logger.info(
                    f"H-V5 k={k_intended_length}: Inter-block phase finished. Best loss {best_loss_after_inter_block_phase:.4f}."
                )
                inter_block_phase_log.update(
                    {
                        "status": "optimized",
                        "best_loss_inter_block_phase": best_loss_after_inter_block_phase,
                    }
                )

        inter_block_phase_log["perm_after_inter_block_phase"] = list(
            current_base_permutation
        )  # Log the result of inter-block
        current_k_layer_data_log["inter_block_optimization_details"] = inter_block_phase_log
        current_k_layer_data_log["permutation_after_k_layer"] = list(current_base_permutation)

        all_layers_data.append(current_k_layer_data_log)

        # Update overall best permutation and loss found so far
        loss_for_this_k_layer = current_k_layer_data_log["loss_after_k_layer"]
        if loss_for_this_k_layer < lowest_loss_overall:
            lowest_loss_overall = loss_for_this_k_layer
            best_permutation_overall = list(
                current_base_permutation
            )  # current_base_permutation is the best after this k layer
            logger.info(
                f"H-V5 k={k_intended_length}: New overall best found: Loss {lowest_loss_overall:.4f}, Permutation {best_permutation_overall}"
            )
            if wandb.run:
                wandb.log(
                    {
                        "h_v5_overall_best_loss": lowest_loss_overall,
                        "h_v5_overall_best_permutation_str": str(best_permutation_overall),
                        "h_v5_k_for_overall_best": k_intended_length,
                    }
                )

        if wandb.run:  # Log details for the completed k-layer
            wandb.log(
                {
                    f"h_v5_k{k_intended_length}_loss_after_layer": loss_for_this_k_layer,
                    f"h_v5_k{k_intended_length}_perm_after_layer_str": str(current_base_permutation),
                    # f"h_v5_k{k_intended_length}_num_intra_candidates_total": sum(d.get("num_candidates",0) for d in current_k_layer_data_log["intra_block_optimization_details"]),
                    # f"h_v5_k{k_intended_length}_num_inter_candidates": current_k_layer_data_log["inter_block_optimization_details"].get("num_candidates",0),
                }
            )

    # End of H-V5 search loop for k_intended_length
    logger.info("\n--- Hierarchical Search V5 Finished ---")
    final_results_summary = {
        "search_version": "v5",
        "initial_permutation": initial_permutation_list,
        "best_permutation_overall": best_permutation_overall,
        "lowest_loss_overall": lowest_loss_overall,
        "target_len_N": target_len,
        "all_k_layers_data": all_layers_data,  # Contains detailed logs for each k
    }

    if best_permutation_overall:
        logger.info(f"Overall best permutation (v5): {best_permutation_overall}")
        logger.info(f"Corresponding lowest loss (v5): {lowest_loss_overall:.4f}")
    else:
        logger.warning("No solution found or search was interrupted (v5). Best permutation might be initial.")

    return final_results_summary


def main():
    hf_parser = HfArgumentParser((PermutationLossLoggingTrainingArguments, ScriptArguments))
    training_args, script_args = hf_parser.parse_args_into_dataclasses()

    # --- Setup ---
    setup_evaluation_environment(script_args, training_args)  # Populates EVAL_SETUP
    target_len_N_for_v5 = script_args.target_len  # N for v5 algorithm

    # Initial permutation for v5
    # initial_permutation = list(range(target_len_N_for_v5))
    # random.shuffle(initial_permutation)
    # logger.info(f"Using initial random permutation for v5 search: {initial_permutation}")
    # initial_permutation = [5, 8, 9, 7, 6, 4, 3, 2, 1, 0]
    # initial_permutation = [3,4,5,6,7,8,9,1,0,2]
    # initial_permutation = [6, 7, 8, 9, 5, 4, 3, 2, 1, 0]

    # initial_permutation = [6, 0, 5, 2, 3, 4, 1]  # n=7, relu
    # initial_permutation = [0, 3, 2, 1, 4, 5, 6, 7] # n=8, relu
    # initial_permutation = [0, 7, 6, 5, 4, 2, 3, 1, 8] # n=9, relu
    # initial_permutation = [8, 9, 10, 7, 6, 5, 4, 3, 2, 1, 0]  # n=11, relu
    # initial_permutation = [6, 7, 8, 9, 10, 11, 5, 4, 2, 3, 1, 0]  # n=12, relu
    # initial_permutation = [11, 12, 10, 9, 8, 7, 6, 5, 4, 2, 3, 1, 0]  # n=13, relu
    # initial_permutation = [1, 2, 4, 5, 0, 6, 7, 3]  # n=8, square mod
    # initial_permutation = [10, 11, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5]  # n=12, square mod
    # initial_permutation = [0, 1, 2, 3, 12, 11, 10, 4, 5, 6, 7, 8, 9] # n=13, square mod
    # initial_permutation = [0, 1, 5, 6, 7, 8, 9, 2, 3, 4, 10, 11, 12, 13] # n=14, square mod

    # initial_permutation = [0, 1, 7, 6, 4, 2, 5, 8, 3, 9, 10, 11, 12]  # n=13, m=4, index
    initial_permutation = [1, 2, 3, 4, 5, 6, 7, 8, 10, 9, 12, 0, 11]  # n=13, m=8, index

    # --- WandB Initialization (if configured) ---
    if script_args.wandb_project:
        # Construct a descriptive run name for WandB
        timestamp_str = datetime.datetime.now().strftime("%y%m%d_%H%M")
        run_name = (
            script_args.wandb_run_name if script_args.wandb_run_name else f"v5_N{target_len_N_for_v5}_{timestamp_str}"
        )

        # Merge configs for logging
        config_for_wandb = {
            **dataclasses.asdict(script_args),
            **dataclasses.asdict(training_args),
            "search_algorithm_version": "v5",
            "initial_permutation_used": str(initial_permutation),  # Log initial perm as string
        }

        wandb.init(
            project=script_args.wandb_project,
            name=run_name,
            config=config_for_wandb,
        )
        logger.info(f"WandB initialized for project '{script_args.wandb_project}' and run '{run_name}'.")

    logger.info(f"Starting Hierarchical Permutation Search v5 with N (target_len) = {target_len_N_for_v5}")

    # --- Execute v5 Search ---
    search_results_v5 = hierarchical_permutation_search_v5(
        initial_permutation_list=initial_permutation,
        target_len=target_len_N_for_v5,
        script_args=script_args,
        # training_args are accessed via EVAL_SETUP where needed
    )

    # --- Save Results ---
    if search_results_v5:
        output_dir = training_args.output_dir  # Main output directory from training_args
        if not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)

        timestamp_file = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        # Filename reflects it's v5 and includes N (target_len)
        filename = f"hierarchical_search_v5_N{target_len_N_for_v5}_results_{timestamp_file}.json"
        filepath = os.path.join(output_dir, filename)

        try:
            # Helper for making results JSON serializable (handles numpy types, etc.)
            def convert_to_serializable(obj):
                if isinstance(obj, np.integer):
                    return int(obj)
                if isinstance(obj, np.floating):
                    return float(obj)
                if isinstance(obj, np.ndarray):
                    return obj.tolist()
                # torch.Tensor should ideally be converted to lists/numbers before this stage
                if isinstance(obj, torch.Tensor):
                    logger.warning("Serializing torch.Tensor found in results; should be list/float.")
                    return obj.tolist()
                if isinstance(obj, dict):
                    return {k: convert_to_serializable(v) for k, v in obj.items()}
                if isinstance(obj, list):
                    return [convert_to_serializable(i) for i in obj]
                if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
                    return dataclasses.asdict(obj)
                return obj  # Default for other types

            serializable_results = convert_to_serializable(search_results_v5)
            with open(filepath, "w") as f:
                json.dump(serializable_results, f, indent=4)
            logger.info(f"Successfully saved detailed results of v5 search to {filepath}")

            if wandb.run:  # If WandB is active, try to save the results file there too
                try:
                    # Use base_path to store it relative to the run's directory structure in WandB
                    wandb.save(filepath, base_path=output_dir, policy="now")
                    logger.info(f"Saved results file {filepath} to WandB.")
                except Exception as e_wandb_save:
                    logger.error(f"Failed to save results file {filepath} to WandB: {e_wandb_save}")

        except Exception as e_serial:
            logger.error(f"Could not serialize and save v5 results to {filepath}: {e_serial}", exc_info=True)
    else:
        logger.warning("Search v5 did not produce any results to save.")

    if wandb.run:  # Ensure WandB run is finished if it was initialized
        wandb.finish()


if __name__ == "__main__":
    main()
