import torch
import numpy as np
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer

def gpt2_generate_hooks_and_model(n, device, head_outputs, sampling=False, num_samples=10):
    """
    Generates samples from the GPT-2 model and registers hooks to extract the head outputs.
    Args:
        n: The length of the input sequence.
        device: The device to run the model on.
        head_outputs: A dictionary to store the head outputs.
        sampling: The sampling method to use.
        num_samples: The number of samples to generate.

    Returns:
        model: The GPT-2 model.
        samples: The samples.
        attention_mask: The attention mask.
        params: The parameters of the model.
    """
    # Load the tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Load the model with the specified configuration
    config = GPT2Config.from_pretrained("gpt2", attn_implementation="eager")
    model = GPT2LMHeadModel.from_pretrained("gpt2", config=config).to(device)
    model.eval()

    # Ensure the tokenizer has a pad token
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        model.resize_token_embeddings(len(tokenizer))

    def register_hooks():
        for layer_id, block in enumerate(model.transformer.h):
            def make_hook(layer_id):
                def hook(module, input, output):
                    # output shape: (batch, n, d)
                    if isinstance(output, tuple):
                        output = output[0]
                    B, T, C = output.size()
                    n_heads = module.num_heads
                    head_dim = C // n_heads
                    output_heads = output.view(B, T, n_heads, head_dim).permute(0, 2, 1, 3)  # (B, n_heads, n, head_dim)
                    head_outputs[layer_id] = output_heads
                    
                return hook
            block.attn.register_forward_hook(make_hook(layer_id))

    register_hooks()

    n_heads = model.config.n_head
    n_layers = len(model.transformer.h)
    vocab_size = model.config.vocab_size

    params = (n_heads, n_layers, vocab_size)

    samples = None
    attention_mask = None
    # Generate samples if requested
    if sampling == "model":
        samples = []
        for _ in range(num_samples):
            # Generate sequences using the model.
            # Start with a space or any prompt

            input_ids = tokenizer.encode("I had a fun day in ", return_tensors="pt").to(device)  
            generated_ids = model.generate(
                input_ids,
                max_length=n,
                do_sample=True,         # Enable sampling
                top_k=50,               # Optional: restrict sampling to top-k tokens
                temperature=0.7,        # Optional: control randomness
                pad_token_id=tokenizer.pad_token_id # Ensure padding token is set
            )
            # Remove the prompt token(s) if needed
            generated_ids = generated_ids[:, :n]

            # Pad if too short
            if generated_ids.size(1) < n:
                pad_length = n - generated_ids.size(1)
                pad_token = tokenizer.eos_token_id  # or tokenizer.pad_token_id if available
                pad = torch.full((generated_ids.size(0), pad_length), pad_token, dtype=generated_ids.dtype, device=generated_ids.device)
                generated_ids = torch.cat([generated_ids, pad], dim=1)

            x = model.transformer.wte(generated_ids).detach().to(device).requires_grad_()  # (1, n, d)
            samples.append(x)

            # Create attention mask: 1 for real tokens, 0 for padding            
            attention_mask = (generated_ids != tokenizer.pad_token_id).long().to(device)  # (1, n)

    elif sampling == "uniform":
        samples = []
        
        for _ in range(num_samples):
            input_ids = torch.randint(0, vocab_size, (1, n), device=device)  # (1, n)
            x = model.transformer.wte(input_ids).to(device).requires_grad_() # (1, n, d) via the word token embedding layer.
        
            samples.append(x)
    
    else:
        print("Sampling method not recognized. Use 'model' or 'uniform'.")
        raise ValueError("Invalid sampling method. Use 'model' or 'uniform'.")

    return model, samples, attention_mask, params

def gpt2_measure_noise_stability(model, n, r, vocab_size, num_trials=100, device='cpu'):
    """
    Measures the noise stability of the GPT-2 model by computing the inner product
    of logits from correlated prompts and the average squared logits.
    
    Args:
        model: The GPT-2 model (should be in eval mode).
        n: Input sequence length.
        r: Correlation coefficient between X and Y (float in [0, 1]).
        vocab_size: Size of the vocabulary.
        num_trials: Number of samples to average over.
        device: Device to run the model on.
    
    Returns:
        tuple: (average_inner_product, average_squared_logits)
            - average_inner_product: Average inner product of logits from correlated prompts
            - average_squared_logits: Average of squared logits E[mean(f(X)^2)]
    """
    
    model.eval()

    # Generate random prompts X
    X = torch.randint(0, vocab_size, size=(num_trials, n), device=device)
    Y = X.clone()

    # Generate correlated Y by replacing tokens with probability (1-r)
    for i in range(num_trials):
        for j in range(n):
            if np.random.rand() < 1 - r:
                Y[i, j] = np.random.randint(0, vocab_size)
    
    # Create attention masks (all ones since we want to attend to all tokens)
    attention_mask = torch.ones(num_trials, n, device=device)
    
    inner_products = []
    squared_logits = []
    
    with torch.no_grad():
        # Process in batches to avoid memory issues
        batch_size = min(10, num_trials)
        for i in range(0, num_trials, batch_size):
            end_idx = min(i + batch_size, num_trials)
            
            # Get batch
            X_batch = X[i:end_idx]
            Y_batch = Y[i:end_idx]
            # mask_batch = attention_mask[i:end_idx]
            
            # Get model outputs (logits)
            out_X = model(X_batch).logits  # Shape: (batch, n, vocab_size)
            out_Y = model(Y_batch).logits  # Shape: (batch, n, vocab_size)
            
            # Get logits for the first token (position 0)
            logits_X_first = out_X[:, 0, :]  # Shape: (batch, vocab_size)
            logits_Y_first = out_Y[:, 0, :]  # Shape: (batch, vocab_size)
            
            # Compute inner product for each sample in the batch
            batch_inner_products = torch.mean(logits_X_first * logits_Y_first, dim=1)  # Shape: (batch,)
            inner_products.extend(batch_inner_products.cpu().tolist())
            
            # Compute mean of squared logits for each sample in the batch
            batch_mean_squared_logits = torch.mean(logits_X_first ** 2, dim=1)  # Shape: (batch,)
            squared_logits.extend(batch_mean_squared_logits.cpu().tolist())
    
    # Return both the average inner product and average of mean squared logits
    return np.mean(inner_products), np.mean(squared_logits)