#!/usr/bin/env python3
"""
Combine two vectorized datasets (prompt_embeddings.pt and response_sae_repr.pt),
concatenate them, and apply a single shared shuffle so prompt and response remain aligned.

Usage example:
  python scripts/combine_vectorized_embeddings.py \
    --dataset-a \
      /data/home/Yunsheng/alignment-handbook/datasets/vectorized_last_alpaca_dataset_sae_llama3b_layers_14 \
    --dataset-b \
      /data/home/Yunsheng/alignment-handbook/datasets/vectorized_smoltalk_single_round_sae_llama3b_layers_14 \
    --output-dir \
      /data/home/Yunsheng/alignment-handbook/datasets/combined_alpaca_smoltalk_sae_llama3b_layers_14 \
    --seed 42
"""

import argparse
import json
from pathlib import Path
from typing import Dict, Any

import torch


def load_pair(directory: Path):
    prompt_path = directory / "prompt_embeddings.pt"
    response_path = directory / "response_sae_repr.pt"

    if not prompt_path.exists():
        raise FileNotFoundError(f"Missing file: {prompt_path}")
    if not response_path.exists():
        raise FileNotFoundError(f"Missing file: {response_path}")

    prompt = torch.load(prompt_path, map_location="cpu")
    response = torch.load(response_path, map_location="cpu")

    if not isinstance(prompt, torch.Tensor) or not isinstance(response, torch.Tensor):
        raise TypeError("Loaded objects must be torch.Tensor")
    if prompt.size(0) != response.size(0):
        raise ValueError(
            f"Mismatched number of rows in {directory}: prompt={prompt.size(0)} response={response.size(0)}"
        )

    meta_path = directory / "metadata.json"
    meta: Dict[str, Any] = {}
    if meta_path.exists():
        try:
            meta = json.loads(meta_path.read_text(encoding="utf-8"))
        except Exception:
            meta = {"warning": "failed_to_parse_metadata_json"}

    return prompt, response, meta


def main():
    parser = argparse.ArgumentParser(description="Combine and jointly shuffle vectorized datasets")
    parser.add_argument("--dataset-a", type=str, required=True, help="Path to first vectorized dataset dir")
    parser.add_argument("--dataset-b", type=str, required=True, help="Path to second vectorized dataset dir")
    parser.add_argument("--output-dir", type=str, required=True, help="Directory to save combined tensors")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling")
    args = parser.parse_args()

    dir_a = Path(args.dataset_a)
    dir_b = Path(args.dataset_b)
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load
    prompt_a, response_a, meta_a = load_pair(dir_a)
    prompt_b, response_b, meta_b = load_pair(dir_b)

    # Concatenate along batch dimension
    prompt_all = torch.cat([prompt_a, prompt_b], dim=0)
    response_all = torch.cat([response_a, response_b], dim=0)

    if prompt_all.size(0) != response_all.size(0):
        raise ValueError("Concatenated prompt and response have different number of rows")

    # Create a single shared permutation
    num_rows = prompt_all.size(0)
    g = torch.Generator()
    g.manual_seed(args.seed)
    permutation = torch.randperm(num_rows, generator=g)

    # Apply the same permutation to both tensors to keep alignment
    prompt_shuffled = prompt_all.index_select(0, permutation)
    response_shuffled = response_all.index_select(0, permutation)

    # Save outputs
    torch.save(prompt_shuffled, out_dir / "prompt_embeddings.pt")
    torch.save(response_shuffled, out_dir / "response_sae_repr.pt")

    # Save a compact metadata summary (optional)
    combined_meta = {
        "sources": {
            "dataset_a": str(dir_a),
            "dataset_b": str(dir_b),
        },
        "num_rows_a": int(prompt_a.size(0)),
        "num_rows_b": int(prompt_b.size(0)),
        "num_rows_total": int(num_rows),
        "prompt_shape": list(prompt_shuffled.size()),
        "response_shape": list(response_shuffled.size()),
        "seed": args.seed,
        "note": "prompt and response were shuffled with the SAME permutation to preserve alignment",
    }
    (out_dir / "metadata.json").write_text(json.dumps(combined_meta, ensure_ascii=False, indent=2), encoding="utf-8")

    # Optionally save the permutation indices for reproducibility/audit
    torch.save(permutation, out_dir / "shuffle_permutation.pt")

    print(f"✅ Combined and shuffled tensors saved to: {out_dir}")


if __name__ == "__main__":
    main()
