import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import torch.nn.functional as F
import warnings
import logging
import math
import shutil
import copy
import argparse
import pickle



def get_loaders(batch_size, dataset = "cifar10"):   
    if(dataset == "cifar10"):
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                                            std=(0.2470, 0.2435, 0.2616))])
        # Loaders for train and test
        cifar_dset_train = torchvision.datasets.CIFAR10('../cifar10_data', train=True, transform=transform, target_transform=None, download=True)
        trainloader = torch.utils.data.DataLoader(cifar_dset_train, batch_size=batch_size, shuffle=True, num_workers=0)

        cifar_dset_test = torchvision.datasets.CIFAR10('../cifar10_data', train=False, transform=transform, target_transform=None, download=True)
        testloader = torch.utils.data.DataLoader(cifar_dset_test, batch_size=batch_size, shuffle=False, num_workers=0)
        return trainloader, testloader
    
    elif(dataset == "mnist"):
        # download and transform train dataset
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                                    torchvision.transforms.Normalize((0.1307,), (0.3081,))])
        # Loaders for train and test
        mnist_dset_train = torchvision.datasets.MNIST('../mnist_data', train=True, transform=transform, target_transform=None, download=True)
        trainloader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=batch_size, shuffle=True, num_workers=0)

        mnist_dset_test = torchvision.datasets.MNIST('../mnist_data', train=False, transform=transform, target_transform=None, download=True)
        testloader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=batch_size, shuffle=False, num_workers=0)
        return trainloader, testloader
    
    
# Memory-efficient function to compute correlation using streaming statistics (Welford's Algorithm)
def accumulate_statistics(hidden, errors, hidden_mean, hidden_M2, error_mean, error_M2, hidden_error_sum, total_samples):
    # Flatten hidden layers
    hidden_flat = hidden.reshape(hidden.size(0), -1).double().cpu()
    errors_flat = errors.double().cpu()  # Shape (batch_size, 10)
    batch_size = hidden.size(0)

    # Update the total number of samples processed
    new_total_samples = total_samples + batch_size

    # Update hidden statistics using Welford's algorithm
    delta_hidden = hidden_flat.mean(dim=0) - hidden_mean
    delta_hidden_scaled = delta_hidden * batch_size / new_total_samples
    hidden_mean += delta_hidden_scaled
    hidden_M2 += ((hidden_flat - hidden_mean).pow(2)).sum(dim=0)  # Accumulate squared differences for variance

    # Update error statistics using Welford's algorithm
    delta_error = errors_flat.mean(dim=0) - error_mean
    delta_error_scaled = delta_error * batch_size / new_total_samples
    error_mean += delta_error_scaled
    error_M2 += ((errors_flat - error_mean).pow(2)).sum(dim=0).cpu()  # Accumulate squared differences for variance

    # Update the hidden-error covariance sum (used later for correlation)
    hidden_error_sum += torch.matmul(hidden_flat.T, errors_flat).cpu()  # Shape (hidden_size, 10)
    return hidden_mean, hidden_M2, error_mean, error_M2, hidden_error_sum, new_total_samples


# Finalize the correlation calculation after all batches are processed
def finalize_correlation(hidden_mean, hidden_M2, error_mean, error_M2, hidden_error_sum, total_samples):
    # Compute variances
    hidden_var = hidden_M2 / (total_samples - 1)
    error_var = error_M2 / (total_samples - 1)
    
    # Compute covariance
    cov = (hidden_error_sum / total_samples) - torch.outer(hidden_mean, error_mean)
    
    # Compute Pearson correlation coefficient
    cross_variances = torch.outer(hidden_var, error_var)
    correlation = cov / (torch.sqrt(cross_variances) + 1e-6)  # Add epsilon for numerical stability
    correlation = correlation * (cross_variances > 1e-5)
    return correlation


# Modified evaluation function for memory-efficient correlation calculation
def evaluateClassificationCorrelation(model, loader, device, printing=True, num_hidden_layers = 2):
    num_classes = 10
    model.eval()
    correct = 0
    loss = 0
    criterion = nn.MSELoss()
    
    # Initialize statistics for each hidden layer
    hidden_sums = [0] * num_hidden_layers
    hidden_sq_sums = [0] * num_hidden_layers
    hidden_error_sums = [0] * num_hidden_layers
    error_sums = [0] * num_hidden_layers
    error_sq_sums = [0] * num_hidden_layers
    total_samples = [0] * num_hidden_layers

    # Initialize tensors for statistics
    for i in range(num_hidden_layers):
        hidden_sums[i], hidden_sq_sums[i], hidden_error_sums[i], hidden_sums[i], hidden_sq_sums[i], hidden_error_sums[i] = None, None, None, None, None, None

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device).type(torch.float64), y.to(device)
            y_hat, preh_list, hidden_list = model(x)
            
            # Calculate predictions and errors
            pred = torch.argmax(y_hat, dim=1).squeeze()
            correct += (y == pred).sum().item()
            y_one_hot = F.one_hot(y, num_classes=10)
            errors = y_hat - y_one_hot.to(device).to(torch.float64)
            
            # Calculate error for each sample (e.g., using CrossEntropyLoss)
            loss += criterion(y_hat, y_one_hot.to(device).to(torch.float64))
                        
            # Update statistics for each hidden layer
            for i in range(num_hidden_layers):
                if hidden_sums[i] is None:
                    # Initialize the sums with the correct dimensions
                    hidden_sums[i] = torch.zeros(hidden_list[i].reshape(hidden_list[i].size(0), -1).size(1), device='cpu')
                    hidden_sq_sums[i] = torch.zeros_like(hidden_sums[i])
                    hidden_error_sums[i] = torch.zeros(hidden_sums[i].size(0), num_classes, device='cpu')  # Shape (hidden_size, 10)

                hidden_sums[i], hidden_sq_sums[i], error_sums[i], error_sq_sums[i], hidden_error_sums[i], total_samples[i] = accumulate_statistics(
                    hidden_list[i], errors, hidden_sums[i], hidden_sq_sums[i], error_sums[i], error_sq_sums[i], hidden_error_sums[i], total_samples[i]
                )
            del x, y, y_hat, preh_list, hidden_list, pred, errors
            torch.cuda.empty_cache()
    
    print("done collecting stats")
    acc = correct / len(loader.dataset)
    avgloss = loss / len(loader.dataset)
    if printing:
        print('Accuracy :\t', acc)
        print('Loss :/t', avgloss)
    
    # Finalize correlation calculation
    correlations = []
    for i in range(num_hidden_layers):
        correlation = finalize_correlation(hidden_sums[i], hidden_sq_sums[i], error_sums[i], error_sq_sums[i], hidden_error_sums[i], total_samples[i])
        correlations.append(correlation)
        print(f'Correlation between hidden layer {i + 1} and output errors (mean over all error dimensions): {torch.abs(correlation).mean().item()}')
    return acc, correlations, avgloss

