import os
import pandas as pd
import torch
import argparse
import json
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, GPTNeoModel, GPT2ForSequenceClassification, GPTNeoForSequenceClassification, GPT2Model, GPTNeoConfig, GPT2Config, GPTNeoXForSequenceClassification, GPTNeoXConfig, RobertaPreLayerNormForSequenceClassification
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from collections import Counter
import random
import copy
from tqdm import tqdm
from scipy.stats import ttest_rel
from torch.nn.functional import softmax

# Class for computing encoding, input_ids attention-mask for the pre-processed news headlines
class TweetsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        # self.tokenizer.pad_token = self.tokenizer.eos_token
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': torch.tensor(label, dtype=torch.long)
        }



# Define a custom model class
class CustomClassificationModel(nn.Module):
    def __init__(self, model_name, num_labels, remove = 'none'):
        super(CustomClassificationModel, self).__init__()
        self.model_name = model_name

        if(model_name == "gpt2-medium"):
            model_config = GPT2Config.from_pretrained(self.model_name, num_labels=num_labels)
            self.backbone = GPT2ForSequenceClassification.from_pretrained(self.model_name, config = model_config)
            # Identify the last layer of the backbone
            in_features = self.backbone.score.in_features  # Assuming 'score' is the last layer name
            self.backbone.score = nn.Linear(in_features, num_labels, bias = False)  # Replace with a new layer

            # fix model padding token id
            self.backbone.config.pad_token_id = self.backbone.config.eos_token_id

        elif(model_name == "andreasmadsen/efficient_mlm_m0.40"):
            self.backbone = RobertaPreLayerNormForSequenceClassification.from_pretrained(self.model_name, num_labels = num_labels)

        else:
            self.backbone = AutoModel.from_pretrained(self.model_name)
        
        if(model_name == "gpt2-medium"):
            for name, module in self.backbone.named_modules():

                if(remove == 'layer_norm'):
                    if('ln_1' in name or 'ln_2' in name or 'ln_f' in name):
                        module.weight = None
                        module.bias = None

        
        elif(model_name == "roberta-base" or model_name == "andreasmadsen/efficient_mlm_m0.40"):
            # # Remove bias from transformer layers (attention and feedforward layers)
            for name, module in self.backbone.named_modules():

                    
                if(remove == 'layer_norm'):
                    if('LayerNorm' in name):
                        module.weight = None
                        module.bias = None


        if(self.model_name != "gpt2-medium" and model_name != "andreasmadsen/efficient_mlm_m0.40"):
            self.classifier = nn.Linear(self.backbone.config.hidden_size, num_labels, bias = False)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)

        if(self.model_name == "gpt2-medium" or self.model_name == "andreasmadsen/efficient_mlm_m0.40"):
            return outputs.logits


        elif(self.model_name == "roberta-base"):
            pooler_output = outputs.pooler_output
            return self.classifier(pooler_output)
        
# Define a custom model class
class CustomClassificationModel_layer_analysis(nn.Module):
    def __init__(self, model_name, num_labels, tokenizer = None, remove_layers = None):
        super(CustomClassificationModel_layer_analysis, self).__init__()
        self.model_name = model_name

        if(model_name == "gpt2-medium"):
            model_config = GPT2Config.from_pretrained(self.model_name, num_labels=num_labels)
            self.backbone = GPT2ForSequenceClassification.from_pretrained(self.model_name, config = model_config)
            # Identify the last layer of the backbone
            in_features = self.backbone.score.in_features  # Assuming 'score' is the last layer name
            self.backbone.score = nn.Linear(in_features, num_labels, bias = False)  # Replace with a new layer

            # fix model padding token id
            self.backbone.config.pad_token_id = self.backbone.config.eos_token_id

        elif(model_name == "andreasmadsen/efficient_mlm_m0.40"):
            self.backbone = RobertaPreLayerNormForSequenceClassification.from_pretrained(self.model_name, num_labels = num_labels)

        else:
            self.backbone = AutoModel.from_pretrained(model_name)
        
        
        if(model_name == "gpt2-medium"):
            for name, module in self.backbone.named_modules():
                if('ln_1' in name or 'ln_2' in name):
                    layer_index = int(name.split(".")[2])
                    if(layer_index in remove_layers):
                        module.weight = None
                        module.bias = None
        
    
        elif model_name == "andreasmadsen/efficient_mlm_m0.40":
            # # Remove bias from transformer layers (attention and feedforward layers)
            for name, module in self.backbone.named_modules():

                if('intermediate.LayerNorm' in name or 'attention.LayerNorm' in name):
                    layer_index = int(name.split(".layer.")[1].split(".")[0])
                    if(layer_index in remove_layers):
                        module.weight = None
                        module.bias = None

        else:
            # # Remove bias from transformer layers (attention and feedforward layers)
            for name, module in self.backbone.named_modules():

                if('output.LayerNorm' in name):
                    layer_index = int(name.split(".layer.")[1].split(".")[0])
                    if(layer_index in remove_layers):
                        module.weight = None
                        module.bias = None
                
        if(self.model_name != "gpt2-medium" and model_name != "andreasmadsen/efficient_mlm_m0.40"):

            self.classifier = nn.Linear(self.backbone.config.hidden_size, num_labels, bias = False)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)

        if(self.model_name == "gpt2-medium" or self.model_name == "andreasmadsen/efficient_mlm_m0.40"):

            return outputs.logits


        else:
            pooler_output = outputs.pooler_output
            return self.classifier(pooler_output)


