import os
import json
import random
import math
from pathlib import Path
from collections import Counter
import string
import re
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    GenerationConfig, BitsAndBytesConfig
)
import transformers
import argparse
import random
import subprocess, sys, csv, glob
import types


# ---------------------- Runtime Configs ---------------------- #

JOB_ID = os.getenv("SLURM_JOB_ID") or "manual_run"
CHILD_SCRIPT = os.path.join(os.path.dirname(__file__), "two_choices_experiment.py")

### Output dir
OUTPUTS = Path("logs") / JOB_ID
OUTPUT_DIR = "OUTPUT/"
PLOT_DIR = "plots/"



# ---------------------- Plotting Utilities ---------------------- #
def ensure_dir(path):
    """Ensure output directory exists."""
    os.makedirs(path, exist_ok=True)

def _ensure_dir(p: str):
    d = os.path.dirname(p)
    if d:
        os.makedirs(d, exist_ok=True)

def _launch_child_and_load_artifact(child_script: str, cli_args: list[str], artifact_json_path: str):
    # Run the child script (same file) in a fresh interpreter
    _ensure_dir(artifact_json_path)
    subprocess.run([sys.executable, child_script] + cli_args, check=True)
    # Load back the detailed results
    with open(artifact_json_path, "r") as f:
        return json.load(f)


def plot_accuracy_comparison(results, param_name, out_dir, baseline_results=None):
    os.makedirs(out_dir, exist_ok=True)
    xs = sorted(results.keys())

    # Gather metrics
    metric_keys = list(results[xs[0]]['metrics'].keys())
    for v in xs:
        assert all(k in results[v]['metrics'] for k in metric_keys), \
            f"Missing some metrics for param value {v}"

    plt.figure(figsize=(8,5))

    # Main curves
    markers = ['o', 's', 'D', '^', 'v', '<', '>']
    for idx, k in enumerate(metric_keys):
        scores = [results[v]['metrics'][k] for v in xs]
        plt.plot(xs, scores,
                 marker=markers[idx % len(markers)],
                 linestyle='-',
                 label=k)

    # Baselines
    if baseline_results:
        is_multi = all(isinstance(v, dict) and ('metrics' not in v)
                       for v in baseline_results.values())
        if is_multi:
            # Multiple baselines (e.g. "Strict", "Load-only")
            baseline_markers = {'Strict': 'x', 'Load-only': '+'}
            baseline_styles  = {'Strict': '--', 'Load-only': '-.'}
            for label, br in baseline_results.items():
                for idx, k in enumerate(metric_keys):
                    if len(br) > 1:
                        ys = [br[v]['metrics'][k] for v in xs]
                        plt.plot(xs, ys,
                                 marker=baseline_markers.get(label, 'x'),
                                 linestyle=baseline_styles.get(label, '--'),
                                 label=f'{label} {k}')
                    else:
                        r = list(br.values())[0]
                        plt.axhline(y=r['metrics'][k],
                                    linestyle=baseline_styles.get(label, '--'),
                                    label=f'{label} {k}')
        else:
            # Single baseline
            for idx, k in enumerate(metric_keys):
                if len(baseline_results) > 1:
                    ys = [baseline_results[v]['metrics'][k] for v in xs]
                    plt.plot(xs, ys, marker='x', linestyle='--', label=f'Baseline {k}')
                else:
                    r = list(baseline_results.values())[0]
                    plt.axhline(y=r['metrics'][k], linestyle='--', label=f'Baseline {k}')

    plt.xlabel(param_name)
    plt.ylabel('Score')
    plt.title(f'Accuracy vs {param_name}')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"{param_name.lower()}_accuracy.png"))
    plt.close()


