import os
import pandas as pd
from pathlib import Path
import pprint

# ==============================================================================
# Part 1: Configuration (User Settings)
# ==============================================================================
# --- (REQUIRED) Set this to the base output directory of the first script ---
BASE_RESULTS_DIR = "./output/stage_1_sweep_results"

# --- (OPTIONAL) Configure the output directory and other parameters ---
OUTPUT_DIR = "./output/stage_2_generated_configs"
NUM_TOP_LAYERS = 3
# This should match the hyperparameter string from the first script's output folders
RUN_PARAMS_SUBDIR = "lr_1e-03_bs_8_rank_16"

# ==============================================================================
# Part 2: The "Scout" - File Discovery Utilities
# ==============================================================================

def find_final_analysis_csv(base_path: Path) -> Path | None:
    """Finds the analysis CSV from the last epoch in a directory."""
    analysis_dir = base_path / "analysis"
    if not analysis_dir.exists():
        print(f"  [Warning] Analysis directory not found in: {base_path}")
        return None

    csv_files = list(analysis_dir.glob("epoch_*_analysis.csv"))
    if not csv_files:
        print(f"  [Warning] No analysis CSVs found in: {analysis_dir}")
        return None

    # Find the latest epoch by parsing the filename
    latest_epoch = -1
    latest_file = None
    for f in csv_files:
        try:
            epoch_num = int(f.stem.split("_")[1])
            if epoch_num > latest_epoch:
                latest_epoch = epoch_num
                latest_file = f
        except (ValueError, IndexError):
            continue

    return latest_file

# ==============================================================================
# Part 3: The "Analyst" - Data Processing Utilities
# ==============================================================================

def get_ranked_layers_from_csv(csv_path: Path) -> list[int]:
    """Reads a CSV and returns a list of layer numbers ranked by norm."""
    df = pd.read_csv(csv_path)
    df_sorted = df.sort_values(by="total_frobenius_norm", ascending=False)
    return df_sorted["layer_num"].tolist()

def create_norm_lookup_table(mlp_csv: Path, attn_csv: Path) -> dict:
    """Creates a lookup table for individual layer norms."""
    lookup = {"mlp": {}, "attn": {}}

    df_mlp = pd.read_csv(mlp_csv)
    for _, row in df_mlp.iterrows():
        lookup["mlp"][row["layer_num"]] = row["total_frobenius_norm"]

    df_attn = pd.read_csv(attn_csv)
    for _, row in df_attn.iterrows():
        lookup["attn"][row["layer_num"]] = row["total_frobenius_norm"]

    return lookup

def find_top_consecutive_block(csv_path: Path, window_size: int) -> list[int] | None:
    """Finds the consecutive block of layers with the highest summed norm."""
    df = pd.read_csv(csv_path).sort_values(by="layer_num")
    if len(df) < window_size:
        return None

    max_norm_sum = -1
    best_block_start_index = -1

    for i in range(len(df) - window_size + 1):
        window = df.iloc[i : i + window_size]
        current_sum = window["total_frobenius_norm"].sum()
        if current_sum > max_norm_sum:
            max_norm_sum = current_sum
            best_block_start_index = i

    top_block_df = df.iloc[best_block_start_index : best_block_start_index + window_size]
    return top_block_df["layer_num"].tolist()

# ==============================================================================
# Part 4: The "Architect" - Experiment Generation Logic
# ==============================================================================

def generate_track1_configs(top_mlp_layers: list, top_attn_layers: list, n: int) -> dict:
    """Generates Soloist and Ensemble experiments."""
    if not top_mlp_layers or not top_attn_layers:
        return {}

    configs = {}

    # --- Top 1 Experiments ---
    l_mlp1, l_attn1 = top_mlp_layers[0], top_attn_layers[0]
    configs[f"Track1_Solo_MLP_L{l_mlp1}"] = [f"mlp_{l_mlp1}"]
    configs[f"Track1_Solo_Attn_L{l_attn1}"] = [f"attn_{l_attn1}"]
    configs[f"Track1_Duet_MLP{l_mlp1}_Attn{l_attn1}"] = [f"mlp_{l_mlp1}", f"attn_{l_attn1}"]

    # --- Top N Ensemble Experiments ---
    top_n_mlp = top_mlp_layers[:n]
    top_n_attn = top_attn_layers[:n]
    configs[f"Track1_Ensemble_Top{n}_MLP"] = [f"mlp_{l}" for l in top_n_mlp]
    configs[f"Track1_Ensemble_Top{n}_Attn"] = [f"attn_{l}" for l in top_n_attn]
    configs[f"Track1_Ensemble_Top{n}_Combined"] = configs[f"Track1_Ensemble_Top{n}_MLP"] + \
                                                  configs[f"Track1_Ensemble_Top{n}_Attn"]
    return configs

