import os
import glob
import pandas as pd
import torch
import argparse
import json
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, AdamW, DeiTForImageClassification, ViTMSNForImageClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from collections import Counter
import random
import copy
from tqdm import tqdm
from scipy.stats import ttest_rel
from torch.nn.functional import softmax
from sklearn.model_selection import StratifiedShuffleSplit
from PIL import Image

# Custom dataset class
class NICODataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(torch.tensor(image))
        return {"image": image,  "label": label}




# Define a custom model class
class CustomClassificationModel(nn.Module):
    def __init__(self, model_name, num_labels, remove = 'none'):
        super(CustomClassificationModel, self).__init__()
        self.model_name = model_name
        
        if(self.model_name == "facebook/vit-msn-small"):
            self.backbone = ViTMSNForImageClassification.from_pretrained(
                                model_name,
                                num_labels = num_labels,
                                ignore_mismatched_sizes=True
                            )
        
        if(model_name == "facebook/vit-msn-small"):
            # # Remove bias from transformer layers (attention and feedforward layers)
            for name, module in self.backbone.named_modules():
                    
                if(remove == 'layer_norm'):
                    if('layernorm' in name):
                        module.weight = None
                        module.bias = None
    
        

    def forward(self, input_imgs):
        outputs = self.backbone(input_imgs)
        return outputs
        

# Define a custom model class
class CustomClassificationModel_layer_analysis(nn.Module):
    def __init__(self, model_name, num_labels, remove_layers = None):
        super(CustomClassificationModel_layer_analysis, self).__init__()
        self.model_name = model_name

        if(self.model_name == "facebook/vit-msn-small"):
            self.backbone = ViTMSNForImageClassification.from_pretrained(
                                model_name,
                                num_labels = num_labels,
                                ignore_mismatched_sizes=True
                            )

        if(model_name == "facebook/vit-msn-small"):
            # # Remove bias from transformer layers (attention and feedforward layers)
            for name, module in self.backbone.named_modules():
                    
                if('layernorm_before' in name or "layernorm_after" in name):
                    layer_index = int(name.split(".layer.")[1].split(".")[0])
                    if(layer_index in remove_layers):
                        module.weight = None
                        module.bias = None
                


    def forward(self, input_imgs):
        outputs = self.backbone(input_imgs)
        return outputs


# Metrics function
def compute_metrics(predictions, labels):
    acc = accuracy_score(labels, predictions)
    precision = precision_score(labels, predictions, average='macro')
    recall = recall_score(labels, predictions, average='macro')
    f1 = f1_score(labels, predictions, average='macro')
    return acc, precision, recall, f1

# Training function
def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    predictions, true_labels = [], []
    criterion = nn.CrossEntropyLoss()

    for batch in train_loader:
        optimizer.zero_grad()
        images = batch['image'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        logits = model(images).logits
        loss = criterion(logits, labels)
        total_loss += loss.item()

        # Get predictions and move data to CPU for metrics calculation
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        labels = labels.cpu().numpy()

        predictions.extend(preds)
        true_labels.extend(labels)

        # Backward pass and optimization step
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    acc, precision, recall, f1 = compute_metrics(predictions, true_labels)
    torch.cuda.empty_cache()
    return avg_loss, acc, precision, recall, f1, optimizer

# Evaluation function
def evaluate_model(model, val_loader, device, lm = False):
    model.eval()
    total_loss = 0
    predictions, true_labels = [], []
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            logits = model(images).logits
            loss = criterion(logits, labels)
            total_loss += loss.item()

            # Get predictions and move data to CPU for metrics calculation
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels)
    if(lm):
        print("LM Predictions: ", predictions)
        print("LM Labels: ", true_labels)

    avg_loss = total_loss / len(val_loader)
    # print(predictions, true_labels)
    acc, precision, recall, f1 = compute_metrics(predictions, true_labels)
    torch.cuda.empty_cache()
    return avg_loss, acc, precision, recall, f1

def test_model(model, test_loader, device):
    total_loss = 0
    predictions, true_labels = [], []
    
    # Initialize a dictionary to store misclassifications per class
    misclassifications = {i: 0 for i in range(6)}  # assuming model has a num_labels attribute

    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            logits = model(images).logits

            # Get predictions and move data to CPU for metrics calculation
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels)

    # Track misclassifications
    for i in range(len(true_labels)):
        if predictions[i] != true_labels[i]:  # If predicted class does not match true label
            misclassifications[predictions[i]] += 1

    avg_loss = total_loss / len(test_loader)
    acc, precision, recall, f1 = compute_metrics(predictions, true_labels)
    torch.cuda.empty_cache()


    return avg_loss, acc, precision, recall, f1, misclassifications