def plot_imbalance_multipliers(results, param_name, out_dir, baseline_results=None):
    os.makedirs(out_dir, exist_ok=True)
    xs = sorted(results.keys())
    mean_I, p95_I = [], []
    for x in xs:
        agg = results[x].get('per_batch_imbalance', {}).get('aggregate', {})
        mean_I.append(agg.get('mean_I_agg', np.nan))
        p95_I.append(agg.get('p95_I_agg', np.nan))

    plt.figure(figsize=(8,5))

    # main curves
    plt.plot(xs, mean_I, marker='o', color='C0', label='mean $I^{agg}$')
    plt.plot(xs, p95_I, marker='s', color='C1', label='p95 $I^{agg}$')

    # Distinct styles for baselines
    baseline_styles = {
        "Strict": {"mean": {"marker": "D", "linestyle": "--", "color": "C2"},
                   "p95":  {"marker": "x", "linestyle": "--", "color": "C2"}},
        "Load-only": {"mean": {"marker": "^", "linestyle": ":", "color": "C3"},
                      "p95":  {"marker": "v", "linestyle": ":", "color": "C3"}},
        "Baseline": {"mean": {"marker": "P", "linestyle": "-.", "color": "C4"},
                     "p95":  {"marker": "*", "linestyle": "-.", "color": "C4"}},
    }

    if baseline_results:
        is_multi = all(isinstance(v, dict) and ('per_batch_imbalance' not in v)
                       for v in baseline_results.values())
        if is_multi:
            for label, br in baseline_results.items():
                styles = baseline_styles.get(label, baseline_styles["Baseline"])
                if len(br) > 1:
                    b_mean, b_p95 = [], []
                    for x in xs:
                        agg = br[x].get('per_batch_imbalance', {}).get('aggregate', {})
                        b_mean.append(agg.get('mean_I_agg', np.nan))
                        b_p95.append(agg.get('p95_I_agg',  np.nan))
                    plt.plot(xs, b_mean, label=f'{label} mean $I^{{agg}}$', **styles["mean"])
                    plt.plot(xs, b_p95,  label=f'{label} p95 $I^{{agg}}$', **styles["p95"])
                else:
                    r = list(br.values())[0]
                    agg = r.get('per_batch_imbalance', {}).get('aggregate', {})
                    if 'mean_I_agg' in agg:
                        plt.axhline(y=agg['mean_I_agg'], label=f'{label} mean $I^{{agg}}$', **styles["mean"])
                    if 'p95_I_agg' in agg:
                        plt.axhline(y=agg['p95_I_agg'], label=f'{label} p95 $I^{{agg}}$', **styles["p95"])
        else:
            styles = baseline_styles["Baseline"]
            if len(baseline_results) > 1:
                b_mean, b_p95 = [], []
                for x in xs:
                    agg = baseline_results[x].get('per_batch_imbalance', {}).get('aggregate', {})
                    b_mean.append(agg.get('mean_I_agg', np.nan))
                    b_p95.append(agg.get('p95_I_agg',  np.nan))
                plt.plot(xs, b_mean, label='Baseline mean $I^{agg}$', **styles["mean"])
                plt.plot(xs, b_p95,  label='Baseline p95 $I^{agg}$', **styles["p95"])
            else:
                r = list(baseline_results.values())[0]
                agg = r.get('per_batch_imbalance', {}).get('aggregate', {})
                if 'mean_I_agg' in agg:
                    plt.axhline(y=agg['mean_I_agg'], label='Baseline mean $I^{agg}$', **styles["mean"])
                if 'p95_I_agg' in agg:
                    plt.axhline(y=agg['p95_I_agg'], label='Baseline p95 $I^{agg}$', **styles["p95"])

    plt.xlabel(param_name)
    plt.ylabel('Imbalance multiplier')
    plt.title(f'Aggregate imbalance vs {param_name}')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"{param_name.lower()}_imbalance.png"))
    plt.close()

def plot_perplexity_comparison(results, param_name, output_dir, baseline_results=None):
    ensure_dir(output_dir)
    param_values = sorted(results.keys())

    plt.figure(figsize=(10, 6))
    y = [results[v]['metrics']['perplexity'] for v in param_values]
    plt.plot(param_values, y, marker='o', label='Perplexity')

    # Baselines
    if baseline_results:
        is_multi = all(isinstance(v, dict) and ('metrics' not in v)
                       for v in baseline_results.values())
        if is_multi:
            for label, br in baseline_results.items():
                if len(br) > 1:
                    yb = [br[v]['metrics']['perplexity'] for v in param_values]
                    plt.plot(param_values, yb, marker='s', linestyle='--', label=f'{label} Perplexity')
                else:
                    r = list(br.values())[0]
                    plt.axhline(y=r['metrics']['perplexity'], linestyle=':', label=f'{label} Perplexity')
        else:
            if len(baseline_results) > 1:
                yb = [baseline_results[v]['metrics']['perplexity'] for v in param_values]
                plt.plot(param_values, yb, marker='s', linestyle='--', label='Baseline Perplexity')
            else:
                r = list(baseline_results.values())[0]
                plt.axhline(y=r['metrics']['perplexity'], linestyle=':', label='Baseline Perplexity')

    plt.xlabel(param_name)
    plt.ylabel('Perplexity')
    plt.title(f'Perplexity vs {param_name}')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/{param_name.lower()}_perplexity.png")
    plt.close()

