import torch
import matplotlib.pyplot as plt
import numpy as np

# Load the trained model
from model import CNN

cnn = CNN()
cnn.load_state_dict(torch.load('model_trained.pth'))

# Load the dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)

def get_training_points(class_from):
    X = []
    for i in range(len(train_data.targets)):
        label = train_data.targets[i].item()
        if label == class_from:
            X.append([train_data.data[i].tolist()])
    X = torch.tensor(X, dtype=torch.float)
    return X

test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

cnn.eval()
with torch.no_grad():

    class_eval = 0
    
    # Get the test points of the class
    class_test_points = []
    for i in range(len(test_data.targets)):
        if test_data.targets[i].item() == class_eval:
            class_test_points.append(test_data.data[i].tolist())
    
    # Try random tests
    random_test_points = np.random.randint(0, len(class_test_points), size = 1)
    # Interesting examples for class 5
    # random_test_points = [829]
    
    # Get the training points of the class
    X = get_training_points(class_eval)
    
    # Get the discriminant for the class
    class_weight = None
    class_bias = None
    for name, params in cnn.named_parameters():
        if name == 'out.weight':
            class_weight = params
        elif name == 'out.bias':
            class_bias = params
    
    # Transform to features
    X_f = cnn.conv2(cnn.conv1(X))
    X_f = X_f.view(X_f.size(0), -1)
    
    # Plot output flag
    plot_output = True
    
    # For each test point index
    for i in range(len(random_test_points)):
            
        # Get the test point
        test_index = random_test_points[i]
        test_point = class_test_points[test_index]
        test_point = torch.tensor([test_point], dtype=torch.float)
        
        print("Evaluating test point: " + str(test_index) + " of class: " + str(class_eval))
        
        # Plot test point
        if plot_output:
            plt.figure()
            plt.axis('off')
            plt.imshow(test_point[0], cmap='gray')
            plt.savefig(
                "results_ext/mnist_test_" + str(class_eval) + "_" + str(test_index) + ".pdf",
                bbox_inches='tight')
        
        # Convert to features
        test_point_f = cnn.conv2(cnn.conv1(test_point))
        test_point_f = test_point_f.view(-1)
        
        # Compute the local values
        X_local = torch.matmul(X_f, test_point_f)
                
        # Compute the spectrums
        for j in range(10):
            
            # The weight and bias to be used
            weight = None
            bias = None
            
            if j == class_eval:
                # For general spectrum
                weight = class_weight[class_eval]
                bias = class_bias[class_eval]
            else:
                # For constrasting spectrums
                weight = class_weight[class_eval] - class_weight[j]
                bias = class_bias[class_eval] - class_bias[j]
                
            # Get the indicies of the support set
            support_idxs = (torch.matmul((X_f - test_point_f), weight) < 0)
            support_idxs = support_idxs.nonzero().flatten()
            print("Support size against " + str(j) + ": " + str(support_idxs.size()[0]))

            # Sort the support set by local values
            support_local = torch.index_select(X_local, 0, support_idxs)
            local_sorted, local_sorted_idxs = torch.sort(support_local, descending=False)
            support_idxs = torch.index_select(support_idxs, 0, local_sorted_idxs)

            # Generate the spectrum
            spec = []
            k = 0
            while k < support_idxs.size()[0]:
                
                # Find the index with the largest global value
                support_f = torch.index_select(X_f, 0, support_idxs[k:])
                support_global = torch.matmul(support_f, weight) + bias
                support_global = - support_global
                
                max_idx = torch.argmax(support_global, dim=0)
                spec.append(support_idxs[k + max_idx.item()].item())
                k = k + max_idx.item() + 1
                
            print("Spectrum size: " + str(len(spec)))

            if plot_output:
                if len(spec) > 0:
                    fig, axs = plt.subplots(1, len(spec))
                    if len(spec) > 1:
                        for k in range(len(spec)):
                            axs[k].axis('off')
                            axs[k].imshow(X[spec[len(spec) - k - 1]][0], cmap='gray')
                    else:
                        axs.axis('off')
                        axs.imshow(X[spec[len(spec) - k - 1]][0], cmap='gray')
                        
                    fig.savefig('results_ext/mnist_spec_' + str(class_eval) + '_' + str(test_index) + '_' + str(j) + '.pdf', bbox_inches='tight')