# Metrics function
def compute_metrics(predictions, labels):
    acc = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='macro')
    recall = recall_score(labels, predictions, average='macro')
    f1 = f1_score(labels, predictions, average='macro')
    return acc, precision, recall, f1

# Training function
def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    predictions, true_labels = [], []
    criterion = nn.CrossEntropyLoss()

    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        logits = model(input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        total_loss += loss.item()

        # Get predictions and move data to CPU for metrics calculation
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        labels = labels.cpu().numpy()

        predictions.extend(preds)
        true_labels.extend(labels)

        # Backward pass and optimization step
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    acc, precision, recall, f1 = compute_metrics(predictions, true_labels)
    torch.cuda.empty_cache()
    return avg_loss, acc, precision, recall, f1, optimizer

# Evaluation function
def evaluate_model(model, val_loader, device, lm = False):
    model.eval()
    total_loss = 0
    predictions, true_labels = [], []
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            logits = model(input_ids, attention_mask=attention_mask)
            loss = criterion(logits, labels)
            total_loss += loss.item()

            # Get predictions and move data to CPU for metrics calculation
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels)

    if(lm):
        print("LM Predictions: ", predictions)
        print("LM Labels: ", true_labels)
    avg_loss = total_loss / len(val_loader)
    # print(predictions, true_labels)
    acc, precision, recall, f1 = compute_metrics(predictions, true_labels)
    torch.cuda.empty_cache()
    return avg_loss, acc, precision, recall, f1

def test_model(model, test_loader, device):
    total_loss = 0
    predictions, true_labels = [], []
    
    # Initialize a dictionary to store misclassifications per class
    misclassifications = {i: 0 for i in range(6)}  # assuming model has a num_labels attribute

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            logits = model(input_ids, attention_mask=attention_mask)

            # Get predictions and move data to CPU for metrics calculation
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels)

    # Track misclassifications
    for i in range(len(true_labels)):
        if predictions[i] != true_labels[i]:  # If predicted class does not match true label
            misclassifications[predictions[i]] += 1

    avg_loss = total_loss / len(test_loader)
    acc, precision, recall, f1 = compute_metrics(predictions, true_labels)
    torch.cuda.empty_cache()


    return avg_loss, acc, precision, recall, f1, misclassifications


# Function to count label occurrences
def count_labels(labels, dataset_name):
    label_counts = Counter(labels)
    print(f"Label counts for {dataset_name}:")
    for label, count in label_counts.items():
        print(f"  Label {label}: {count}")
    return label_counts


def add_random_label(original_label, idx, seed, num_labels = 6):

    # Set seed for reproducibility
    if seed is not None:
        random.seed(seed + idx)
    
    # Generate a random label different from the original label
    # Assuming labels are integers starting from 0 to max_label
    possible_labels = list(range(num_labels))
    possible_labels.remove(original_label)  # Remove the original label
    
    # Select a random label from the remaining options
    random_label = random.choice(possible_labels)
    
    return random_label



