#!/usr/bin/env python3

import argparse
import json
import os
import random
from typing import Any, Dict, List


def load_json(path: str) -> Any:

    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def save_json(data: Any, path: str) -> None:

    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def sample_indices(total: int, k: int) -> List[int]:

    k = min(k, total)
    return random.sample(range(total), k)


def collect_per_layer_metrics(layer_metrics: Dict[str, Dict[str, Any]], idx: int) -> Dict[str, Any]:

    per_layer: Dict[str, Any] = {}
    for layer_key, metrics in layer_metrics.items():
        try:
            per_layer[layer_key] = {
                "high_kl_divergences": metrics.get("high_kl_divergences")[idx],
                "top_1_differences": metrics.get("top_1_differences")[idx],
                "idx_of_first_top_1_difference": metrics.get("idx_of_first_top_1_difference")[idx],
                "idx_of_first_high_kl_divergence": metrics.get("idx_of_first_high_kl_divergence")[idx],
                "min_idx": metrics.get("min_idxs")[idx],
                "max_idx": metrics.get("max_idxs")[idx],
            }
        except Exception:
            # If any metric is missing or shapes mismatch, store None gracefully
            per_layer[layer_key] = {
                "high_kl_divergences": None,
                "top_1_differences": None,
                "idx_of_first_top_1_difference": None,
                "idx_of_first_high_kl_divergence": None,
                "min_idx": None,
                "max_idx": None,
            }
    return per_layer


def main() -> None:

    parser = argparse.ArgumentParser(description="Sample text lines and pair with probability-based metrics.")
    parser.add_argument(
        "exp_dir",
        help="Path to the experiment directory containing generated_lines.json and prob_based_metrics.json",
    )
    parser.add_argument(
        "--unsteered_k",
        type=int,
        default=10,
        help="Number of unsteered samples to select (default: 10)",
    )
    parser.add_argument(
        "--steered_k",
        type=int,
        default=10,
        help="Number of steered samples to select (default: 10)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility",
    )
    parser.add_argument(
        "--output",
        default=None,
        help="Optional output JSON path. Defaults to <exp_dir>/sample_prob_pairs.json",
    )

    args = parser.parse_args()

    random.seed(args.seed)

    input_file = os.path.join(args.exp_dir, "generated_lines.json")
    metrics_file = os.path.join(args.exp_dir, "prob_based_metrics.json")

    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")
    if not os.path.exists(metrics_file):
        raise FileNotFoundError(f"Metrics file not found: {metrics_file}")

    line_data = load_json(input_file)
    metrics_data = load_json(metrics_file)
    unsteered_texts: List[str] = line_data.get("unsteered_texts", [])
    steered_texts: List[str] = line_data.get("steered_texts", [])["14"]

    if not isinstance(unsteered_texts, list) or not isinstance(steered_texts, list):
        raise ValueError("Input file missing 'unsteered_texts' or 'steered_texts' lists")

    layer_metrics: Dict[str, Dict[str, Any]] = metrics_data.get("layer_metrics", {})
    metadata = metrics_data.get("metadata", {})

    # Sample indices
    unsteered_indices = sample_indices(len(unsteered_texts), args.unsteered_k)
    steered_indices = sample_indices(len(steered_texts), args.steered_k)

    samples: List[Dict[str, Any]] = []

    for idx in unsteered_indices:
        samples.append(
            {
                "type": "unsteered",
                "index": idx,
                "text": unsteered_texts[idx],
                "per_layer_metrics": collect_per_layer_metrics(layer_metrics, idx),
            }
        )

    for idx in steered_indices:
        samples.append(
            {
                "type": "steered",
                "index": idx,
                "text": steered_texts[idx],
                "per_layer_metrics": collect_per_layer_metrics(layer_metrics, idx),
            }
        )

    output = {
        "exp_dir": os.path.abspath(args.exp_dir),
        "seed": args.seed,
        "metadata": metadata,
        "num_unsteered": len(unsteered_indices),
        "num_steered": len(steered_indices),
        "samples": samples,
    }

    out_path = args.output or os.path.join(args.exp_dir, "sample_prob_pairs.json")
    save_json(output, out_path)
    print(f"Saved {len(samples)} samples to: {out_path}")


if __name__ == "__main__":

    main()