def plot_latency_comparison(results, param_name, output_dir, baseline_results=None):
    ensure_dir(output_dir)
    param_values = sorted(results.keys())
    y = [results[v]['latency']['avg_latency_per_token_sec'] for v in param_values]

    plt.figure(figsize=(10, 6))
    plt.plot(param_values, y, marker='o', label='Avg Latency per Token (s)')

    if baseline_results:
        is_multi = all(isinstance(v, dict) and ('latency' not in v)
                       for v in baseline_results.values())
        if is_multi:
            for label, br in baseline_results.items():
                if len(br) > 1:
                    yb = [br[v]['latency']['avg_latency_per_token_sec'] for v in param_values]
                    plt.plot(param_values, yb, marker='x', linestyle='--', label=f'{label} Latency')
                else:
                    r = list(br.values())[0]
                    plt.axhline(y=r['latency']['avg_latency_per_token_sec'], linestyle=':', label=f'{label} Latency')
        else:
            if len(baseline_results) > 1:
                yb = [baseline_results[v]['latency']['avg_latency_per_token_sec'] for v in param_values]
                plt.plot(param_values, yb, marker='x', linestyle='--', label='Baseline Latency')
            else:
                r = list(baseline_results.values())[0]
                plt.axhline(y=r['latency']['avg_latency_per_token_sec'], linestyle=':', label='Baseline Latency')

    plt.xlabel(param_name)
    plt.ylabel('Latency (s/token)')
    plt.title(f'Average Latency per Token vs {param_name}')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/{param_name.lower()}_latency.png")
    plt.close()


def plot_violation_comparison(
    x_values,
    avg_max_violations,
    worst_max_violations,
    baseline_avg=None,
    baseline_worst=None,
    xlabel="",
    title="",
    filename="violation_plot.png",
    output_dir="."
):
    ensure_dir(output_dir)
    plt.figure(figsize=(10, 6))

    plt.plot(x_values, avg_max_violations, 'o-', label='Avg Max Violation')
    plt.plot(x_values, worst_max_violations, 's--', label='Worst Max Violation')


    # Accept either lists (old single baseline) or dicts of label -> list/singleton
    def _plot_baseline_series(label, avg_list, worst_list):
        if avg_list:
            if len(avg_list) > 1:
                plt.plot(x_values, avg_list, 'x-', label=f'{label} Avg Max Viol.')
            else:
                plt.axhline(y=avg_list[0], color='gray', linestyle=':', label=f'{label} Avg Max Viol.')
        if worst_list:
            if len(worst_list) > 1:
                plt.plot(x_values, worst_list, '+--', label=f'{label} Worst Max Viol.')
            else:
                plt.axhline(y=worst_list[0], color='black', linestyle='--', label=f'{label} Worst Max Viol.')

    if isinstance(baseline_avg, dict) and isinstance(baseline_worst, dict):
        for label in baseline_avg.keys() | baseline_worst.keys():
            _plot_baseline_series(label, baseline_avg.get(label), baseline_worst.get(label))
    else:
        if baseline_avg or baseline_worst:
            _plot_baseline_series("Baseline",
                                  baseline_avg if isinstance(baseline_avg, list) else None,
                                  baseline_worst if isinstance(baseline_worst, list) else None)

    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel('Max Violation')
    plt.grid(True)
    plt.legend()
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