def _decide_circuit_link(layer: int, norm_lookup: dict, max_layer: int) -> list[str]:
    """Helper to decide the best MLP->Attn circuit link for a given layer."""
    # Score A: MLP_{L-1} -> Attn_L
    score_a = -1
    if layer > 0:
        norm_mlp_prev = norm_lookup["mlp"].get(layer - 1, 0)
        norm_attn_curr = norm_lookup["attn"].get(layer, 0)
        score_a = norm_mlp_prev + norm_attn_curr

    # Score B: MLP_L -> Attn_{L+1}
    score_b = -1
    if layer < max_layer:
        norm_mlp_curr = norm_lookup["mlp"].get(layer, 0)
        norm_attn_next = norm_lookup["attn"].get(layer + 1, 0)
        score_b = norm_mlp_curr + norm_attn_next

    if score_a > score_b:
        return [f"mlp_{layer-1}", f"attn_{layer}"]
    elif score_b > -1:
        return [f"mlp_{layer}", f"attn_{layer+1}"]
    else: # Edge case if no valid circuit can be formed
        return []

def generate_track2_configs(both_csv: Path, norm_lookup: dict, n: int) -> dict:
    """Generates Circuit-Based experiments."""
    configs = {}

    # Step A: Identify Hotspots
    top_overall_layers = get_ranked_layers_from_csv(both_csv)[:n]
    top_consecutive_block = find_top_consecutive_block(both_csv, window_size=n)
    max_layer_num = max(norm_lookup['mlp'].keys())

    # Step C: Architect Experiments
    # 1. Single-Link Circuit
    top_layer = top_overall_layers[0]
    best_link = _decide_circuit_link(top_layer, norm_lookup, max_layer_num)
    if best_link:
        configs[f"Track2_SingleCircuit_L{top_layer}"] = best_link

    # 2. Chained-Block Circuit
    if top_consecutive_block:
        chain = []
        # Chain starts from one layer before the block to feed into it
        start_layer = top_consecutive_block[0] - 1
        for i in range(n + 1):
            layer = start_layer + i
            if layer < max_layer_num:
                chain.extend([f"mlp_{layer}", f"attn_{layer+1}"])
        # Remove duplicates while preserving order
        configs[f"Track2_ChainedCircuit_L{top_consecutive_block[0]}-{top_consecutive_block[-1]}"] = sorted(list(set(chain)), key = chain.index)


    # 3. Distributed Top-N Circuit
    distributed_circuit = []
    for layer in top_overall_layers:
        distributed_circuit.extend(_decide_circuit_link(layer, norm_lookup, max_layer_num))
    if distributed_circuit:
        # Remove duplicates while preserving order
        unique_dist_circuit = sorted(list(set(distributed_circuit)), key = distributed_circuit.index)
        layer_nums = sorted([int(p.split("_")[1]) for p in unique_dist_circuit])
        configs[f"Track2_DistributedCircuit_Top{n}_Layers"] = unique_dist_circuit

    return configs

# ==============================================================================
# Part 5: The Orchestrator
# ==============================================================================

def main():
    """Main function to run the entire generation process."""
    base_dir = Path(BASE_RESULTS_DIR)
    output_dir = Path(OUTPUT_DIR)
    output_dir.mkdir(parents=True, exist_ok=True)

    if not base_dir.exists():
        print(f"[ERROR] Base results directory not found: {base_dir}")
        return

    datasets = [d for d in base_dir.iterdir() if d.is_dir()]
    print(f"Found {len(datasets)} datasets. Starting configuration generation...\n")

    for dataset_path in datasets:
        dataset_name = dataset_path.name
        print(f"{'='*60}\nProcessing Dataset: {dataset_name}\n{'='*60}")

        # --- Scout Phase ---
        mlp_path = dataset_path / "mlp" / RUN_PARAMS_SUBDIR
        attn_path = dataset_path / "attention" / RUN_PARAMS_SUBDIR
        both_path = dataset_path / "both" / RUN_PARAMS_SUBDIR

        mlp_csv = find_final_analysis_csv(mlp_path)
        attn_csv = find_final_analysis_csv(attn_path)
        both_csv = find_final_analysis_csv(both_path)

        if not all([mlp_csv, attn_csv, both_csv]):
            print(f"[SKIPPING] Missing one or more required analysis files for {dataset_name}.\n")
            continue

        print(f"  - Found MLP analysis: {mlp_csv.name}")
        print(f"  - Found Attention analysis: {attn_csv.name}")
        print(f"  - Found Both analysis: {both_csv.name}")

        # --- Analyst Phase ---
        top_mlp = get_ranked_layers_from_csv(mlp_csv)
        top_attn = get_ranked_layers_from_csv(attn_csv)
        norm_lookup = create_norm_lookup_table(mlp_csv, attn_csv)

        # --- Architect Phase ---
        print("\n  Generating experiment configurations...")
        track1 = generate_track1_configs(top_mlp, top_attn, n=NUM_TOP_LAYERS)
        track2 = generate_track2_configs(both_csv, norm_lookup, n=NUM_TOP_LAYERS)

        final_configs = {**track1, **track2}

        # --- Save Phase ---
        output_filepath = output_dir / f"{dataset_name}_experiments.py"
        with open(output_filepath, "w") as f:
            f.write(f"# Auto-generated experiment configurations for dataset: {dataset_name}\n")
            f.write(f"# Generated by 02_scout_and_architect.py\n\n")
            f.write("EXPERIMENT_COMBINATIONS = ")
            f.write(pprint.pformat(final_configs, indent=4, width=120))
            f.write("\n")

        print(f"  >>> Successfully saved configurations to: {output_filepath}\n")

    print("All datasets processed.")


if __name__ == "__main__":
    main()