# Function to count label occurrences
def count_labels(labels, dataset_name):
    label_counts = Counter(labels)
    print(f"Label counts for {dataset_name}:")
    for label, count in label_counts.items():
        print(f"  Label {label}: {count}")
    return label_counts


def add_random_label(original_label, idx, seed, num_labels):

    # Set seed for reproducibility
    if seed is not None:
        random.seed(seed + idx)
    
    # Generate a random label different from the original label
    # Assuming labels are integers starting from 0 to max_label
    possible_labels = list(range(0, num_labels))
    possible_labels.remove(original_label)  # Remove the original label
    
    # Select a random label from the remaining options
    random_label = random.choice(possible_labels)
    
    return random_label


def add_lm_to_texts(texts, labels, class_label, n, seed=28, num_labels = 100):

    random.seed(seed)

    # Find indices of samples belonging to the given class label
    indices = [i for i, label in enumerate(labels) if label == class_label]

    # Randomly select n indices to modify
    indices_to_modify = random.sample(indices, min(n, len(indices)))
    # Copy texts to avoid modifying the original data
    train_texts_copy, train_labels_copy = copy.deepcopy(texts), copy.deepcopy(labels)
    lm_texts = []
    lm_labels = []
    lm_labels_actual = []

    for idx in indices_to_modify:
        lm_labels_actual.append(train_labels_copy[idx])
        train_labels_copy[idx] = add_random_label(train_labels_copy[idx], idx, seed, num_labels) #noisy label
        lm_texts.append(train_texts_copy[idx])
        lm_labels.append(train_labels_copy[idx])

    print("Actual labels: ", lm_labels_actual)
    return train_texts_copy, train_labels_copy, lm_texts, lm_labels



def plot_metrics(epochs_list, train_list, val_list, test_list, lm_list, metric_type, save_path):
    """
    Plots the specified metric trends (accuracy or loss) over epochs and saves the plot to the given path.
    
    Args:
        epochs_list (list): List of epoch numbers.
        train_list (list): Training metric values over epochs.
        val_list (list): Validation metric values over epochs.
        test_list (list): Test metric values over epochs.
        lm_list (list): Label memorization metric values over epochs.
        metric_type (str): Type of metric ("Accuracy" or "Loss").
        save_path (str): File path to save the plot (including file name and extension).
    """
    # Create the plot
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_list, train_list, label="Train " + metric_type, color="red", marker='o')
    plt.plot(epochs_list, val_list, label="Validation " + metric_type, color="blue", marker='s')
    plt.plot(epochs_list, test_list, label="Test " + metric_type, color="yellow", marker='x')
    plt.plot(epochs_list, lm_list, label="LM " + metric_type, color="green", marker='^')
    
    # Adding labels, title, legend, and grid
    plt.xlabel("Epochs", fontsize=12)
    plt.ylabel(metric_type, fontsize=12)
    plt.title(f"{metric_type} Trends over Epochs", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, linestyle="--", alpha=0.6)
    
    # Save the plot
    plt.savefig(save_path)
    plt.close()
    print(f"{metric_type} plot saved at {save_path}")



def save_metrics_to_csv(epochs_list, train_list, val_list, test_list, lm_list, csv_path):
    """
    Saves the metrics data to a CSV file.

    Args:
        epochs_list (list): List of epoch numbers.
        train_list (list): Training metric values.
        val_list (list): Validation metric values.
        test_list (list): Test metric values.
        lm_list (list): Label memorization metric values.
        csv_path (str): File path to save the CSV (including file name and extension).
    """
    # Create a DataFrame
    df = pd.DataFrame({
        "Epoch": epochs_list,
        "Train": train_list,
        "Validation": val_list,
        "Test": test_list,
        "Label Memorization": lm_list,
    })
    
    # Save to CSV
    df.to_csv(csv_path, index=False)
    print(f"Metrics saved to CSV at {csv_path}")


def save_combined_metrics_to_json(data, file_name):
    """
    Save combined metrics to a JSON file.
    
    Args:
        data (dict): Dictionary containing all metrics (accuracy or loss).
        file_name (str): Name of the output JSON file.
    """
    with open(file_name, 'w') as json_file:
        json.dump(data, json_file, indent=4)
    print(f"Combined metrics saved to {file_name}")



def sample_50_percent(train_texts, train_labels, seed = 28):
    np.random.seed(seed)
    # Unique labels in the dataset
    unique_labels = np.unique(train_labels)

    sampled_texts = []
    sampled_labels = []

    for label in unique_labels:
        # Get indices of all samples with the current label
        label_indices = np.where(train_labels == label)[0]

        # Randomly sample 50% of the indices
        sample_size = len(label_indices) // 2
        sampled_indices = np.random.choice(label_indices, size=sample_size, replace=False)

        # Append sampled texts and labels
        sampled_texts.extend(train_texts[sampled_indices])
        sampled_labels.extend(train_labels[sampled_indices])

    return np.array(sampled_texts), np.array(sampled_labels)


