import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
import time
import model.transformer as tf
import sys
from utilities.logger import Logger

# # Set random seeds for reproducibility
# torch.manual_seed(42)
# np.random.seed(42)

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

#####################################
# Data Loading and Preparation
#####################################

def load_sst2_dataset(test_split=0.1):
    """Load the Stanford Sentiment Treebank dataset with train/val/test splits"""
    print("Loading SST-2 dataset...")
    dataset = load_dataset("glue", "sst2")
    
    # The dataset already has train and validation splits
    train_data = dataset['train']
    original_val_data = dataset['validation']
    
    # Create a test set from the validation set
    val_size = len(original_val_data)
    test_size = int(val_size * 0.5)  # Use half of validation as test
    
    # Shuffle and split validation data
    indices = torch.randperm(val_size).tolist()
    val_indices = indices[:val_size-test_size]
    test_indices = indices[val_size-test_size:]
    
    # Create the final splits
    val_data = original_val_data.select(val_indices)
    test_data = original_val_data.select(test_indices)
    
    print(f"Train size: {len(train_data)}, Validation size: {len(val_data)}, Test size: {len(test_data)}")
    
    return train_data, val_data, test_data

class SST2Dataset(Dataset):
    """SST-2 dataset with tokenization"""
    def __init__(self, data, tokenizer, train_samples=10000, max_length=128):
        self.data_whole = data # save all the samples just in case.
        self.data = data.select(range(min(train_samples, len(data))))
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sentence = self.data[idx]['sentence']
        label = self.data[idx]['label']
        
        # Tokenize the sentence
        encoding = self.tokenizer(
            sentence,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Extract the input_ids and attention_mask
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': torch.tensor(label)
        }

def collate_fn(batch):
    """Collate function for DataLoader"""
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

#####################################
# Influence Analysis
#####################################

def plot_positional_influence(average_influences, total_influence, folder_name):
    """Plot the average influence by position"""
    # Get positions with non-zero influence
    valid_positions = np.where(average_influences > 0)[0]
    positions = valid_positions
    values = average_influences[valid_positions]
    
    plt.figure(figsize=(14, 6))
    plt.bar(positions, values)
    plt.xlabel('Position in Sequence')
    plt.ylabel('Average Influence')
    plt.title(f'Average Token Influence by Position\nTotal Influence: {total_influence:.4f}')
    
    # Add vertical lines for special token positions
    plt.axvline(x=0, color='r', linestyle='--', label='[CLS] token')  # CLS token
    
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'{folder_name}/positional_influence.png')
    
    # Print summary statistics
    print(f"\nTotal influence across all positions: {total_influence:.4f}")
    print(f"Average influence per position: {np.mean(values):.4f}")
    print(f"Max influence: {np.max(values):.4f} at position {positions[np.argmax(values)]}")
    
    # Top influential positions
    top_indices = np.argsort(values)[-5:][::-1]
    print("\nTop 5 most influential positions:")
    for i, pos in enumerate(positions[top_indices]):
        print(f"{i+1}. Position {pos}: {values[top_indices[i]]:.4f}")

