#!/usr/bin/env python3
"""
Main Experiment Runner

Usage:
    # Step 1: Hyperfit a model
    python run_experiments.py --mode hyperfit --model TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
    
    # Step 2: Run all experiments
    python run_experiments.py --mode experiments --original_model TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T --hyperfitted_model ./checkpoints/hyperfitted/final
    
    # Or run individual experiments
    python run_experiments.py --mode experiment1 --original_model ... --hyperfitted_model ...
"""

import os
import sys
import argparse
import json
import torch
import numpy as np
import tqdm
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import logging
from peft import PeftModel, PeftConfig
from typing import Optional

# Add src to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from hyperfitting_trainer import hyperfit_model, HyperfittingDataset, DATASET_CONFIGS
from experiments import (
    Experiment1_TemperatureMatching,
    Experiment2_RankAnalysis,
    Experiment3_SyntheticHyperfitting,
    Experiment4_RepresentationAnalysis,
    save_results,
)
from metrics import DistributionAnalyzer, GenerationMetrics

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def _build_quantization_config(args) -> Optional[BitsAndBytesConfig]:
    if args.load_in_4bit and args.load_in_8bit:
        raise ValueError("Only one of load_in_4bit or load_in_8bit can be enabled.")
    if not args.load_in_4bit and not args.load_in_8bit:
        return None

    dtype_map = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
    }
    bnb_dtype = dtype_map.get(args.bnb_4bit_compute_dtype, torch.bfloat16)

    if args.load_in_4bit:
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=bnb_dtype,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant,
        )
    return BitsAndBytesConfig(load_in_8bit=True)


def _is_peft_checkpoint(path: str) -> bool:
    if not os.path.isdir(path):
        return False
    adapter_config = os.path.join(path, "adapter_config.json")
    adapter_weights = os.path.join(path, "adapter_model.safetensors")
    adapter_weights_bin = os.path.join(path, "adapter_model.bin")
    return os.path.isfile(adapter_config) or os.path.isfile(adapter_weights) or os.path.isfile(adapter_weights_bin)


def load_eval_data(
    num_samples: int = 300,
    context_length: int = 32,
    sequence_length: int = 256,
    seed: int = 42,
    tokenizer=None,
    dataset_names=None,
) -> tuple:
    """
    Load evaluation data to match the paper:
    - 3 datasets (Wikipedia, Fiction-Stories, BBC News)
    - 100 samples each
    - 256 token sequences, first 32 tokens as context
    """
    if tokenizer is None:
        raise ValueError("tokenizer is required for evaluation data construction")
    if context_length >= sequence_length:
        raise ValueError("context_length must be smaller than sequence_length")

    np.random.seed(seed)
    logger.info("Loading evaluation data (paper-matched)...")

    if dataset_names is None:
        dataset_names = ["wikitext", "fiction-stories", "bbc-news"]

    per_dataset = num_samples // len(dataset_names)
    total_target = per_dataset * len(dataset_names)
    if total_target != num_samples:
        logger.warning(
            "num_samples not divisible by number of datasets; "
            f"using {total_target} total samples ({per_dataset} each)."
        )

    eval_sequences = []
    for dataset_name in dataset_names:
        if dataset_name not in DATASET_CONFIGS:
            raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(DATASET_CONFIGS.keys())}")

        config = DATASET_CONFIGS[dataset_name]
        ds_path = config["path"]
        ds_config = config["config"]
        ds_split = config["split"]
        text_field = config["text_field"]

        logger.info(f"Loading dataset for evaluation: {ds_path} (config={ds_config})")
        if ds_config:
            ds = load_dataset(ds_path, ds_config, split=ds_split, trust_remote_code=True)
        else:
            ds = load_dataset(ds_path, split=ds_split, trust_remote_code=True)

        ds = ds.shuffle(seed=seed)
        collected = 0

        for item in tqdm.tqdm(ds, desc=f"Collecting eval samples from {dataset_name}"):
            text = item.get(text_field, None)
            if text is None or len(text.strip()) == 0:
                continue

            tokens = tokenizer(
                text,
                truncation=True,
                max_length=sequence_length,
                padding=False,
                add_special_tokens=True,
            )["input_ids"]

            if len(tokens) == sequence_length:
                eval_sequences.append(tokens)
                collected += 1
                if collected >= per_dataset:
                    break

        if collected < per_dataset:
            logger.warning(
                f"Only collected {collected}/{per_dataset} samples for {dataset_name}. "
                "Consider a smaller sequence_length or more data."
            )

    eval_texts = [
        tokenizer.decode(tokens, skip_special_tokens=True)
        for tokens in eval_sequences
    ]
    prompts = [
        tokenizer.decode(tokens[:context_length], skip_special_tokens=True)
        for tokens in eval_sequences
    ]

    logger.info(f"Loaded {len(eval_texts)} evaluation texts and {len(prompts)} prompts")
    return eval_texts, prompts