def remove_classes(texts, labels, classes_to_remove):
    """
    Removes samples belonging to specified classes from the dataset.

    Parameters:
        texts (list): List of texts.
        labels (list): List of corresponding labels.
        classes_to_remove (set): Classes to be removed from the dataset.

    Returns:
        filtered_texts (list): Texts with specified classes removed.
        filtered_labels (list): Labels with specified classes removed.
    """
    filtered_texts = []
    filtered_labels = []

    for text, label in zip(texts, labels):
        if label not in classes_to_remove:
            filtered_texts.append(text)
            filtered_labels.append(label)

    return filtered_texts, filtered_labels



def calculate_ln_derivatives(model, model_name, loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    attn_layernorm_derivatives = dict()
    output_layernorm_derivatives = dict()
    sample_count = 0
    
    def hook_fn(module, grad_input, grad_output, storage, layer_index):
        if grad_input[0] is not None:  # Ensure valid gradient
            grad_avg = grad_input[0].abs().mean(dim=1)  # Average across tokens
            
            if layer_index not in storage:
                storage[layer_index] = torch.zeros_like(grad_avg)
            
            storage[layer_index] += grad_avg  # Accumulate across batches

    hooks = []
    for name, module in model.named_modules():
        if 'layernorm_before' in name:
            layer_index = int(name.split(".layer.")[1].split(".")[0])
            hook = module.register_full_backward_hook(
                lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, attn_layernorm_derivatives, idx)
            )
            hooks.append(hook)
        elif 'layernorm_after' in name:
            layer_index = int(name.split(".layer.")[1].split(".")[0])
            hook = module.register_full_backward_hook(
                lambda mod, gin, gout, idx=layer_index: hook_fn(mod, gin, gout, output_layernorm_derivatives, idx)
            )
            hooks.append(hook)

    
    for batch in loader:
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        sample_count += images.shape[0]

        logits = model(images).logits
        loss = criterion(logits, labels)
        loss.backward()
    
    for hook in hooks:
        hook.remove()
    
    # Normalize by total samples and compute Frobenius norm
    for layer_index in attn_layernorm_derivatives:
        attn_layernorm_derivatives[layer_index] /= sample_count
        attn_layernorm_derivatives[layer_index] = torch.norm(attn_layernorm_derivatives[layer_index], p='fro').item()
    
    for layer_index in output_layernorm_derivatives:
        output_layernorm_derivatives[layer_index] /= sample_count
        output_layernorm_derivatives[layer_index] = torch.norm(output_layernorm_derivatives[layer_index], p='fro').item()
    
    print("Attention LayerNorm grads: ", attn_layernorm_derivatives)
    print("Output LayerNorm grads: ", output_layernorm_derivatives)
    print()
    
    return attn_layernorm_derivatives, output_layernorm_derivatives


def gradients_analysis(args, train_imgs, train_labels, val_imgs, val_labels, test_imgs, test_labels, seed):


    # Parameters from args
    model_name = args.model_name
    num_labels = len(set(train_labels))
    batch_size = args.batch_size
    epochs = args.epochs
    device = args.device
    remove = args.remove
    learning_rate = args.learning_rate
    percent_train_noisy_samps = args.percent_train_noisy_samps
    desired_train_noise_label = args.desired_train_noise_label
    model_path = args.model_path


    print(f"Model: {model_name}, Batch size: {batch_size}, Epochs: {epochs}")
    print(f"Learning rate: {learning_rate}, Device: {device}")
    print(f"Noise: {percent_train_noisy_samps}% with label {desired_train_noise_label}")

    # Count labels for train, val, and test datasets
    train_label_counts = count_labels(train_labels, "Train")
    val_label_counts = count_labels(val_labels, "Validation")
    test_label_counts = count_labels(test_labels, "Test")

    num_train_noisy_samps = int((percent_train_noisy_samps/100)*(sum(list(train_label_counts.values()))))
    print(num_train_noisy_samps)

    train_imgs, train_labels, lm_imgs, lm_labels = add_lm_to_texts(train_imgs, train_labels, desired_train_noise_label, num_train_noisy_samps, num_labels = num_labels, seed=seed)
    train_label_counts = count_labels(train_labels, "Train")
    print(len(train_imgs))  # Should be (num_samples, height, width, channels)
    print(train_imgs[0].shape)  # Should be (28, 28, 3) or (28, 28, 1)

    # Transform for resizing and normalization
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Create training and validation datasets
    train_dataset = NICODataset(train_imgs, train_labels, transform=transform)
    val_dataset = NICODataset(val_imgs, val_labels, transform=transform)
    test_dataset = NICODataset(test_imgs, test_labels, transform=transform)
    lm_dataset = NICODataset(lm_imgs, lm_labels, transform=transform)


    num_workers = min(3, os.cpu_count())
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers = num_workers)
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers = num_workers)
    lm_loader = DataLoader(lm_dataset, batch_size=1, num_workers = num_workers)


    # Model setup
    model = CustomClassificationModel(model_name, num_labels, remove = remove)
    model = model.to(device)
    model.load_state_dict(torch.load(model_path))

    for name, param in model.named_parameters():
        print(f"Layer: {name}, Size: {param.size()}, req grad: {param.requires_grad}")

    # Testing phase
    test_loss, test_acc, test_precision, test_recall, test_f1 = evaluate_model(model, test_loader, device)
    print(f"Testing Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")

    lm_loss, lm_acc, lm_precision, lm_recall, lm_f1 = evaluate_model(model, lm_loader, device, lm = True)
    print(f"LM Loss: {lm_loss:.4f}, Accuracy: {lm_acc:.4f}, Precision: {lm_precision:.4f}, Recall: {lm_recall:.4f}, F1: {lm_f1:.4f}")

    lm_attn_ln_gradients, lm_output_ln_gradients = calculate_ln_derivatives(model, model_name, lm_loader, device)

    test_attn_ln_gradietns, test_output_ln_gradients = calculate_ln_derivatives(model, model_name, test_loader, device)