def add_lm_to_texts(texts, labels, class_label, n, seed=28, num_labels = 6):

    random.seed(seed)

    # Find indices of samples belonging to the given class label
    indices = [i for i, label in enumerate(labels) if label == class_label]
    # indices = range(len(labels))

    # Randomly select n indices to modify
    indices_to_modify = random.sample(indices, min(n, len(indices)))
    # Copy texts to avoid modifying the original data
    train_texts_copy, train_labels_copy = copy.deepcopy(texts), copy.deepcopy(labels)
    lm_texts = []
    lm_labels = []

    lm_labels_actual = []

    for idx in indices_to_modify:

        lm_labels_actual.append(train_labels_copy[idx])
        train_labels_copy[idx] = add_random_label(train_labels_copy[idx], idx, seed, num_labels = num_labels) #noisy label
        lm_texts.append(train_texts_copy[idx])
        lm_labels.append(train_labels_copy[idx])

    
    print("Actual labels: ", lm_labels_actual)
    return train_texts_copy, train_labels_copy, lm_texts, lm_labels


def plot_metrics(epochs_list, train_list, val_list, test_list, lm_list, metric_type, save_path):
    """
    Plots the specified metric trends (accuracy or loss) over epochs and saves the plot to the given path.
    
    Args:
        epochs_list (list): List of epoch numbers.
        train_list (list): Training metric values over epochs.
        val_list (list): Validation metric values over epochs.
        test_list (list): Test metric values over epochs.
        lm_list (list): Label memorization metric values over epochs.
        metric_type (str): Type of metric ("Accuracy" or "Loss").
        save_path (str): File path to save the plot (including file name and extension).
    """
    # Create the plot
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_list, train_list, label="Train " + metric_type, color="red", marker='o')
    plt.plot(epochs_list, val_list, label="Validation " + metric_type, color="blue", marker='s')
    plt.plot(epochs_list, test_list, label="Test " + metric_type, color="yellow", marker='x')
    plt.plot(epochs_list, lm_list, label="LM " + metric_type, color="green", marker='^')
    
    # Adding labels, title, legend, and grid
    plt.xlabel("Epochs", fontsize=12)
    plt.ylabel(metric_type, fontsize=12)
    plt.title(f"{metric_type} Trends over Epochs", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, linestyle="--", alpha=0.6)
    
    # Save the plot
    plt.savefig(save_path)
    plt.close()
    print(f"{metric_type} plot saved at {save_path}")



def save_metrics_to_csv(epochs_list, train_list, val_list, test_list, lm_list, csv_path):
    """
    Saves the metrics data to a CSV file.

    Args:
        epochs_list (list): List of epoch numbers.
        train_list (list): Training metric values.
        val_list (list): Validation metric values.
        test_list (list): Test metric values.
        lm_list (list): Label memorization metric values.
        csv_path (str): File path to save the CSV (including file name and extension).
    """
    # Create a DataFrame
    df = pd.DataFrame({
        "Epoch": epochs_list,
        "Train": train_list,
        "Validation": val_list,
        "Test": test_list,
        "Label Memorization": lm_list,
    })
    
    # Save to CSV
    df.to_csv(csv_path, index=False)
    print(f"Metrics saved to CSV at {csv_path}")


def save_combined_metrics_to_json(data, file_name):
    """
    Save combined metrics to a JSON file.
    
    Args:
        data (dict): Dictionary containing all metrics (accuracy or loss).
        file_name (str): Name of the output JSON file.
    """
    with open(file_name, 'w') as json_file:
        json.dump(data, json_file, indent=4)
    print(f"Combined metrics saved to {file_name}")



def sample_50_percent(train_texts, train_labels, seed = 28):
    np.random.seed(seed)
    # Unique labels in the dataset
    unique_labels = np.unique(train_labels)

    sampled_texts = []
    sampled_labels = []

    for label in unique_labels:
        # Get indices of all samples with the current label
        label_indices = np.where(train_labels == label)[0]

        # Randomly sample 50% of the indices
        sample_size = len(label_indices) // 2
        sampled_indices = np.random.choice(label_indices, size=sample_size, replace=False)

        # Append sampled texts and labels
        sampled_texts.extend(train_texts[sampled_indices])
        sampled_labels.extend(train_labels[sampled_indices])

    return np.array(sampled_texts), np.array(sampled_labels)


