#!/usr/bin/env python
# coding=utf-8
"""
Precompute average alignment metric ratios for each module _type_ (e.g. q_proj, k_proj, ...)
and write the sorted (increasing) list to disk.

The script is meant for exploratory analysis / debugging of the layer-selection logic
implemented in `train.select_layers_by_ratio`.  Concretely it will

1.  Load a model + tokenizer given by `--model_name_or_path`.
2.  Build a (tokenised) dataset using `dataset_transformation.get_cached_dataset_tulu` – the same
    helper that the training script relies on.  You specify the dataset via
    `--dataset_mixer_list` (e.g.  "allenai/tulu-3-sft-mixture 1.0").
3.  Run `train.compute_and_save_alignment_metrics` to obtain the per-module statistics
    (saved under `<output_dir>/raw_metrics/<dataset_key>_metrics.json`).
4.  Aggregate those per-module stats **per module suffix** and compute the average
    `actual/random` ratio for each suffix.
5.  Save the resulting list, sorted in *increasing* order, to
    `<output_dir>/module_type_avg_ratios_increasing.json`.

Example:
    python -m src.precompute_module_type_ratios \
        --model_name_or_path meta-llama/Llama-3.2-1B \
        --dataset_mixer_list allenai/tulu-3-sft-mixture 1.0
"""
from __future__ import annotations

import argparse
import json
import os
import math
from typing import List, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import datasets

from dataset_transformation import (
    TokenizerConfig,
    get_cached_dataset_tulu,
    TOKENIZED_SFT_DATASET_KEYS,
)

# `compute_and_save_alignment_metrics` + helper (for naming the output directory)
# live inside the training script and are imported directly from there to avoid code dup.
from train import compute_and_save_alignment_metrics, get_standardized_output_dir

from accelerate import PartialState  # Needed for accelerate logging utilities


DEFAULT_TARGET_MODULE_TYPES = [
    "q_proj",
    "o_proj",
    "v_proj",
    "k_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]


# -----------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------

def build_dataset(
    dataset_mixer_list: List[str],
    dataset_mixer_list_splits: List[str],
    tokenizer_name_or_path: str,
    dataset_transform_fn: List[str] | None = None,
    max_seq_length: int | None = None,
) -> "datasets.Dataset":
    """Download / tokenise the dataset using the same utility as the training script."""
    # Minimal `TokenizerConfig` – we do *not* need the fancy options used in train.py
    tc = TokenizerConfig(
        tokenizer_name_or_path=tokenizer_name_or_path,
        use_fast=True,
    )

    # Decide transform fns
    if dataset_transform_fn is None:
        # Heuristic: if MetaMathQA in dataset names use metamathqa transforms else Tulu ones
        if any("MetaMathQA" in name for name in dataset_mixer_list):
            dataset_transform_fn = [
                "sft_metamathqa_tokenize_and_truncate_v1",
                "sft_metamathqa_filter_v1",
            ]
            if max_seq_length is None:
                max_seq_length = 1024
        else:
            dataset_transform_fn = [
                "sft_tulu_tokenize_and_truncate_v1",
                "sft_tulu_filter_v1",
            ]

    if max_seq_length is None:
        max_seq_length = None

    transform_fn_args = [
        {"max_seq_length": max_seq_length},
        {},
    ]

    # If the user only provided one split, repeat for each dataset in the mixer list
    if len(dataset_mixer_list_splits) == 1:
        dataset_mixer_list_splits = dataset_mixer_list_splits * (len(dataset_mixer_list) // 2)

    dataset = get_cached_dataset_tulu(
        dataset_mixer_list=dataset_mixer_list,
        dataset_mixer_list_splits=dataset_mixer_list_splits,
        tc=tc,
        dataset_transform_fn=dataset_transform_fn,
        transform_fn_args=transform_fn_args,
        target_columns=TOKENIZED_SFT_DATASET_KEYS,
        dataset_skip_cache=False,
    )

    # Shuffle for some randomisation and set PT format (needed for alignment metrics helper)
    dataset = dataset.shuffle(seed=42)
    dataset.set_format(type="pt")
    return dataset, tc


def load_model_and_tokenizer(model_name_or_path: str):
    """Load the model on the *first* visible CUDA device (or CPU) in eval mode."""
    # We keep things simple here; no qlora, no flash-attn toggle, etc.
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        use_fast=True,
        trust_remote_code=False,
    )

    # Ensure the tokenizer has a pad token; otherwise default to eos.
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            # Fall back to adding a new pad token (rare for modern LLM tokenizers)
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch.bfloat16,
        device_map="auto" if torch.cuda.is_available() else None,
        low_cpu_mem_usage=True,
    )
    model.eval()
    return model, tokenizer