def load_nico(main_folder, seed):

    nico_classes = {"car": 0, "flower": 1, "penguin": 2, "camel": 3, "chair": 4, "monitor": 5, "truck": 6, "wheat": 7, "sword": 8, "seal": 9, "lion": 10, "fish": 11, "dolphin": 12, "lifeboat": 13, "tank": 14}
    
    images, labels = [], []

    for folder_name in os.listdir(main_folder):

        if(folder_name not in list(nico_classes.keys())):
            continue
        folder_path = os.path.join(main_folder, folder_name)
        nico_label = nico_classes[folder_name]
        
        if os.path.isdir(folder_path):
            # Loop through files in the subfolder
            for file_name in os.listdir(folder_path):
                # Check if the file is a .jpg image
                if file_name.lower().endswith('.jpg'):
                    # Full path to the image file
                    file_path = os.path.join(folder_path, file_name)
                    
                    # Open the image and store it (or process it)
                    try:
                        img = Image.open(file_path).convert('RGB')
                        # Resize the image to (224, 224) for consistency
                        image_resized = img.resize((224, 224), resample=Image.Resampling.BICUBIC)

                        # Convert the resized image to a numpy array
                        image_np = np.array(image_resized)

                        # Reorder the dimensions to match (3, height, width)
                        image_np = image_np.transpose(2, 0, 1)  # Change shape to (3, 224, 224)

                        # Ensure the image has the expected shape (3, 224, 224)
                        assert image_np.shape == (3, 224, 224), f"Unexpected image shape: {image_np.shape}"

                        # print(image_np.shape)

                        # Append the image and its corresponding label
                        images.append(image_np)
                        labels.append(nico_label)
                        # print(f"Loaded {file_name} from {folder_name}")
                    except Exception as e:
                        print(f"Failed to open {file_path}: {e}")

    print(len(images), len(labels))
    train_images, test_images, train_labels, test_labels = train_test_split(images, labels, train_size = 0.8, \
                                                                stratify=labels, random_state=seed)

    print(len(train_images), len(train_labels), len(test_images), len(test_labels))
    val_images, test_images, val_labels, test_labels = train_test_split(test_images, test_labels, train_size = 0.5, \
                                                                stratify=test_labels, random_state=seed)

    return train_images, train_labels, val_images, val_labels, test_images, test_labels 



if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser(description="Fine-tune a BERT model with custom parameters.")
    
    parser.add_argument("--model_name", type=str, default="facebook/vit-msn-small", help="Model name to fine-tune.")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training.")
    parser.add_argument("--epochs", type=int, default=70, help="Number of epochs for training.")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training.")
    parser.add_argument("--remove", type=str, default="none", help="Parameter to remove something (if applicable).")
    parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate for the optimizer.")
    parser.add_argument("--percent_train_noisy_samps", type=int, default=1, help="Percentage of noisy samples in training data.")
    parser.add_argument("--desired_train_noise_label", type=int, default=6, help="Label to assign to noisy training samples.")
    parser.add_argument("--model_path", type=str, default = "saved_models_bias_impact/nico_dataset_model_vit_small.pth", help = "path of saved model")


    args = parser.parse_args()

    seeds_list = [28]
    for seed in seeds_list:
        train_imgs, train_labels, val_imgs, val_labels, test_imgs, test_labels = load_nico("public_ood_0412_nodomainlabel/train/", seed)
        gradients_analysis(args, train_imgs, train_labels, val_imgs, val_labels, test_imgs, test_labels, seed)