import torch
from torch import nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import copy

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

import numpy as np

np.random.seed(10)
torch.manual_seed(10)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
    )

print(f"Using {device} device")

# Generate the data points
num_data = 135
num_classes = 3
class_colors = ['red', 'green', 'blue']

X = np.random.uniform(-2.0, 2.0, size=(num_data,2))
Y = []
for x in X:
    if np.linalg.norm(x) < 1:
        Y.append(0)
    elif x[1] < 0:
        Y.append(1)
    else:
        Y.append(2)
        
X = torch.Tensor(X)
Y = torch.Tensor(Y).type(torch.LongTensor)

# Create the mesh for showing decision boundaries
h = 0.02
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# The model
model_width = 10
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_map = nn.Sequential(
            nn.Linear(2, model_width),
            nn.ReLU(),
            )
        self.lin_classifier = nn.Linear(model_width, num_classes)
    
    def forward(self, x):
        f = self.feature_map(x)
        logits = self.lin_classifier(f)
        return logits
    
model = NeuralNetwork().to(device)

# Training hyperparameters
num_iterations = 3000
l2_lambda = 0.01

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Optimizer & lr_scheduler
optimizer = torch.optim.SGD(
    [{'params' : model.feature_map.parameters()},
     {'params' : model.lin_classifier.parameters(), 'weight_decay' : (2 * l2_lambda)}], 
    lr = 1.0,
    momentum=0.5)
scheduler = lr_scheduler.LinearLR(
    optimizer, 
    start_factor=0.2,
    end_factor=0.3,
    total_iters=num_iterations)

# Best loss and parameters
best_loss_val = float('inf')
best_loss_params = None

# Train the model
loss_val = []
iter = 1
while iter <= num_iterations:

    # Enable for batch normalization and dropout layers
    model.train() 
    
    # Compute prediction & loss
    logits = model(X)
    loss = loss_fn(logits, Y)
    loss_val.append(loss.item())

    # Update best loss (if necessary)
    if loss.item() < best_loss_val:
        best_loss_val = loss.item()
        best_loss_params = copy.deepcopy(model.state_dict())

    # Backprop
    loss.backward()
    
    # GD step
    optimizer.step()
    optimizer.zero_grad()
    
    # Go next
    scheduler.step()
    iter += 1
    
# Plot the training losses
plt.figure("training losses")
plt.title('Training losses')
plt.plot(loss_val)

# Test points
test_points = torch.Tensor([[-1.25, 1.25], [1.25, 1.25]])

# For plotting the decision boundaries
def plot_dec_boundaries():
    Z = model(torch.tensor(np.column_stack([xx.ravel(), yy.ravel()])).type(torch.FloatTensor))
    al = 0.25
    # Plot the region of class 0
    Z_0 = nn.Softmax(dim=1)(Z).numpy()[:,0].reshape(xx.shape)
    cm_0 = ListedColormap([(1.0, 0.0, 0.0, 0.0), (1.0, 0.0, 0.0, al)])
    # Plot the region of class 1
    Z_1 = nn.Softmax(dim=1)(Z).numpy()[:,1].reshape(xx.shape)
    cm_1 = ListedColormap([(0.0, 1.0, 0.0, 0.0), (0.0, 1.0, 0.0, al)])
    # Plot the region of class 2
    Z_2 = nn.Softmax(dim=1)(Z).numpy()[:,2].reshape(xx.shape)
    cm_2 = ListedColormap([(0.0, 0.0, 1.0, 0.0), (0.0, 0.0, 1.0, al)])
    plt.contourf(xx, yy, Z_0, cmap=cm_0)
    plt.contourf(xx, yy, Z_1, cmap=cm_1)
    plt.contourf(xx, yy, Z_2, cmap=cm_2)