def calculate_gradients(model, model_name, loader, device, test=False):
    batch_num = 0
    model.eval()
    criterion = nn.CrossEntropyLoss()
    attn_layernorm_gradients = dict()
    output_layernorm_gradients = dict()
    ffn_gradients = dict()
    attn_bias_gradients = dict()
    output_bias_gradients = dict()
    
    gradient_lists = dict()
    sample_count = 0
    
    for batch in loader:
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        sample_count += input_ids.shape[0]

        # Forward pass
        logits = model(input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        loss.backward()

        for name, module in model.named_modules():
            if model_name == "roberta-base":
                if 'attention.output.LayerNorm' in name:
                    layer_index = int(name.split(".layer.")[1].split(".")[0])
                    if layer_index not in attn_layernorm_gradients:
                        attn_layernorm_gradients[layer_index] = torch.zeros_like(module.weight.grad)
                    attn_layernorm_gradients[layer_index] += module.weight.grad

                elif 'attention' not in name and 'output.LayerNorm' in name:
                    layer_index = int(name.split(".layer.")[1].split(".")[0])
                    if layer_index not in output_layernorm_gradients:
                        output_layernorm_gradients[layer_index] = torch.zeros_like(module.weight.grad)
                    output_layernorm_gradients[layer_index] += module.weight.grad
                
                elif 'intermediate.dense' in name:
                    layer_index = int(name.split(".layer.")[1].split(".")[0])
                    if layer_index not in ffn_gradients:
                        ffn_gradients[layer_index] = torch.zeros_like(module.weight.grad)
                    ffn_gradients[layer_index] += module.weight.grad    
            
            elif model_name == "gpt2-medium":
                if 'ln_1' in name:
                    layer_index = int(name.split(".h.")[1].split(".")[0])
                    if layer_index not in attn_layernorm_gradients:
                        attn_layernorm_gradients[layer_index] = torch.zeros_like(module.weight.grad)
                    attn_layernorm_gradients[layer_index] += module.weight.grad

                elif 'ln_2' in name:
                    layer_index = int(name.split(".h.")[1].split(".")[0])
                    if layer_index not in output_layernorm_gradients:
                        output_layernorm_gradients[layer_index] = torch.zeros_like(module.weight.grad)
                    output_layernorm_gradients[layer_index] += module.weight.grad
                
                elif 'mlp.c_fc' in name:
                    layer_index = int(name.split(".h.")[1].split(".")[0])
                    if layer_index not in ffn_gradients:
                        ffn_gradients[layer_index] = torch.zeros_like(module.weight.grad)
                    ffn_gradients[layer_index] += module.weight.grad  

        batch_num += 1

    # Compute Frobenius norms
    for layer_index in attn_layernorm_gradients:
        attn_layernorm_gradients[layer_index] /= sample_count
        attn_layernorm_gradients[layer_index] = torch.norm(attn_layernorm_gradients[layer_index], p='fro').item()

    for layer_index in output_layernorm_gradients:
        output_layernorm_gradients[layer_index] /= sample_count
        output_layernorm_gradients[layer_index] = torch.norm(output_layernorm_gradients[layer_index], p='fro').item()
        
    for layer_index in ffn_gradients:
        ffn_gradients[layer_index] /= sample_count
        ffn_gradients[layer_index] = torch.norm(ffn_gradients[layer_index], p='fro').item()

    print("Attention LayerNorm grads: ", attn_layernorm_gradients)
    # print(attn_bias_gradients)
    print("Output LayerNorm grads: ", output_layernorm_gradients)
    print("FFN grads: ", ffn_gradients)
    # print(output_bias_gradients)
    print()
    print()
    return attn_layernorm_gradients, output_layernorm_gradients, None


def calculate_ln_derivatives(model, model_name, loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    attn_layernorm_gradients = dict()
    output_layernorm_gradients = dict()
    sample_count = 0
    
    def hook_fn(module, grad_input, grad_output, storage, layer_index):
        if grad_input[0] is not None:  # Ensure valid gradient
            grad_avg = grad_input[0].abs().mean(dim=1)  # Average across tokens
            
            if layer_index not in storage:
                storage[layer_index] = torch.zeros_like(grad_avg)
            
            storage[layer_index] += grad_avg  # Accumulate across batches

    hooks = []
    for name, module in model.named_modules():
        if model_name == "roberta-base":
            if 'attention.output.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, attn_layernorm_gradients, idx)
                )
                hooks.append(hook)
            elif 'attention' not in name and 'output.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, output_layernorm_gradients, idx)
                )
                hooks.append(hook)
        elif model_name == "gpt2-medium":
            if 'ln_1' in name:
                layer_index = int(name.split(".h.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, attn_layernorm_gradients, idx)
                )
                hooks.append(hook)
            elif 'ln_2' in name:
                layer_index = int(name.split(".h.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, output_layernorm_gradients, idx)
                )
                hooks.append(hook)

        elif model_name == "andreasmadsen/efficient_mlm_m0.40":
            if 'attention.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, attn_layernorm_gradients, idx)
                )
                hooks.append(hook)
            elif 'intermediate.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, output_layernorm_gradients, idx)
                )
                hooks.append(hook)
    
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        sample_count += input_ids.shape[0]

        logits = model(input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
    
    for hook in hooks:
        hook.remove()
    
    # Normalize by total samples and compute Frobenius norm
    for layer_index in attn_layernorm_gradients:
        attn_layernorm_gradients[layer_index] /= sample_count
        attn_layernorm_gradients[layer_index] = torch.norm(attn_layernorm_gradients[layer_index], p='fro').item()
    
    for layer_index in output_layernorm_gradients:
        output_layernorm_gradients[layer_index] /= sample_count
        output_layernorm_gradients[layer_index] = torch.norm(output_layernorm_gradients[layer_index], p='fro').item()
    
    print("Attention LayerNorm gradients: ", attn_layernorm_gradients)
    print("Output LayerNorm gradients: ", output_layernorm_gradients)
    print()
    
    return attn_layernorm_gradients, output_layernorm_gradients


def calculate_ln_derivatives_output(model, model_name, loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    attn_layernorm_gradients = dict()
    output_layernorm_gradients = dict()
    sample_count = 0
    
    def hook_fn(module, grad_input, grad_output, storage, layer_index):
        if grad_output[0] is not None:  # Ensure valid gradient
            grad_avg = grad_output[0].abs().mean(dim=1)  # Average across tokens
            
            if layer_index not in storage:
                storage[layer_index] = torch.zeros_like(grad_avg)
            
            storage[layer_index] += grad_avg  # Accumulate across batches

    hooks = []
    for name, module in model.named_modules():
        if model_name == "roberta-base":
            if 'attention.output.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, attn_layernorm_gradients, idx)
                )
                hooks.append(hook)
            elif 'attention' not in name and 'output.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, output_layernorm_gradients, idx)
                )
                hooks.append(hook)
        elif model_name == "gpt2-medium":
            if 'ln_1' in name:
                layer_index = int(name.split(".h.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, attn_layernorm_gradients, idx)
                )
                hooks.append(hook)
            elif 'ln_2' in name:
                layer_index = int(name.split(".h.")[1].split(".")[0])
                hook = module.register_full_backward_hook(
                    lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, output_layernorm_gradients, idx)
                )
                hooks.append(hook)
    
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        sample_count += input_ids.shape[0]

        logits = model(input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
    
    for hook in hooks:
        hook.remove()
    
    # Normalize by total samples and compute Frobenius norm
    for layer_index in attn_layernorm_gradients:
        attn_layernorm_gradients[layer_index] /= sample_count
        attn_layernorm_gradients[layer_index] = torch.norm(attn_layernorm_gradients[layer_index], p='fro').item()
    
    for layer_index in output_layernorm_gradients:
        output_layernorm_gradients[layer_index] /= sample_count
        output_layernorm_gradients[layer_index] = torch.norm(output_layernorm_gradients[layer_index], p='fro').item()
    
    print("Attention LayerNorm gradients: ", attn_layernorm_gradients)
    print("Output LayerNorm gradients: ", output_layernorm_gradients)
    print()
    
    return attn_layernorm_gradients, output_layernorm_gradients

def capture_ln_inputs_l2_norm_sigma(model, model_name, loader, device):
    model.eval()
    attn_ln_inputs = dict()
    output_ln_inputs = dict()
    attn_ln_std = dict()
    output_ln_std = dict()
    sample_count = 0
    
    def hook_fn(module, input, output, storage, std_storage, layer_index):
        if input[0] is not None:  # Ensure valid input
            input_avg = input[0].detach().cpu().mean(dim=1)  # Average across tokens
            l2_norm = torch.norm(input_avg, p='fro').item()
            std_dev = input_avg.std().item()

            if layer_index not in storage:
                storage[layer_index] = 0  # Initialize sum
                std_storage[layer_index] = 0  # Store values for std computation
            
            storage[layer_index] += l2_norm  # Accumulate L2 norms
            std_storage[layer_index] += std_dev  # Store std dev values

    hooks = []
    for name, module in model.named_modules():
        if model_name == "roberta-base":
            if 'attention.output.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_forward_hook(
                    lambda mod, inp, out, idx=layer_index: hook_fn(mod, inp, out, attn_ln_inputs, attn_ln_std, idx)
                )
                hooks.append(hook)
            elif 'attention' not in name and 'output.LayerNorm' in name:
                layer_index = int(name.split(".layer.")[1].split(".")[0])
                hook = module.register_forward_hook(
                    lambda mod, inp, out, idx=layer_index: hook_fn(mod, inp, out, output_ln_inputs, output_ln_std, idx)
                )
                hooks.append(hook)
        elif model_name == "gpt2-medium":
            if 'ln_1' in name:
                layer_index = int(name.split(".h.")[1].split(".")[0])
                hook = module.register_forward_hook(
                    lambda mod, inp, out, idx=layer_index: hook_fn(mod, inp, out, attn_ln_inputs, attn_ln_std, idx)
                )
                hooks.append(hook)
            elif 'ln_2' in name:
                layer_index = int(name.split(".h.")[1].split(".")[0])
                hook = module.register_forward_hook(
                    lambda mod, inp, out, idx=layer_index: hook_fn(mod, inp, out, output_ln_inputs, output_ln_std, idx)
                )
                hooks.append(hook)
    
    # Loop over batches
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        sample_count += input_ids.shape[0]

        with torch.no_grad():
            model(input_ids, attention_mask=attention_mask)
    
    # Remove hooks after processing
    for hook in hooks:
        hook.remove()

    # Normalize by total samples and compute standard deviation
    for layer_index in attn_ln_inputs:
        attn_ln_inputs[layer_index] /= sample_count
        attn_ln_std[layer_index] /= sample_count
    
    for layer_index in output_ln_inputs:
        output_ln_inputs[layer_index] /= sample_count
        output_ln_std[layer_index] /= sample_count
    
    print("Attention LayerNorm Inputs L2-norm: ", attn_ln_inputs)
    print("Attention LayerNorm Inputs Std-dev: ", attn_ln_std)
    print("Output LayerNorm Inputs L2-norm: ", output_ln_inputs)
    print("Output LayerNorm Inputs Std-dev: ", output_ln_std)
    print()

    return attn_ln_inputs, attn_ln_std, output_ln_inputs, output_ln_std