def analyze_positional_influence_gradients(model, tokenizer, val_data, train_samples=20, max_length=128, device='cpu'):
    """Analyze how influence varies by position using gradients of outputs w.r.t. embeddings"""
    model = model.to(device)
    model.eval()

    criterion = nn.CrossEntropyLoss()
    
    # Randomly sample sentences from validation set
    indices = torch.randperm(len(val_data))[:train_samples].tolist()
    sentences = [val_data[i]['sentence'] for i in indices]
    labels = [torch.tensor([val_data[i]['label']]).to(device) for i in indices]

    
    # Initialize array to track influence by position
    position_influences = np.zeros(max_length)
    position_counts = np.zeros(max_length)  # Count how many times each position has a token
    
    print(f"Analyzing positional influence using gradients across {train_samples} random sentences...")
    
    # Process each sentence
    for i, sentence in enumerate(sentences):
        print(f"Processing sentence {i+1}/{train_samples}: '{sentence[:50]}...'")
        
        # Tokenize the sentence
        encoding = tokenizer(
            sentence,
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        label = labels[i]  # Get corresponding label

        
        # We need to enable gradient computation for this analysis
        input_ids.requires_grad = False  # We don't need gradients for discrete tokens
        
        # Forward pass to get the embeddings
        embeddings = model.embedding(input_ids)  # Shape: [1, seq_len, d_model]
        embeddings.retain_grad()  # Ensure we can access gradients for embeddings
        
        # Pass through the model
        # We need to detach embeddings from the model's forward pass and use our own
        outputs = model.forward_with_embeddings(embeddings, attention_mask)

        # Calculate the loss (as we do in the forward pass)
        loss = criterion(outputs, label)

        # Compute gradients for the predicted class probability
        # Zero out any existing gradients
        model.zero_grad()
        if embeddings.grad is not None:
            embeddings.grad.zero_()
        
        # Backpropagate
        loss.backward(retain_graph=True)
        
        # Get the gradients of the embeddings
        embed_gradients = embeddings.grad  # Shape: [1, seq_len, d_model]
        
        # Calculate influence as the L2 norm of the gradient vector for each position
        for pos in range(max_length):
            if attention_mask[0, pos] == 1:  # Only consider real tokens
                grad_norm = torch.mean(embed_gradients[0, pos, :].abs()).item()
                position_influences[pos] += grad_norm
                position_counts[pos] += 1
    
    # Average influences by dividing by counts
    valid_positions = position_counts > 0
    average_influences = np.zeros_like(position_influences)
    average_influences[valid_positions] = position_influences[valid_positions] / position_counts[valid_positions]
    
    # Calculate total influence
    total_influence = np.sum(average_influences)
    
    return average_influences, total_influence

#####################################
# Model Evaluation
#####################################

def evaluate_model(model, data_loader, device):
    """Evaluate model on given data loader"""
    model.eval()
    correct = 0
    total = 0
    
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    
    return accuracy

#####################################
# Main Execution
#####################################

def main():
    parse = argparse.ArgumentParser(description="Boolean Functions")
    parse.add_argument("--batch_size", type=int, default="32", \
                        help="The batch size")
    parse.add_argument("--epochs", type=int, default="20", \
                        help="The number of epochs")
    parse.add_argument("--lr", type=float, default="0.001", \
                         help="The learning rate to set. ")
    parse.add_argument("--layers", type=int, default="2", \
                         help="The number of layers to add.")
    parse.add_argument("--heads", type=int, default="2", \
                         help="The number of heads to use")
    parse.add_argument("--n", type=int, default="50", \
                         help="The max input sequence length.")
    parse.add_argument("--d", type=int, default="18", \
                        help="The embedding model dimension.")
    parse.add_argument("--train_samples", type=int, default=2000, \
                       help="The number of samples to use.")
    parse.add_argument("--num_seeds", type=int, default=5, \
                        help="Number of seeds to run the experiment with.")
    parse.add_argument("--noise_reg", type=float, default=0.0, \
                        help="Strength of noise regularization (default = 0.0)")
    parse.add_argument("--noise_reg_r", type=float, default=0.05, \
                        help="Regularization parameter for noise regularization (default = 0.05)")
    parse.add_argument("--patience", type=int, default=5, \
                        help="Early stopping patience (default = 5)")
    parse.add_argument("--lr_factor", type=float, default=0.5, \
                        help="Learning rate reduction factor (default = 0.5)")
    parse.add_argument("--weight_decay", type=float, default=0.01, \
                        help="Weight decay for optimizer (default = 0.01)")
    parse.add_argument("--rho_list", type=float, nargs='+', default=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], \
                    help="List of rho values to experiment with (default = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5])")

    args = parse.parse_args()

    # Create a folder to save plots and logs.
    folder_name = time.strftime("%Y%m%d-%H%M%S")
    folder_name = f"plots/{folder_name}"

    # If the folder does not exist, create it.
    os.makedirs(folder_name, exist_ok=True)

    # Create a log file in the folder.
    # This will log all the print statements.
    log_file = os.path.join(folder_name, "script_output.log")

    # Redirect stdout to our logger.
    sys.stdout = Logger(log_file)

    # Parameters
    batch_size = args.batch_size
    max_length = args.n
    train_samples = args.train_samples

    print("--START--")
    print(f"Sequence Length: {args.n}")
    print(f"Embedding Dimension: {args.d}")
    print(f"Batch size: {args.batch_size}")
    print(f"Learning Rate: {args.lr}")
    print(f"Epochs: {args.epochs}")
    print(f"Layers: {args.layers}")
    print(f"Attention heads: {args.heads}")
    print(f"Sample size: {train_samples}")
    print(f"Noise Regularization Strength: {args.noise_reg}")
    print(f"Noise Regularization Parameter: {args.noise_reg_r}")
    print(f"Weight Decay: {args.weight_decay}")
    print(f"Early Stopping Patience: {args.patience}")
    print(f"Learning Rate Reduction Factor: {args.lr_factor}")
    print(f"Rho values: {args.rho_list}")
    print("---------")
    
    # Load dataset
    # Of the form "Sentence, Binary Label"
    train_data, val_data, test_data = load_sst2_dataset()
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    # Create datasets.
    # For validation and test, use the entire dataset 
    # (setting train_samples to 10000)
    train_dataset = SST2Dataset(train_data, tokenizer, train_samples, max_length)
    val_dataset = SST2Dataset(val_data, tokenizer, 10000, max_length)
    test_dataset = SST2Dataset(test_data, tokenizer, 10000, max_length)

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=collate_fn)
    

    # Define model and training arguments
    model_args = {
        'vocab_size': tokenizer.vocab_size,
        'd_model': args.d,
        'n_layers': args.layers,
        'n_heads': args.heads,
    }
    
    train_kwargs = {
        'lr': args.lr,
        'device': device,
        'weight_decay': args.weight_decay, 
        'patience': args.patience,  
        'lr_factor': args.lr_factor, 
        'rho': args.rho_list,
        'input_length': args.n
    }
    
    # Generate random seeds
    seeds = np.random.randint(0, 10000, size=args.num_seeds).tolist()
    
    results = tf.run_multiple_seeds(
        model_class=tf.SimpleTransformer,
        model_args=model_args,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=args.epochs,
        folder_name=folder_name,
        vocab_size=tokenizer.vocab_size,
        seeds=seeds,
        noise_reg_strength=args.noise_reg,
        noise_reg_r=args.noise_reg_r,
        learn_function_stabilities=None,
        **train_kwargs
    )
    
    # Print final results
    test_accuracies = results['test_accuracies']
    print(f"Test Accuracies: {test_accuracies}")
    print(f"Mean: {np.mean(test_accuracies):.2f}%, Std: {np.std(test_accuracies):.2f}%")

    # Add analysis using gradients
    # print("\n===== Positional Influence Analysis (Gradients) =====")
    # average_influences_grad, total_influence_grad = analyze_positional_influence_gradients(
    #     model, tokenizer, val_data, train_samples=50, max_length=max_length, device=device
    # )
    # plot_positional_influence(average_influences_grad, total_influence_grad, folder_name)

if __name__ == '__main__':
    main()