# Start evaluation
model.train(False)
model.eval()
with torch.no_grad():

    # Load best loss model
    model.load_state_dict(best_loss_params)    
    
    # Get the discriminants for the classes
    W = [[]]*num_classes
    b = [[]]*num_classes
    for name, param in model.lin_classifier.state_dict().items():
        if name == "weight":
            for i in range(num_classes):
                W[i] = param[i]
        if name == "bias":
            for i in range(num_classes):
                b[i] = param[i]
    
    # Do for all the test points
    for tp in range(test_points.size()[0]):
    
        test_x = test_points[tp].view(1,2)
        
        # Get the predicted class for the test point
        test_logits = model(test_x)
        test_pred_probab = nn.Softmax(dim=1)(test_logits)
        test_pred_class = test_pred_probab.argmax(1).item()
    
        # Separate the training points by class
        X_sep = [[]]*num_classes
        for i in range(num_classes):
            class_indicies = (Y==i)
            class_indicies = class_indicies.nonzero().flatten()
            X_sep[i] = torch.index_select(X, 0, class_indicies)
            
        # Get the features of the training points (in the predicted class)
        features = model.feature_map(torch.Tensor(X_sep[test_pred_class]))
        
        # feature of the test point
        test_feature = model.feature_map(test_x)[0]
    
        # Set high precision multiplication
        torch.set_float32_matmul_precision('high')
        
        # Generate the spectrums
        for i in range(num_classes):
            # The relevant discriminants
            weight = None
            bias = None
            if i == test_pred_class:
                weight = W[test_pred_class]
                bias = b[test_pred_class]
            else:
                weight = W[test_pred_class] - W[i]
                bias = b[test_pred_class] - b[i]
    
            # Get the support indices
            support_indices = (torch.matmul((features - test_feature), weight) < 0)
            support_indices = support_indices.nonzero().flatten()
            
            # Get the support set and its features
            X_support = torch.index_select(X_sep[test_pred_class], 0, support_indices)
            features_support = torch.index_select(features, 0, support_indices)
            
            # Get the similarity and importance values
            similarity = torch.matmul(features_support, test_feature)
            importance = - (torch.matmul(features_support, weight) + bias)
            
            # Sort the three list by similarity
            _, sorted_indices = torch.sort(similarity, descending=False)
            X_support = torch.index_select(X_support, 0, sorted_indices)
            similarity = torch.index_select(similarity, 0, sorted_indices)
            importance = torch.index_select(importance, 0, sorted_indices)
            
            # Generate the spectrum
            spec = []
            k = 0
            while k < support_indices.size()[0]:
                
                # Get the most important point
                max_index = torch.argmax(importance[k:])
                # Add to the spectrum
                spec.append(X_support[k + max_index].tolist())
                # Shrink the support size
                k = k + max_index + 1
            
            identifier = str(tp) + "_" + str(i)
            
            # Plot the spec and save to file
            plt.figure("support_" + identifier)
            
            # Plot the decision boundaries
            plot_dec_boundaries()
            
            # Distinguish between support and non-support
            support_alpha = []
            for i in range(X_sep[test_pred_class].size()[0]):
                if i in support_indices:
                    support_alpha.append(1.0)
                else:
                    support_alpha.append(0.1)
            
            # Plot the training points
            plt.scatter(
                np.array(X_sep[test_pred_class])[:,0], 
                np.array(X_sep[test_pred_class])[:,1], 
                alpha=support_alpha, 
                c=class_colors[test_pred_class])
    
            plt.xlim(xx.min(), xx.max())
            plt.ylim(yy.min(), yy.max())
            plt.xticks(())
            plt.yticks(())
    
            # Plot the test point
            plt.scatter(test_x[0][0].item(), test_x[0][1].item(), c='black', s=50)
            
            # Save to file
            plt.savefig("support_" + identifier + ".png", dpi =400, bbox_inches='tight')
            
            # Plot the spectrum & save to file
            plt.figure("spec_" + identifier)
            
            # Plot the decision boundaries
            plot_dec_boundaries()
            
            # Plot the spectrum
            spec.append([test_x[0][0].item(), test_x[0][1].item()])
            plt.plot(np.array(spec)[:,0], np.array(spec)[:,1], c='black', alpha=0.5, linewidth=1)
            
            # Plot the training points
            plt.scatter(
                np.array(X_sep[test_pred_class])[:,0], 
                np.array(X_sep[test_pred_class])[:,1], 
                alpha=support_alpha, 
                c=class_colors[test_pred_class])
    
            plt.xlim(xx.min(), xx.max())
            plt.ylim(yy.min(), yy.max())
            plt.xticks(())
            plt.yticks(())
    
            # Plot the test point
            plt.scatter(test_x[0][0].item(), test_x[0][1].item(), c='black', s=50)
            
            # Save to file
            plt.savefig("spec_" + identifier + ".png", dpi =400, bbox_inches='tight')
            