def gradients_analysis(args, train_texts, train_labels, val_texts, val_labels, test_texts, test_labels, seed):

    # Parameters from args
    model_name = args.model_name
    num_labels = len(set(train_labels))
    batch_size = args.batch_size
    epochs = args.epochs
    device = args.device
    remove = args.remove
    learning_rate = args.learning_rate
    percent_train_noisy_samps = args.percent_train_noisy_samps
    desired_train_noise_label = args.desired_train_noise_label
    model_path = args.model_path

    print(f"Model: {model_name}, Batch size: {batch_size}, Epochs: {epochs}")
    print(f"Learning rate: {learning_rate}, Device: {device}")
    print(f"Noise: {percent_train_noisy_samps}% with label {desired_train_noise_label}")

    # Count labels for train, val, and test datasets
    train_label_counts = count_labels(train_labels, "Train")
    val_label_counts = count_labels(val_labels, "Validation")
    test_label_counts = count_labels(test_labels, "Test")
    
    num_train_noisy_samps = int((percent_train_noisy_samps/100)*(sum(list(train_label_counts.values()))))
    # num_train_noisy_samps = int((5/100)*train_label_counts[0])
    print(num_train_noisy_samps)
    train_texts, train_labels, lm_texts, lm_labels = add_lm_to_texts(train_texts, train_labels, desired_train_noise_label, num_train_noisy_samps, num_labels = num_labels, seed=seed)

    train_label_counts = count_labels(train_labels, "Train")
    # Tokenization and Data Preparation
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if(model_name == "openai-community/gpt2" or model_name == "EleutherAI/gpt-neo-125M" 
        or model_name == "gpt2-medium" or model_name == "EleutherAI/pythia-160M"
        or model_name == "Qwen/Qwen2-0.5B-Instruct"):
        
        # default to left padding
        tokenizer.padding_side = "left"
        # Define PAD Token = EOS Token = 50256
        tokenizer.pad_token = tokenizer.eos_token

    train_dataset = TweetsDataset(train_texts, train_labels, tokenizer)
    val_dataset = TweetsDataset(val_texts, val_labels, tokenizer)
    test_dataset = TweetsDataset(test_texts, test_labels, tokenizer)
    lm_dataset = TweetsDataset(lm_texts, lm_labels, tokenizer)


    num_workers = min(3, os.cpu_count())
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers = num_workers)
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers = num_workers)
    lm_loader = DataLoader(lm_dataset, batch_size=1, num_workers = num_workers)

    # Model setup
    model = CustomClassificationModel(model_name, num_labels, remove = remove)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))

    for name, param in model.named_parameters():
        print(f"Layer: {name}, Size: {param.size()}, req grad: {param.requires_grad}")

    # Testing phase
    test_loss, test_acc, test_precision, test_recall, test_f1 = evaluate_model(model, test_loader, device)
    print(f"Testing Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")

    lm_loss, lm_acc, lm_precision, lm_recall, lm_f1 = evaluate_model(model, lm_loader, device, lm = True)
    print(f"LM Loss: {lm_loss:.4f}, Accuracy: {lm_acc:.4f}, Precision: {lm_precision:.4f}, Recall: {lm_recall:.4f}, F1: {lm_f1:.4f}")

    lm_attn_ln_gradients, lm_output_ln_gradients = calculate_ln_derivatives(model, model_name, lm_loader, device)

    test_attn_ln_gradietns, test_output_ln_gradients = calculate_ln_derivatives(model, model_name, test_loader, device)

    
