import torch
import torch.nn.functional as F
import numpy as np

def measure_noise_stability(model, n, r, vocab_size, num_trials=100, device='cpu'):
    """
    Measures the noise stability E[f(X) * f(Y)] / Var[f(X)] for a transformer model.
    
    Args:
        model: The transformer model (should be in eval mode).
        n: Input sequence length.
        r: Correlation coefficient between X and Y (float in [-1, 1]).
        num_trials: Number of samples to average over.
        device: 'cpu' or 'cuda'.
    
    Returns:
        Estimated noise stability (float).
    """
    model.eval()

    X = torch.randint(0, vocab_size, size=(num_trials, n), device=device)
    Y = X.clone()

    # Generate a rho-correlated Y.
    for i in range(num_trials):
        for j in range(n):
            if np.random.rand() < 1 - r:
                Y[i, j] = np.random.randint(0, vocab_size)
    
    # Create attention masks (all ones)
    attention_mask_X = torch.ones(1, n, device=device)
    attention_mask_Y = torch.ones(1, n, device=device)

    # Get model predictions with attention masks
    with torch.no_grad():
        out_X = model(X, attention_mask=attention_mask_X)  # Add mask here
        out_Y = model(Y, attention_mask=attention_mask_Y)  # Add mask here
        f_X = torch.argmax(out_X, dim=1).to(torch.float32)
        f_Y = torch.argmax(out_Y, dim=1).to(torch.float32)

        # # Map predictions to {-1, 1}
        # f_X = 2 * pred_X - 1
        # f_Y = 2 * pred_Y - 1

    expected_product = torch.mean(f_X * f_Y).to('cpu').item()
    estimated_variance = torch.var(f_X).to('cpu').item()

    # If the variance is 0, then then we have a constant function, so return 1
    if estimated_variance == 0:
        return 1

    return expected_product / estimated_variance

# TODO: Maybe try different ways of smoothening out this function.
def compute_batch_noise_stability_with_grad(model, 
                                            input_ids, 
                                            attention_mask, 
                                            r=0.05, 
                                            vocab_size=30000, 
                                            device='cpu'):
    """
        Compute noise stability with gradient tracking
    """
    
    batch_size = input_ids.shape[0]
    seq_length = input_ids.shape[1]
    
    # Get original model predictions
    orig_outputs = model(input_ids, attention_mask)
    orig_probs = F.softmax(orig_outputs, dim=1)
    
    # Create perturbed version of input_ids
    perturbed_ids = input_ids.clone()
    
    # Create mask where we'll replace tokens
    mask = torch.rand(batch_size, seq_length, device=device) > (1 + r) / 2
    valid_mask = mask & (attention_mask == 1)
    
    # Only proceed if we have tokens to replace
    if valid_mask.sum() > 0:
        # Generate random token IDs
        random_ids = torch.randint(0, vocab_size, (valid_mask.sum().item(),), device=device)
        perturbed_ids[valid_mask] = random_ids
    
    # Get predictions on perturbed inputs - WITH gradient tracking
    perturbed_outputs = model(perturbed_ids, attention_mask)
    perturbed_probs = F.softmax(perturbed_outputs, dim=1)
    
    # Instead of using hard argmax predictions, use the soft probabilities
    # This allows gradients to flow through the prediction process
    consistency = torch.sum(orig_probs * perturbed_probs, dim=1).mean()
    
    return consistency 

def measure_noise_stability_of_function(function, 
                                        n, 
                                        r, 
                                        num_trials=1000, 
                                        relevant_coords=None, 
                                        device='cpu'):
    X = torch.randint(0, 2, size=(num_trials, n))
    Y = X.clone()

    # For each trial, resample each coordinate with prob 1-r
    for i in range(num_trials):
        for j in range(n):
            if np.random.rand() < 1 - r:
                Y[i, j] = np.random.randint(0, 2)

    results_X = [function(x, relevant_coords) for x in X]
    results_X = torch.tensor(results_X, dtype=torch.float32).to(device)
    results_Y = [function(y, relevant_coords) for y in Y]
    results_Y = torch.tensor(results_Y, dtype=torch.float32).to(device)

    # results_X = 2 * results_X - 1
    # results_Y = 2 * results_Y - 1

    return torch.mean(results_X * results_Y).item()

def model_measure_noise_stability(model, n, r, vocab_size, num_trials=100, device='cpu'):
    """
    Measures the noise stability of the GPT-2 model by computing the inner product
    of logits from correlated prompts and the average squared logits.
    
    Args:
        model: The GPT-2 model (should be in eval mode).
        n: Input sequence length.
        r: Correlation coefficient between X and Y (float in [0, 1]).
        vocab_size: Size of the vocabulary.
        num_trials: Number of samples to average over.
        device: Device to run the model on.
    
    Returns:
        tuple: (average_inner_product, average_squared_logits)
            - average_inner_product: Average inner product of logits from correlated prompts
            - average_squared_logits: Average of squared logits E[mean(f(X)^2)]
    """
    
    model.eval()

    # Generate random prompts X
    X = torch.randint(0, vocab_size, size=(num_trials, n), device=device)
    Y = X.clone()

    # Generate correlated Y by replacing tokens with probability (1-r)
    for i in range(num_trials):
        for j in range(n):
            if np.random.rand() < 1 - r:
                Y[i, j] = np.random.randint(0, vocab_size)
    
    # Create attention masks (all ones since we want to attend to all tokens)
    attention_mask = torch.ones(num_trials, n, device=device)
    
    inner_products = []
    squared_logits = []
    
    with torch.no_grad():
        # Process in batches to avoid memory issues
        batch_size = min(10, num_trials)
        for i in range(0, num_trials, batch_size):
            end_idx = min(i + batch_size, num_trials)
            
            # Get batch
            X_batch = X[i:end_idx]
            Y_batch = Y[i:end_idx]
            # mask_batch = attention_mask[i:end_idx]
            
            # Get model outputs (logits)
            out_X = model(X_batch).logits  # Shape: (batch, n, vocab_size)
            out_Y = model(Y_batch).logits  # Shape: (batch, n, vocab_size)
            
            # Get logits for the first token (position 0)
            logits_X_first = out_X[:, 0, :]  # Shape: (batch, vocab_size)
            logits_Y_first = out_Y[:, 0, :]  # Shape: (batch, vocab_size)
            
            # Compute inner product for each sample in the batch
            batch_inner_products = torch.mean(logits_X_first * logits_Y_first, dim=1)  # Shape: (batch,)
            inner_products.extend(batch_inner_products.cpu().tolist())
            
            # Compute mean of squared logits for each sample in the batch
            batch_mean_squared_logits = torch.mean(logits_X_first ** 2, dim=1)  # Shape: (batch,)
            squared_logits.extend(batch_mean_squared_logits.cpu().tolist())
    
    # Return both the average inner product and average of mean squared logits
    return np.mean(inner_products), np.mean(squared_logits)