def run_hyperfitting(args):
    """Run hyperfitting training"""
    logger.info("=" * 60)
    logger.info("HYPERFITTING TRAINING")
    logger.info("=" * 60)
    
    model, history = hyperfit_model(
        model_name=args.model,
        num_samples=args.num_samples,
        sequence_length=args.sequence_length,
        num_epochs=args.num_epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        save_dir=args.save_dir,
        torch_dtype=args.torch_dtype,
        dataset_seed=args.dataset_seed,
        dataset_shuffle=args.dataset_shuffle,
        use_lora=args.use_lora,
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        lora_target_modules=args.lora_target_modules,
        lora_target_layers=args.lora_target_layers,
        lora_bias=args.lora_bias,
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
        bnb_4bit_compute_dtype=args.bnb_4bit_compute_dtype,
        bnb_4bit_quant_type=args.bnb_4bit_quant_type,
        bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant,
    )
    
    logger.info("\nHyperfitting complete!")
    logger.info(f"Final training loss: {history['train_loss'][-1]:.6f}")
    logger.info(f"Model saved to: {args.save_dir}/final")
    
    return model, history


def load_models(args):
    """Load original and hyperfitted models"""
    logger.info("Loading models...")
    
    # Determine dtype
    dtype_map = {
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
    }
    dtype = dtype_map.get(args.torch_dtype, torch.bfloat16)
    quantization_config = _build_quantization_config(args)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.original_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load original model
    logger.info(f"Loading original model: {args.original_model}")
    original_model = AutoModelForCausalLM.from_pretrained(
        args.original_model,
        torch_dtype=dtype,
        device_map="auto",
        quantization_config=quantization_config,
    )
    
    # Load hyperfitted model
    logger.info(f"Loading hyperfitted model: {args.hyperfitted_model}")
    if _is_peft_checkpoint(args.hyperfitted_model):
        peft_config = PeftConfig.from_pretrained(args.hyperfitted_model)
        base_model_name = peft_config.base_model_name_or_path or args.original_model
        logger.info(f"Detected PEFT adapter; loading base model: {base_model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=dtype,
            device_map="auto",
            quantization_config=quantization_config,
        )
        hyperfitted_model = PeftModel.from_pretrained(
            base_model,
            args.hyperfitted_model,
            is_trainable=False,
        )
    else:
        hyperfitted_model = AutoModelForCausalLM.from_pretrained(
            args.hyperfitted_model,
            torch_dtype=dtype,
            device_map="auto",
            quantization_config=quantization_config,
        )
    
    return original_model, hyperfitted_model, tokenizer


def run_experiment1(original_model, hyperfitted_model, tokenizer, eval_texts, prompts, output_dir):
    """Run Experiment 1: Temperature Matching"""
    experiment = Experiment1_TemperatureMatching(
        original_model=original_model,
        hyperfitted_model=hyperfitted_model,
        tokenizer=tokenizer,
    )
    
    result = experiment.run(
        eval_texts=eval_texts[:50],  # Use subset for efficiency
        prompts=prompts[:30],
        max_new_tokens=224,
    )
    
    save_results(result, output_dir)
    return result


def run_experiment2(original_model, hyperfitted_model, tokenizer, eval_texts, output_dir):
    """Run Experiment 2: Rank Analysis"""
    experiment = Experiment2_RankAnalysis(
        original_model=original_model,
        hyperfitted_model=hyperfitted_model,
        tokenizer=tokenizer,
    )
    
    result = experiment.run(
        eval_texts=eval_texts[:30],  # Use subset for efficiency
        max_seq_length=256,
    )
    
    save_results(result, output_dir)
    return result


def run_experiment3(original_model, hyperfitted_model, tokenizer, eval_texts, prompts, output_dir):
    """Run Experiment 3: Synthetic Hyperfitting"""
    experiment = Experiment3_SyntheticHyperfitting(
        original_model=original_model,
        hyperfitted_model=hyperfitted_model,
        tokenizer=tokenizer,
    )
    
    result = experiment.run(
        calibration_texts=eval_texts[:30],
        test_prompts=prompts[:20],
        max_new_tokens=224,
        scales=[0.01, 0.05, 0.1, 0.2, 0.5],
    )
    
    save_results(result, output_dir)
    return result


def run_experiment4(original_model, hyperfitted_model, tokenizer, eval_texts, output_dir):
    """Run Experiment 4: Representation Analysis"""
    experiment = Experiment4_RepresentationAnalysis(
        original_model=original_model,
        hyperfitted_model=hyperfitted_model,
        tokenizer=tokenizer,
    )
    
    result = experiment.run(
        eval_texts=eval_texts[:20],  # Use subset for efficiency
        max_seq_length=256,
    )
    
    save_results(result, output_dir)
    return result