def compute_suffix_avg_ratios(
    metrics: dict,
    target_suffixes: List[str],
) -> List[Tuple[str, float]]:
    """Return list of (module_suffix, avg_ratio) sorted in increasing order."""
    avg_ratios: List[Tuple[str, float]] = []
    for suffix in target_suffixes:
        ratios = []
        for name, scores in metrics.items():
            if name.endswith(suffix):
                a = scores.get("actual", 0.0)
                r = scores.get("random", 0.0)
                ratio = a / r if r else math.inf
                ratios.append(ratio)
        if ratios:
            avg = sum(ratios) / len(ratios)
            avg_ratios.append((suffix, avg))
    avg_ratios.sort(key=lambda x: x[1])  # increasing order
    return avg_ratios


# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------

def parse_args():
    parser = argparse.ArgumentParser(description="Precompute module-type average ratios (alignment metrics).")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        required=True,
        help="HF model name or local path.",
    )
    parser.add_argument(
        "--dataset_mixer_list",
        type=str,
        nargs="+",
        required=True,
        help="Interleaved list like:  <dataset1> <frac1/num1> <dataset2> <frac2> ...",
    )
    parser.add_argument(
        "--dataset_mixer_list_splits",
        type=str,
        nargs="+",
        default=["train"],
        help="Splits corresponding to the datasets (defaults to 'train').",
    )
    parser.add_argument("--sample_size", type=int, default=100)
    parser.add_argument("--max_seq_length", type=int, default=None)
    parser.add_argument(
        "--dataset_transform_fn",
        type=str,
        nargs="+",
        default=None,
        help="Override dataset transform function list if provided.",
    )
    parser.add_argument("--max_length", type=int, default=256)
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Directory to place the metric outputs (defaults to a name based on the model).",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    # Initialize accelerate PartialState so that accelerate.logging works without an Accelerator instance
    PartialState()

    output_dir = args.output_dir or get_standardized_output_dir(args.model_name_or_path)
    os.makedirs(output_dir, exist_ok=True)

    # ------------------------------------------------------------------
    # Load model, tokenizer, dataset and compute raw metrics
    # ------------------------------------------------------------------
    model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

    dataset, tc = build_dataset(
        dataset_mixer_list=args.dataset_mixer_list,
        dataset_mixer_list_splits=args.dataset_mixer_list_splits,
        tokenizer_name_or_path=args.model_name_or_path,
        dataset_transform_fn=args.dataset_transform_fn,
        max_seq_length=args.max_seq_length,
    )

    # Reconstruct dataset key (same logic as train.py)
    dataset_key_elems = [
        name.split("/")[-1]
        for name in args.dataset_mixer_list
        if not name.replace(".", "").isdigit()
    ]
    dataset_key = "_".join(dataset_key_elems[:1]) or "train_dataset"

    # ------------------------------------------------------------------
    # Ensure we have raw metrics in the *standard* alignment directory only.
    # ------------------------------------------------------------------
    default_alignment_dir = get_standardized_output_dir(args.model_name_or_path)
    default_metrics_file = os.path.join(
        default_alignment_dir, "raw_metrics", f"{dataset_key}_metrics.json"
    )

    # If metrics don't exist yet, compute them once into the default directory.
    if not os.path.exists(default_metrics_file):
        compute_and_save_alignment_metrics(
            model,
            tokenizer,
            dataset,
            dataset_key,
            default_alignment_dir,
            sample_size=args.sample_size,
            max_length=args.max_length,
        )

    metrics_file_path_to_use = default_metrics_file

    # ------------------------------------------------------------------
    # Load metrics and aggregate per module type
    # ------------------------------------------------------------------
    with open(metrics_file_path_to_use, "r") as f:
        metrics = json.load(f)

    avg_ratios = compute_suffix_avg_ratios(metrics, DEFAULT_TARGET_MODULE_TYPES)

    # Save to disk (only the final summary JSON inside `output_dir`)
    ratios_path = os.path.join(output_dir, "module_type_avg_ratios_increasing.json")
    with open(ratios_path, "w") as f:
        json.dump(avg_ratios, f, indent=2)
    print(f"✅  Saved sorted module-type avg ratios to {ratios_path}")


if __name__ == "__main__":
    main() 