import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.autograd import grad

from tqdm import tqdm
import numpy as np



n_layer = 10

print("n_layer: ", n_layer)

# Define a transform to preprocess the MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize input to range [-1, 1]
])

# Define the neural network model with the specific layer of size 20
class DenseNN(nn.Module):
    def __init__(self, layer_sizes):
        super(DenseNN, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(1, len(layer_sizes)):
            self.layers.append(nn.Linear(layer_sizes[i-1], layer_sizes[i]))
            if i < len(layer_sizes) - 1:
                self.layers.append(nn.ReLU())  # Add ReLU activation except for the last layer

    def forward(self, x):
        activations = []
        for layer in self.layers:
            x = layer(x)
            # Store activations from the specific layer (layer of size 20)
            if isinstance(layer, nn.Linear) and layer.out_features == 20:
                activations.append(x)

            if isinstance(layer, nn.Linear) and layer.out_features == n_layer:
                activations.append(x)
           
        return activations

# Load the MNIST dataset
batch_size = 1
test_dataset = MNIST(root='/home/causal_ksd/data_MNIST/original', train=False, transform=transform, download=True)


for target_class in tqdm(range(10)):
    # Filter the dataset to include only images of one class (e.g., digit "0")
    filtered_dataset = [item for item in test_dataset if item[1] == target_class] 

    # Create a data loader for the filtered dataset
    filtered_loader = DataLoader(filtered_dataset, batch_size=batch_size, shuffle=False)
        

    # Create the model with layer sizes [784, 100, 20, 10]
    if n_layer == 15:
        layer_sizes = [784, 100, 20, 15, 10]
    elif n_layer == 10:
        layer_sizes = [784, 100, 20, 10]
    model = DenseNN(layer_sizes)

    # Load the pretrained model weights
    model_weights_path = "/home/causal_ksd/models_MNIST/mnist_dense_nn_weights" + str(n_layer) + ".pth"
    model.load_state_dict(torch.load(model_weights_path))

    # Set the model to training mode
    model.train()

    # Extract activations from the layer of size 20 for each image in the filtered dataset
    activations_list = []

    X = torch.zeros((len(filtered_dataset), n_layer))
    Y = torch.zeros((len(filtered_dataset), 20)) 
    sco = torch.zeros((len(filtered_dataset), 20, n_layer))
    sco_original = torch.zeros((len(filtered_dataset), 784, n_layer))
    
    for i, (inputs, _) in enumerate(filtered_loader):
        inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
        inputs.requires_grad = True
        outputs = model(inputs)
        activations = outputs[-1]  # Get activations from the last layer (layer of size 15)
        X[batch_size*i:batch_size*(i+1)] = activations        
        # Compute the gradients
        hidden = outputs[-2]
        Y[batch_size*i:batch_size*(i+1)] = hidden 
    
        for k in range(n_layer):
            sco[i, :, k] = grad(activations[:, k], hidden, retain_graph=True)[0].detach().squeeze()
            sco_original[i, :, k] = grad(activations[:, k], inputs, retain_graph=True)[0].detach().squeeze()
            
            
    # Print activations for the first image in the filtered dataset
    torch.save(X.detach(), "/home/causal_ksd/models_MNIST/hl" + str(target_class) + "n_layer" + str(n_layer) + ".pt")
    torch.save(Y.detach(), "/home/causal_ksd/models_MNIST/ll" + str(target_class) + "n_layer" + str(n_layer) + ".pt")
    torch.save(sco, "/home/causal_ksd/models_MNIST/sco" + str(target_class) + "n_layer" + str(n_layer) + ".pt")
    torch.save(sco_original, "/home/causal_ksd/models_MNIST/sco_original" + str(target_class) + "n_layer" + str(n_layer) + ".pt")