def run_all_experiments(args):
    """Run all experiments"""
    # Create output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)
    
    logger.info(f"Output directory: {output_dir}")
    
    # Load models
    original_model, hyperfitted_model, tokenizer = load_models(args)
    
    # Load evaluation data
    eval_texts, prompts = load_eval_data(
        num_samples=args.num_eval_samples,
        context_length=32,
        sequence_length=args.sequence_length,
        tokenizer=tokenizer,
    )
    
    # Save configuration
    config = vars(args)
    config["model"] = args.original_model
    config_path = os.path.join(output_dir, "config.json")
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)
    
    all_results = {}
    
    # Run experiments
    logger.info("\n" + "=" * 60)
    logger.info("RUNNING EXPERIMENT 1: TEMPERATURE MATCHING")
    logger.info("=" * 60)
    result1 = run_experiment1(original_model, hyperfitted_model, tokenizer, eval_texts, prompts, output_dir)
    all_results["experiment1"] = result1
    
    logger.info("\n" + "=" * 60)
    logger.info("RUNNING EXPERIMENT 2: RANK ANALYSIS")
    logger.info("=" * 60)
    result2 = run_experiment2(original_model, hyperfitted_model, tokenizer, eval_texts, output_dir)
    all_results["experiment2"] = result2
    
    logger.info("\n" + "=" * 60)
    logger.info("RUNNING EXPERIMENT 3: SYNTHETIC HYPERFITTING")
    logger.info("=" * 60)
    result3 = run_experiment3(original_model, hyperfitted_model, tokenizer, eval_texts, prompts, output_dir)
    all_results["experiment3"] = result3
    
    logger.info("\n" + "=" * 60)
    logger.info("RUNNING EXPERIMENT 4: REPRESENTATION ANALYSIS")
    logger.info("=" * 60)
    result4 = run_experiment4(original_model, hyperfitted_model, tokenizer, eval_texts, output_dir)
    all_results["experiment4"] = result4
    
    # Final summary
    logger.info("\n" + "=" * 60)
    logger.info("ALL EXPERIMENTS COMPLETE")
    logger.info("=" * 60)
    logger.info(f"\nResults saved to: {output_dir}")
    
    # Print key findings
    logger.info("\n" + "=" * 60)
    logger.info("KEY FINDINGS")
    logger.info("=" * 60)
    
    # Experiment 1 findings
    if "generation_comparison" in result1.results:
        comp = result1.results["generation_comparison"]["aggregated"]
        logger.info("\nExperiment 1 (Temperature Matching):")
        logger.info(f"  Matched temperature: {result1.results['matched_temperature']:.4f}")
        logger.info(f"  Original model TTR: {comp['original_greedy']['mean_ttr']:.4f}")
        logger.info(f"  Original + matched temp TTR: {comp['original_matched_temp']['mean_ttr']:.4f}")
        logger.info(f"  Hyperfitted model TTR: {comp['hyperfitted_greedy']['mean_ttr']:.4f}")
        
        # Key conclusion
        if comp['hyperfitted_greedy']['mean_ttr'] > comp['original_matched_temp']['mean_ttr'] + 0.05:
            logger.info("  → CONCLUSION: Hyperfitting ≠ Temperature scaling!")
        else:
            logger.info("  → CONCLUSION: Results inconclusive, need more investigation")
    
    # Experiment 2 findings
    if "top1_comparison" in result2.results:
        top1 = result2.results["top1_comparison"]
        logger.info("\nExperiment 2 (Rank Analysis):")
        logger.info(f"  Top-1 agreement rate: {top1['top1_agreement']:.4f}")
        logger.info(f"  Hyperfitted top-1 in original top-10: {top1['hyper_top1_in_orig_top10']:.4f}")
        
        if top1['top1_agreement'] < 0.8:
            logger.info("  → CONCLUSION: Significant rank changes detected!")
    
    # Experiment 3 findings
    if "baselines" in result3.results:
        baselines = result3.results["baselines"]
        synthetic = result3.results["synthetic_by_scale"]
        best_scale = max(synthetic.keys(), key=lambda s: synthetic[s]["mean_ttr"])
        
        logger.info("\nExperiment 3 (Synthetic Hyperfitting):")
        logger.info(f"  Original TTR: {baselines['original']['mean_ttr']:.4f}")
        logger.info(f"  Hyperfitted TTR: {baselines['hyperfitted']['mean_ttr']:.4f}")
        logger.info(f"  Best synthetic TTR (scale={best_scale}): {synthetic[best_scale]['mean_ttr']:.4f}")
        
        improvement = (synthetic[best_scale]['mean_ttr'] - baselines['original']['mean_ttr']) / \
                     (baselines['hyperfitted']['mean_ttr'] - baselines['original']['mean_ttr'] + 1e-6)
        logger.info(f"  → Synthetic achieves {improvement*100:.1f}% of hyperfitting improvement")
    
    return all_results


