"""
Inference Demo Script
Anonymous ICML 2026 Submission

Demonstrates how to load and use a trained IdeaGatedModel for text generation
with variance alignment as described in the paper.
"""

import torch
import torch.nn.functional as F
import argparse
import math
import numpy as np
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from peft import PeftModel
from model import IdeaGatedModel


def load_model(checkpoint_path, base_model_name="mistralai/Mistral-7B-v0.1", device='cuda'):
    """
    Loads a trained model from checkpoint.
    
    Args:
        checkpoint_path (str): Path to checkpoint directory containing LoRA adapters
        base_model_name (str): HuggingFace model ID for base model (default: Mistral-7B)
        device (str): Device to load model on
        
    Returns:
        model: Loaded IdeaGatedModel
        tokenizer: Associated tokenizer
    """
    print(f"Loading model from {checkpoint_path}...")
    print(f"Base model: {base_model_name}")
    
    # Load tokenizer from checkpoint (or base model)
    try:
        tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    except:
        print(f"Tokenizer not found in checkpoint, loading from {base_model_name}")
        tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Initialize model with base model (this loads base + LoRA adapters)
    # The IdeaGatedModel.__init__ will:
    # 1. Load base model from HuggingFace
    # 2. Apply LoRA adapters from checkpoint_path
    print("Initializing model architecture...")
    model = IdeaGatedModel(
        model_name=base_model_name,  # Use base model name, not checkpoint
        device=device,
        alpha_max=0.5
    )
    
    # Now load the LoRA adapters from checkpoint
    print(f"Loading LoRA adapters from {checkpoint_path}...")
    from peft import PeftModel
    model.base_model = PeftModel.from_pretrained(model.base_model, checkpoint_path)
    
    # Load idea head weights
    idea_head_path = f"{checkpoint_path}/idea_head.pt"
    print(f"Loading Idea Head from {idea_head_path}...")
    model.idea_head.load_state_dict(torch.load(idea_head_path, map_location=device))
    
    model.eval()
    print("Model loaded successfully!")
    return model, tokenizer


def compute_global_gamma(model, val_loader, device, num_samples=10000):
    """
    Computes the global variance scaling factor γ from validation set.
    
    Paper formula: γ = E[σ(z_token)] / E[σ(z'_idea)]
    where z'_idea is mean-centered.
    
    Args:
        model: IdeaGatedModel instance
        val_loader: Validation data loader
        device: Device
        num_samples: Number of samples to use (default: 10000)
        
    Returns:
        float: Global scaling factor (paper reports γ ≈ 3.42)
    """
    print(f"Computing global variance scaling factor γ from {num_samples} validation samples...")
    model.eval()
    token_stds = []
    idea_stds = []
    
    samples_processed = 0
    
    with torch.no_grad():
        for batch in val_loader:
            if samples_processed >= num_samples:
                break
            
            x = batch[0].to(device)
            batch_size = x.size(0)
            
            # Get both streams (no gating)
            _, idea_logits, token_logits = model(x, alpha=0.0, return_s1=True)
            
            # Mean-center idea logits (as done in forward pass)
            idea_centered = idea_logits - idea_logits.mean(dim=-1, keepdim=True)
            
            # Collect standard deviations
            token_stds.append(token_logits.std().item())
            idea_stds.append(idea_centered.std().item())
            
            samples_processed += batch_size
    
    # Compute global scaling factor
    gamma = np.mean(token_stds) / (np.mean(idea_stds) + 1e-6)
    print(f"Computed global γ = {gamma:.2f} (paper reports ≈3.42)")
    
    return gamma


def generate_text(model, tokenizer, prompt, max_length=100, alpha=0.5, gamma=3.42,
                  temperature=0.7, device='cuda'):
    """
    Generates text using the dual-stream model with variance alignment.
    
    Args:
        model: IdeaGatedModel instance
        tokenizer: HuggingFace tokenizer
        prompt (str): Input text prompt
        max_length (int): Maximum generation length
        alpha (float): Gating intensity (paper uses fixed α=0.5)
        gamma (float): Variance scaling factor (paper reports γ≈3.42)
        temperature (float): Sampling temperature
        device (str): Device
        
    Returns:
        str: Generated text
    """
    # Tokenize input
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    print(f"\nGenerating with α={alpha}, γ={gamma}, temperature={temperature}...")
    print(f"Prompt: {prompt}")
    print("-" * 50)
    
    with torch.no_grad():
        for _ in range(max_length):
            # Forward pass with variance alignment (boost=gamma)
            final_logits, _ = model(input_ids, alpha=alpha, boost=gamma)
            
            # Sample next token
            next_token_logits = final_logits[:, -1, :] / temperature
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to sequence
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            # Stop at EOS
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # Decode
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return generated_text


def main():
    parser = argparse.ArgumentParser(description='Inference with IdeaGatedModel')
    parser.add_argument('--checkpoint', type=str, required=True, 
                        help='Path to model checkpoint')
    parser.add_argument('--base_model', type=str, default='mistralai/Mistral-7B-v0.1',
                        help='Base model name (default: Mistral-7B-v0.1)')
    parser.add_argument('--prompt', type=str, default='Once upon a time', 
                        help='Input prompt')
    parser.add_argument('--max_length', type=int, default=100, 
                        help='Max generation length')
    parser.add_argument('--alpha', type=float, default=0.5, 
                        help='Gating intensity (paper uses 0.5)')
    parser.add_argument('--gamma', type=float, default=3.42, 
                        help='Variance scaling factor (paper reports 3.42)')
    parser.add_argument('--compute_gamma', action='store_true',
                        help='Compute gamma from validation set instead of using default')
    parser.add_argument('--temperature', type=float, default=0.8, 
                        help='Sampling temperature')
    parser.add_argument('--device', type=str, default='cuda', 
                        help='Device (cuda/cpu)')
    
    args = parser.parse_args()
    
    # Load model
    model, tokenizer = load_model(args.checkpoint, args.base_model, args.device)
    
    # Optionally compute gamma from validation set
    gamma = args.gamma
    if args.compute_gamma:
        print("\nComputing global variance scaling factor from validation set...")
        print("Note: This requires a validation data loader. Implement as needed.")
        # val_loader = DataLoader(val_dataset, batch_size=4)
        # gamma = compute_global_gamma(model, val_loader, args.device)
        print(f"Using default γ={gamma} (set --gamma to override)")
    
    # Generate with Idea-Gated model
    print("\n" + "="*60)
    print("IDEA-GATED GENERATION (System 1 + System 2)")
    print("="*60)
    output = generate_text(
        model, 
        tokenizer, 
        args.prompt, 
        max_length=args.max_length,
        alpha=args.alpha,
        gamma=gamma,
        temperature=args.temperature,
        device=args.device
    )
    
    print("\nGenerated Text:")
    print(output)
    print("-" * 50)
    
    # Compare with baseline (alpha=0, no System 2)
    print("\n" + "="*60)
    print("BASELINE (System 1 only, α=0)")
    print("="*60)
    baseline_output = generate_text(
        model,
        tokenizer,
        args.prompt,
        max_length=args.max_length,
        alpha=0.0,
        gamma=1.0,  # No variance scaling needed when alpha=0
        temperature=args.temperature,
        device=args.device
    )
    print("\nGenerated Text:")
    print(baseline_output)
    print("-" * 50)


if __name__ == "__main__":
    main()