def plot_expert_distribution(
    results,
    output_dir,
    filename="token_distribution.png",
    normalize=True,
    title="Token Distribution Across Experts",
    baseline_results=None,
):
    ensure_dir(output_dir)

    def max_expert_idx(expert_counts_dict):
        m = -1
        for layer_dict in expert_counts_dict.values():
            for k in layer_dict.keys():
                try:
                    m = max(m, int(k))
                except (ValueError, TypeError):
                    pass
        return m

    # global expert id range across results + baselines
    n_max = -1
    for r in results.values():
        n_max = max(n_max, max_expert_idx(r["expert_counts"]))

    if baseline_results:
        is_multi = all(isinstance(v, dict) and ('expert_counts' not in v)
                       for v in baseline_results.values())
        if is_multi:
            for series in baseline_results.values():
                for r in series.values():
                    n_max = max(n_max, max_expert_idx(r["expert_counts"]))
        else:
            for r in baseline_results.values():
                n_max = max(n_max, max_expert_idx(r["expert_counts"]))

    if n_max < 0:
        raise ValueError("No expert counts found.")
    n_experts = n_max + 1
    x = np.arange(1, n_experts + 1)  # 1-based ticks

    def aggregate_across_layers(expert_counts_dict):
        totals = np.zeros(n_experts, dtype=float)
        for layer_dict in expert_counts_dict.values():
            for k, v in layer_dict.items():
                try:
                    idx = int(k)
                except (ValueError, TypeError):
                    continue
                if 0 <= idx < n_experts:
                    totals[idx] += float(v)
        return totals

    plt.figure(figsize=(14, 8))

    # main results
    for param_val in sorted(results.keys()):
        totals = aggregate_across_layers(results[param_val]["expert_counts"])
        y = 100 * totals / totals.sum() if normalize and totals.sum() > 0 else totals
        plt.plot(x, y, marker="o", label=str(param_val))

    # baselines (optional)
    if baseline_results:
        is_multi = all(isinstance(v, dict) and ('expert_counts' not in v)
                       for v in baseline_results.values())
        if is_multi:
            for label, series in baseline_results.items():
                if len(series) > 1:
                    for param_val in sorted(series.keys()):
                        totals = aggregate_across_layers(series[param_val]["expert_counts"])
                        y = 100 * totals / totals.sum() if normalize and totals.sum() > 0 else totals
                        plt.plot(x, y, marker="x", linestyle="--", label=f"{label} {param_val}")
                else:
                    r = list(series.values())[0]
                    totals = aggregate_across_layers(r["expert_counts"])
                    y = 100 * totals / totals.sum() if normalize and totals.sum() > 0 else totals
                    plt.plot(x, y, marker="x", linestyle="--", label=label)
        else:
            if len(baseline_results) > 1:
                for param_val in sorted(baseline_results.keys()):
                    totals = aggregate_across_layers(baseline_results[param_val]["expert_counts"])
                    y = 100 * totals / totals.sum() if normalize and totals.sum() > 0 else totals
                    plt.plot(x, y, marker="x", linestyle="--", label=f"baseline_{param_val}")
            else:
                r = list(baseline_results.values())[0]
                totals = aggregate_across_layers(r["expert_counts"])
                y = 100 * totals / totals.sum() if normalize and totals.sum() > 0 else totals
                plt.plot(x, y, marker="x", linestyle="--", label="baseline")

    plt.xlabel("Expert Index (1-based)")
    plt.ylabel("% of Tokens Assigned" if normalize else "Token Count")
    plt.title(title)
    plt.xticks(np.arange(1, n_experts + 1, max(1, n_experts // 16 or 1)))
    if normalize:
        plt.ylim(0, 100)
    plt.grid(True)
    if len(results) + (len(baseline_results) if baseline_results else 0) > 1:
        plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()


# ---------------------- Experiement Runner ---------------------- #

def run_experiment_grid(
    args,
    sweep_name: str,
    sweep_values: list,
    get_config_fn,
    baseline_config_fn=None,
    x_label="",
    plot_key="",
    sweep_baselines=True,
):
    """
    New master grid runner:
    - Spawns a fresh Python process per run (avoids dynamic module caching issues)
    - Each run writes an artifact JSON + appends a row to a shared summary CSV
    - We load artifact JSONs back and build `results` / `baseline_results` for plotting
    """
    # Where artifacts and summary live for this sweep
    sweep_root = os.path.join(PLOT_DIR, sweep_name)
    runs_dir   = os.path.join(sweep_root, "runs")
    summary_csv = os.path.join(sweep_root, "summary.csv")
    _ensure_dir(runs_dir)

    # We'll call THIS script as the child (so it uses the exact same code path)
    child_script = os.path.abspath(CHILD_SCRIPT)

    results = {}
    baseline_results_strict = {}
    baseline_results_load   = {}

    # Helper: build the CLI list for a given config
    def build_cli(experiment_name: str, selection_method: str, num_choices: int, sample_size: int):
        artifact_json = os.path.join(runs_dir, f"{experiment_name}.json")
        cli = [
            "--model_name", args.model_name,
            "--model_family", args.model_family,
            "--mode", args.mode,
            "--model_type", args.model_type,
            "--selection_method", selection_method,
            "--num_choices", str(num_choices),
            "--batch_size", str(args.batch_size),
            "--sample_size", str(sample_size),
            "--experiment_name", experiment_name,
            "--seed", str(args.seed),
            "--beta", str(args.beta),
            "--dataset_name", args.dataset_name,
            "--output_dir", OUTPUT_DIR,
            "--result_csv", summary_csv,
            "--artifact_json", artifact_json,
        ] + ["--sum_threshold"] + [str(x) for x in args.sum_threshold] + ["--threshold_factor"] + [str(x) for x in args.threshold_factor]

        if selection_method != "baseline":
            cli += ["--lora_path", args.lora_path]

        if (selection_method == "baseline" or selection_method == "load_only") and args.baseline_lora_path:
            cli += ["--lora_path", args.baseline_lora_path]

        if args.sample_before_load:
            cli.append("--sample_before_load")

        if args.vectorized:
            cli.append("--vectorized")

        cli += [
            "--max_new_tokens", str(args.max_new_tokens),
            "--max_prompt_length", str(args.max_prompt_length),
            "--max_perplexity_length", str(args.max_perplexity_length),
        ]
        return cli, artifact_json

    # Sweep
    for i, val in enumerate(sweep_values):
        

        # Optional baseline
        if baseline_config_fn is not None and (i == 0 or sweep_baselines) and not args.ignore_baseline:
            base_cfg = baseline_config_fn(val)
            # strict baseline: score-only top-k
            base_name_strict = f"{sweep_name}-{plot_key}{val}-baseline_strict"
            print(f"[Baseline: strict] Running {base_name_strict}")
            base_cli, base_artifact = build_cli(
                experiment_name=base_name_strict,
                selection_method="baseline",
                num_choices=base_cfg.get("num_choices", -1),
                sample_size=base_cfg.get("sample_size", args.sample_size),
            )
            base_art = _launch_child_and_load_artifact(child_script, base_cli, base_artifact)
            viols = [l['max_violation'] for l in base_art['all_layer_metrics']]
            baseline_results_strict[val] = {
                'avg_violation': float(np.mean(viols)) if viols else 0.0,
                'worst_violation': float(np.max(viols)) if viols else 0.0,
                'latency': base_art['latency'],
                'metrics': base_art['metrics'],
                'all_layer_metrics': base_art['all_layer_metrics'],
                'expert_counts': base_art['expert_counts'],
                "per_batch_imbalance": base_art['per_batch_imbalance'],
            }

            # load-only baseline: ignore scores, route by load
            base_name_load = f"{sweep_name}-{plot_key}{val}-baseline_loadonly"
            print(f"[Baseline: load_only] Running {base_name_load}")
            base_cli, base_artifact = build_cli(
                experiment_name=base_name_load,
                selection_method="load_only",
                num_choices=base_cfg.get("num_choices", -1),
                sample_size=base_cfg.get("sample_size", args.sample_size),
            )
            base_art = _launch_child_and_load_artifact(child_script, base_cli, base_artifact)
            viols = [l['max_violation'] for l in base_art['all_layer_metrics']]
            baseline_results_load[val] = {
                'avg_violation': float(np.mean(viols)) if viols else 0.0,
                'worst_violation': float(np.max(viols)) if viols else 0.0,
                'latency': base_art['latency'],
                'metrics': base_art['metrics'],
                'all_layer_metrics': base_art['all_layer_metrics'],
                'expert_counts': base_art['expert_counts'],
                "per_batch_imbalance": base_art['per_batch_imbalance'],
            }

        # Main run for this value
        cfg = get_config_fn(val)
        exp_name = f"{sweep_name}-{plot_key}{cfg[plot_key]}-{args.selection_method}"
        print(f"[Main] Running {exp_name}")
        # inject into cli if parameter is override in experiment config
        main_cli, main_artifact = build_cli(
            experiment_name=exp_name, 
            selection_method=args.selection_method, 
            num_choices=cfg.get("num_choices", -1),
            sample_size=cfg.get("sample_size", args.sample_size),
        )
        art = _launch_child_and_load_artifact(child_script, main_cli, main_artifact)

        # Build the in-memory structure plotting funcs already expect
        results[cfg[plot_key]] = {
            **cfg,
            'expert_counts': art['expert_counts'],
            'all_layer_metrics': art['all_layer_metrics'],
            'metrics': art['metrics'],
            'latency': art['latency'],
            "per_batch_imbalance": art['per_batch_imbalance'],
        }


    # === Post-processing identical to old code ===
    def extract_violations(r):
        violations = [l['max_violation'] for l in r['all_layer_metrics']]
        return np.mean(violations), np.max(violations)

    x_values = sorted(results.keys())
    avg_max_violations  = [extract_violations(results[x])[0] for x in x_values]
    worst_max_violations = [extract_violations(results[x])[1] for x in x_values]

    # prepare multi-baseline series for plots
    baselines = None
    if baseline_results_strict or baseline_results_load:

        if not sweep_baselines:
            # Compare against the first baseline only

            baselines = {
                "Strict": {"avg": [list(baseline_results_strict.values())[0]['avg_violation']] if baseline_results_strict else None,
                           "worst": [list(baseline_results_strict.values())[0]['worst_violation']] if baseline_results_strict else None,
                           "series": baseline_results_strict,},
                "Load-only": {"avg": [list(baseline_results_load.values())[0]['avg_violation']] if baseline_results_load else None,
                              "worst": [list(baseline_results_load.values())[0]['worst_violation']] if baseline_results_load else None,
                              "series": baseline_results_load,},
            }

        else:
            baselines = {}
            if baseline_results_strict:
                baselines["Strict"] = {
                    "series": baseline_results_strict,
                    "avg":  [baseline_results_strict[x]['avg_violation']   for x in x_values],
                    "worst":[baseline_results_strict[x]['worst_violation'] for x in x_values],
                }
            if baseline_results_load:
                baselines["Load-only"] = {
                    "series": baseline_results_load,
                    "avg":  [baseline_results_load[x]['avg_violation']   for x in x_values],
                    "worst":[baseline_results_load[x]['worst_violation'] for x in x_values],
                }

    # Save the aggregated in-memory dicts
    with open(os.path.join(sweep_root, f"{sweep_name}_results.json"), "w") as f:
        json.dump(results, f, indent=2, default=str)

    if baseline_results_strict or baseline_results_load:
        with open(os.path.join(sweep_root, f"{sweep_name}_baseline_results.json"), "w") as f:
            json.dump({"strict": baseline_results_strict, "load_only": baseline_results_load}, f, indent=2, default=str)

    plot_violation_comparison(
        x_values=x_values,
        avg_max_violations=avg_max_violations,
        worst_max_violations=worst_max_violations,
        baseline_avg=None if baselines is None else {k: v.get("avg") for k, v in baselines.items()},
        baseline_worst=None if baselines is None else {k: v.get("worst") for k, v in baselines.items()},
        xlabel=x_label,
        title=f'Expert Load Imbalance vs {x_label}',
        filename=f'{sweep_name}_violation.png',
        output_dir=sweep_root
    )

    if args.mode == "perplexity":
        plot_perplexity_comparison(results, param_name=x_label, output_dir=sweep_root,
                                   baseline_results=None if baselines is None else {k: v.get("series") for k,v in baselines.items() if v.get("series")})
    elif args.mode == "qa":
        plot_accuracy_comparison(results, param_name=x_label, out_dir=sweep_root,
                                 baseline_results=None if baselines is None else {k: v.get("series") for k,v in baselines.items() if v.get("series")})

    plot_expert_distribution(results, output_dir=sweep_root, normalize=True, filename="token_distribution_normalized.png",
                             baseline_results=None if baselines is None else {k: v.get("series") for k,v in baselines.items() if v.get("series")})
    plot_expert_distribution(results, output_dir=sweep_root, normalize=False,
                             baseline_results=None if baselines is None else {k: v.get("series") for k,v in baselines.items() if v.get("series")})
    
    plot_imbalance_multipliers(
        results, param_name=x_label, out_dir=sweep_root,
        baseline_results=None if baselines is None else {k: v["series"] for k, v in baselines.items()}
    )

    print(f"[Done] Wrote plots and summaries to: {sweep_root}")
    return results


def run_num_choices_experiment(args):
    sweep_values = args.num_choices

    def get_config(val):
        return {
            'num_choices': val,
        }

    def baseline_cfg(val):
        return {
        }

    return run_experiment_grid(
        args=args,
        sweep_name="num_choices",
        sweep_values=sweep_values,
        get_config_fn=get_config,
        baseline_config_fn=baseline_cfg,
        x_label="Number of Choices",
        plot_key="num_choices",
        sweep_baselines=False,
    )

def run_sample_size_experiment(args):
    sweep_values = args.sample_sizes

    def get_config(sample_size):
        return {
            'sample_size': sample_size,
        }

    def baseline_cfg(sample_size):
        return {
            'sample_size': sample_size,
        }

    return run_experiment_grid(
        args=args,
        sweep_name="sample_size",
        sweep_values=sweep_values,
        get_config_fn=get_config,
        baseline_config_fn=baseline_cfg,
        x_label="Sample Size",
        plot_key="sample_size"
    )



# ---------------------- Runner Entry ---------------------- #

def parse_args():
    parser = argparse.ArgumentParser()

    # Output
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--plot_dir", type=str, required=True)

    # General config for the actual runner
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--model_family", type=str, required=True)
    parser.add_argument("--lora_path", type=str, required=False, default="")
    parser.add_argument("--baseline_lora_path", type=str, required=False, default="")
    parser.add_argument("--mode", choices=["qa", "perplexity"], default="perplexity")
    parser.add_argument("--model_type", type=str, default="chat")
    parser.add_argument("--selection_method", choices=["gini", "threshold"], default="gini")

    parser.add_argument(
        "--threshold_factor", 
        type=float, 
        nargs="+",   
        default=[0.9],
        help="One or more floats (space separated)"
    ) 

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--beta", type=int, default=1)
    parser.add_argument("--sample_before_load", action="store_true", help="If we sample prior to computing and selecting based on load")
    parser.add_argument("--vectorized", action="store_true")
    parser.add_argument("--ignore_baseline", action="store_true")
    parser.add_argument(
        "--sum_threshold", 
        type=float, 
        nargs="+",   
        default=[-1],
        help="One or more floats (space separated)"
    ) # by default not used

    # Experiment setup
    parser.add_argument("--num_choices", type=int, nargs="+", default=[])
    parser.add_argument("--sample_sizes", type=int, nargs="+", default=[])

    # Generation/perplexity limits
    parser.add_argument("--max_new_tokens", type=int, default=1)
    parser.add_argument("--max_prompt_length", type=int, default=1)
    parser.add_argument("--max_perplexity_length", type=int, default=1)

    # Dataset
    parser.add_argument("--dataset_name", type=str, default="wiki")

    # Defaults
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--sample_size", type=int)


    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    # Output
    OUTPUT_DIR = args.output_dir
    PLOT_DIR = args.plot_dir


    print("Script path:", os.path.abspath(__file__))

    if args.num_choices:
        print(f"\nRunning experiments for different numbers of choices {args.num_choices} ...")
        run_num_choices_experiment(args)
    if args.sample_sizes:
        print(f"\nRunning experiments for different numbers of samples {args.sample_sizes}...")
        run_sample_size_experiment(args)