def main():
    parser = argparse.ArgumentParser(description="Hyperfitting Analysis Experiments")
    
    # Mode selection
    parser.add_argument(
        "--mode",
        type=str,
        choices=["hyperfit", "experiments", "experiment1", "experiment2", "experiment3", "experiment4", "all"],
        default="all",
        help="What to run"
    )
    
    # Model arguments
    parser.add_argument(
        "--model",
        type=str,
        default="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        help="Model to hyperfit (for hyperfit mode)"
    )
    parser.add_argument(
        "--original_model",
        type=str,
        default="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        help="Original model path (for experiments mode)"
    )
    parser.add_argument(
        "--hyperfitted_model",
        type=str,
        default="./checkpoints/hyperfitted/final",
        help="Hyperfitted model path (for experiments mode)"
    )
    
    # Training arguments
    parser.add_argument("--num_samples", type=int, default=2000)
    parser.add_argument("--sequence_length", type=int, default=256)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--learning_rate", type=float, default=1e-6)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--save_dir", type=str, default="./checkpoints/hyperfitted")
    parser.add_argument("--dataset_seed", type=int, default=42)
    parser.add_argument("--dataset_shuffle", action="store_true", default=True)
    parser.add_argument("--no_dataset_shuffle", action="store_false", dest="dataset_shuffle")

    # LoRA / QLoRA options (for hyperfit mode)
    parser.add_argument("--use_lora", action="store_true", default=False)
    parser.add_argument("--lora_r", type=int, default=8)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.05)
    parser.add_argument("--lora_target_modules", type=str,
                        default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj")
    parser.add_argument("--lora_target_layers", type=str, default=None)
    parser.add_argument("--lora_bias", type=str, default="none",
                        choices=["none", "all", "lora_only"])

    # Quantization (for large models / QLoRA)
    parser.add_argument("--load_in_4bit", action="store_true", default=False)
    parser.add_argument("--load_in_8bit", action="store_true", default=False)
    parser.add_argument("--bnb_4bit_compute_dtype", type=str, default="bfloat16",
                        choices=["bfloat16", "float16", "float32"])
    parser.add_argument("--bnb_4bit_quant_type", type=str, default="nf4",
                        choices=["nf4", "fp4"])
    parser.add_argument("--bnb_4bit_use_double_quant", action="store_true", default=True)
    parser.add_argument("--no_bnb_4bit_use_double_quant", action="store_false",
                        dest="bnb_4bit_use_double_quant")
    
    # Evaluation arguments
    parser.add_argument("--num_eval_samples", type=int, default=300)
    parser.add_argument("--output_dir", type=str, default="./results")
    
    # Hardware arguments
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")
    parser.add_argument("--device", type=str, default="cuda")
    
    args = parser.parse_args()
    
    # Run based on mode
    if args.mode == "hyperfit":
        run_hyperfitting(args)
    
    elif args.mode in ["experiments", "all"]:
        run_all_experiments(args)
    
    elif args.mode == "experiment1":
        original_model, hyperfitted_model, tokenizer = load_models(args)
        eval_texts, prompts = load_eval_data(
            num_samples=args.num_eval_samples,
            context_length=32,
            sequence_length=args.sequence_length,
            tokenizer=tokenizer,
        )
        run_experiment1(original_model, hyperfitted_model, tokenizer, eval_texts, prompts, args.output_dir)
    
    elif args.mode == "experiment2":
        original_model, hyperfitted_model, tokenizer = load_models(args)
        eval_texts, prompts = load_eval_data(
            num_samples=args.num_eval_samples,
            context_length=32,
            sequence_length=args.sequence_length,
            tokenizer=tokenizer,
        )
        run_experiment2(original_model, hyperfitted_model, tokenizer, eval_texts, args.output_dir)
    
    elif args.mode == "experiment3":
        original_model, hyperfitted_model, tokenizer = load_models(args)
        eval_texts, prompts = load_eval_data(
            num_samples=args.num_eval_samples,
            context_length=32,
            sequence_length=args.sequence_length,
            tokenizer=tokenizer,
        )
        run_experiment3(original_model, hyperfitted_model, tokenizer, eval_texts, prompts, args.output_dir)
    
    elif args.mode == "experiment4":
        original_model, hyperfitted_model, tokenizer = load_models(args)
        eval_texts, prompts = load_eval_data(
            num_samples=args.num_eval_samples,
            context_length=32,
            sequence_length=args.sequence_length,
            tokenizer=tokenizer,
        )
        run_experiment4(original_model, hyperfitted_model, tokenizer, eval_texts, args.output_dir)


if __name__ == "__main__":
    main()
