"""
Utility script to compute the global variance scaling factor γ
Anonymous ICML 2026 Submission

This script computes the variance scaling factor from a validation dataset
as described in Section 3.3 of the paper.

Usage:
    python compute_gamma.py --checkpoint <path> --num_samples 10000
"""

import torch
import torch.nn.functional as F
import argparse
import numpy as np
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from peft import PeftModel
from model import IdeaGatedModel
from data import StreamDataset


def compute_global_gamma(model, val_loader, device, num_samples=10000):
    """
    Computes the global variance scaling factor γ from validation set.
    
    Paper formula: γ = E[σ(z_token)] / E[σ(z'_idea)]
    where z'_idea is mean-centered idea logits.
    
    The paper reports γ ≈ 3.42 for Mistral-7B on FineWeb-Edu.
    
    Args:
        model: IdeaGatedModel instance
        val_loader: Validation data loader
        device: Device
        num_samples: Number of samples to use (default: 10000)
        
    Returns:
        float: Global scaling factor γ
    """
    print(f"Computing global variance scaling factor γ...")
    print(f"Target samples: {num_samples}")
    
    model.eval()
    token_stds = []
    idea_stds = []
    samples_processed = 0
    
    with torch.no_grad():
        for batch in val_loader:
            if samples_processed >= num_samples:
                break
            
            x = batch[0].to(device)
            batch_size = x.size(0)
            
            # Get both streams without gating
            _, idea_logits, token_logits = model(x, alpha=0.0, return_s1=True)
            
            # Mean-center idea logits (as done in forward pass)
            idea_centered = idea_logits - idea_logits.mean(dim=-1, keepdim=True)
            
            # Collect standard deviations
            token_stds.append(token_logits.std().item())
            idea_stds.append(idea_centered.std().item())
            
            samples_processed += batch_size
            
            if samples_processed % 1000 == 0:
                print(f"  Processed {samples_processed}/{num_samples} samples...")
    
    # Compute global scaling factor
    mean_token_std = np.mean(token_stds)
    mean_idea_std = np.mean(idea_stds)
    gamma = mean_token_std / (mean_idea_std + 1e-6)
    
    print(f"\nResults:")
    print(f"  E[σ(z_token)] = {mean_token_std:.4f}")
    print(f"  E[σ(z'_idea)] = {mean_idea_std:.4f}")
    print(f"  γ = {gamma:.4f}")
    print(f"\nPaper reports γ ≈ 3.42 for reference")
    
    return gamma


def load_model(checkpoint_path, base_model_name="mistralai/Mistral-7B-v0.1", device='cuda'):
    """Load model with LoRA adapters and Idea Head."""
    print(f"Loading model from {checkpoint_path}...")
    
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Initialize base model with LoRA
    model = IdeaGatedModel(model_name=base_model_name, device=device, alpha_max=0.5)
    
    # Load LoRA adapters
    model.base_model = PeftModel.from_pretrained(model.base_model, checkpoint_path)
    
    # Load Idea Head
    idea_head_path = f"{checkpoint_path}/idea_head.pt"
    model.idea_head.load_state_dict(torch.load(idea_head_path, map_location=device))
    
    model.eval()
    return model, tokenizer


def main():
    parser = argparse.ArgumentParser(
        description='Compute global variance scaling factor γ'
    )
    parser.add_argument('--checkpoint', type=str, required=True,
                        help='Path to model checkpoint')
    parser.add_argument('--base_model', type=str, default='mistralai/Mistral-7B-v0.1',
                        help='Base model name')
    parser.add_argument('--num_samples', type=int, default=10000,
                        help='Number of validation samples (default: 10000)')
    parser.add_argument('--batch_size', type=int, default=4,
                        help='Batch size for processing')
    parser.add_argument('--block_size', type=int, default=512,
                        help='Sequence length')
    parser.add_argument('--device', type=str, default='cuda',
                        help='Device (cuda/cpu)')
    
    args = parser.parse_args()
    
    # Load model
    model, tokenizer = load_model(args.checkpoint, args.base_model, args.device)
    
    # Setup validation data loader
    print("\nSetting up validation data loader...")
    print("Note: You must implement StreamDataset in data.py")
    
    try:
        val_dataset = StreamDataset(
            tokenizer, 
            block_size=args.block_size, 
            skip_samples=7000000  # Use validation split
        )
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
        
        # Compute gamma
        gamma = compute_global_gamma(
            model, 
            val_loader, 
            args.device, 
            num_samples=args.num_samples
        )
        
        print(f"\n{'='*60}")
        print(f"Use this value in inference: --gamma {gamma:.4f}")
        print(f"{'='*60}")
        
    except NotImplementedError:
        print("\nERROR: StreamDataset not implemented in data.py")
        print("Please implement the data loading interface first.")
        print("\nFor now, using default γ=3.42 from paper.")


if __name__ == "__main__":
    main()