def swap_classes(df):
    df.loc[df['label'] == 5, 'label'] = -1  # Temporarily change 5 to -1
    df.loc[df['label'] == 3, 'label'] = 5   # Change 3 to 5
    df.loc[df['label'] == -1, 'label'] = 3  # Change temporary -1 to 3
    return df

if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser(description="Fine-tune a BERT model with custom parameters.")
    
    parser.add_argument("--model_name", type=str, default="roberta-base", help="Model name to fine-tune.")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training.")
    parser.add_argument("--epochs", type=int, default=70, help="Number of epochs for training.")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training.")
    parser.add_argument("--remove", type=str, default="none", help="Parameter to remove something (if applicable).")
    parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate for the optimizer.")
    parser.add_argument("--percent_train_noisy_samps", type=int, default=1, help="Percentage of noisy samples in training data.")
    parser.add_argument("--desired_train_noise_label", type=int, default=3, help="Label to assign to noisy training samples.")
    parser.add_argument("--model_path", type=str, default = "saved_models_bias_impact/tweets_dataset_model_roberta.pth", help = "path of saved model")


    args = parser.parse_args()

    ds = load_dataset("cardiffnlp/tweet_topic_single")


    # Load datasets
    # Convert each split to a pandas DataFrame
    train_df = ds['train_2020'].to_pandas()
    val_df = ds['validation_2020'].to_pandas()
    test_df = ds['test_2020'].to_pandas()

    # Apply the swap to train, validation, and test datasets
    train_df = swap_classes(train_df)
    val_df = swap_classes(val_df)
    test_df = swap_classes(test_df)



    seeds_list = [64]
    for seed in seeds_list:
        print("---------------------------------------------------------------------------")
        print("Results for seed: " ,seed)
        gradients_analysis(args, np.array(train_df['text']), np.array(train_df['label']), \
                        np.array(val_df['text']), np.array(val_df['label']), \
                        np.array(test_df['text']), np.array(test_df['label']), seed = seed)
        print("---------------------------------------------------------------------------")
        print()
        print()
    print()