import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import torch.autograd as autograd

from sklearn.datasets import make_circles, make_classification, make_moons

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

import copy 

import numpy as np

np.random.seed(10)
torch.manual_seed(10)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
    )

print(f"Using {device} device")

# Generate the data points
num_data = 135
num_classes = 3
class_colors = ['red', 'green', 'blue']

X = np.random.uniform(-2.0, 2.0, size=(num_data,2))
Y = []
for x in X:
    if np.linalg.norm(x) < 1:
        Y.append(0)
    elif x[1] < 0:
        Y.append(1)
    else:
        Y.append(2)
        
X = torch.Tensor(X)
Y = torch.Tensor(Y).type(torch.LongTensor)

# Create the mesh for showing decision boundaries
h = 0.02
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# The model
model_width = 10
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_map = nn.Sequential(
            nn.Linear(2, model_width),
            nn.ReLU(),
            )
        self.lin_classifier = nn.Linear(model_width, num_classes)
    
    def forward(self, x):
        f = self.feature_map(x)
        logits = self.lin_classifier(f)
        return logits
    
model = NeuralNetwork().to(device)

# Training hyperparameters
num_iterations = 3000
l2_lambda = 0.01

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Optimizer & lr_scheduler
optimizer = torch.optim.SGD(
    [{'params' : model.feature_map.parameters()},
     {'params' : model.lin_classifier.parameters(), 'weight_decay' : (2 * l2_lambda)}], 
    lr = 1.0,
    momentum=0.5)
scheduler = lr_scheduler.LinearLR(
    optimizer, 
    start_factor=0.2,
    end_factor=0.3,
    total_iters=num_iterations)

# Best loss and parameters
best_loss_val = float('inf')
best_loss_params = None

# Train the model
loss_val = []
iter = 1
while iter <= num_iterations:

    # Enable for batch normalization and dropout layers
    model.train() 
    
    # Compute prediction & loss
    logits = model(X)
    loss = loss_fn(logits, Y)
    loss_val.append(loss.item())

    # Update best loss (if necessary)
    if loss.item() < best_loss_val:
        best_loss_val = loss.item()
        best_loss_params = copy.deepcopy(model.state_dict())

    # Backprop
    loss.backward()
    
    # GD step
    optimizer.step()
    optimizer.zero_grad()
    
    # Go next
    scheduler.step()
    iter += 1
    
# Plot the training losses
plt.figure("training losses")
plt.title('Training losses')
plt.plot(loss_val)

# Specify the test point
test_x = torch.Tensor([[1.25,1.25]])

### Compute the representer points values
model.train(False)
model.eval()
with torch.no_grad():
      
    # Get the predicted classes for each training point
    logits = model(X)
    pred_probab = nn.Softmax(dim=1)(logits)
    Y_pred = pred_probab.argmax(1)

    # Get the features
    features = model.feature_map(X)
     
    # The true one hot vectors
    Y_one_hot = F.one_hot(Y, num_classes=num_classes)

    alpha = Y_one_hot - pred_probab
    alpha = alpha / (2 * num_data * l2_lambda)    
    
    # The predicted class for the test point
    test_logits = model(test_x)
    test_pred_probab = nn.Softmax(dim=1)(test_logits)
    test_pred_class = test_pred_probab.argmax(1).item()
    #print(f"{test_logits} -> {test_pred_probab} -> {test_pred_class}")
    
    # feature of the test point
    test_features = model.feature_map(test_x)[0]
    
    # Logit output through representation & attribution values
    att_val = []
    test_replogits = None
    i = 0
    while i < num_data:
        att_vec = torch.dot(test_features, features[i]) * alpha[i]
        if (test_replogits == None):
            test_replogits = att_vec
        else:
            test_replogits = torch.add(test_replogits, att_vec)
        att_val.append(att_vec[test_pred_class].item())
        i += 1
        
    test_replogits = test_replogits[None,:]
    test_pred_probab = nn.Softmax(dim=1)(test_replogits)
    #print(f"{test_replogits} -> {test_pred_probab}")

    # Adjust the contrast of the attribute differences for better plots
    contrast = 1
    for i in range(len(att_val)):
        if att_val[i] >= 0:
            att_val[i] = np.power(att_val[i], contrast)
        else:
            att_val[i] = - np.power(np.abs(att_val[i]), contrast)

    # Separate the training points by class
    X_0 = []
    X_1 = []
    X_2 = []
    a_0 = []
    a_1 = []
    a_2 = []
    c_0 = []
    c_1 = []
    c_2 = []
    i = 0
    
    while i < num_data:
        if Y[i] == 0:
            X_0.append([X[i][0].item(), X[i][1].item()])
            
            # Contribution to model
            a_0.append(alpha[i][0].item())
            
            # Contribution to test point
            c_0.append(att_val[i])
            '''
            if test_pred_class == 0:
                c_0.append(att_val[i])
            else:
                c_0.append(np.abs(att_val[i]))
            '''
        elif Y[i] == 1:
            X_1.append([X[i][0].item(), X[i][1].item()])
            
            # Contribution to model
            a_1.append(alpha[i][1].item())
            
            # Contribution to test point
            c_1.append(att_val[i])
            '''
            if test_pred_class == 1:
                c_1.append(att_val[i])
            else:
                c_1.append(np.abs(att_val[i]))
            '''
        else:
            X_2.append([X[i][0].item(), X[i][1].item()])
            
            # Contribution to model
            a_2.append(alpha[i][2].item())
            
            # Contribution to test point
            c_2.append(att_val[i])
            '''
            if test_pred_class == 2:
                c_2.append(att_val[i])
            else:
                c_2.append(np.abs(att_val[i]))
            '''
        i += 1
            
    # --- Plot the results --- #
    
    # For plotting the decision boundaries
    Z = model(torch.tensor(np.column_stack([xx.ravel(), yy.ravel()])).type(torch.FloatTensor))
    al = 0.25
    
    # Plot the region of class 0
    Z_0 = nn.Softmax(dim=1)(Z).numpy()[:,0].reshape(xx.shape)
    cm_0 = ListedColormap([(1.0, 0.0, 0.0, 0.0), (1.0, 0.0, 0.0, al)])
    
    # Plot the region of class 1
    Z_1 = nn.Softmax(dim=1)(Z).numpy()[:,1].reshape(xx.shape)
    cm_1 = ListedColormap([(0.0, 1.0, 0.0, 0.0), (0.0, 1.0, 0.0, al)])

    # Plot the region of class 2
    Z_2 = nn.Softmax(dim=1)(Z).numpy()[:,2].reshape(xx.shape)
    cm_2 = ListedColormap([(0.0, 0.0, 1.0, 0.0), (0.0, 0.0, 1.0, al)])
    
    # Plot training points
    plt.figure("training points")
    
    # Plot the decision boundaries
    plt.contourf(xx, yy, Z_0, cmap=cm_0)
    plt.contourf(xx, yy, Z_1, cmap=cm_1)
    plt.contourf(xx, yy, Z_2, cmap=cm_2)
    
    # Plot the training points
    plt.scatter(np.array(X_0)[:,0], np.array(X_0)[:,1], c='red')
    plt.scatter(np.array(X_1)[:,0], np.array(X_1)[:,1], c='green')
    plt.scatter(np.array(X_2)[:,0], np.array(X_2)[:,1], c='blue')

    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())
    
    plt.savefig("synthetic_example.png", dpi=400, bbox_inches='tight')

    # Plot training points with normalized attribution values to model
    plt.figure("attributions to model")
    
    # Plot the decision boundaries
    plt.contourf(xx, yy, Z_0, cmap=cm_0)
    plt.contourf(xx, yy, Z_1, cmap=cm_1)
    plt.contourf(xx, yy, Z_2, cmap=cm_2)
    
    # Plot the points
    contrast = 1
    a_0 = np.power(np.array(a_0), contrast)
    a_0 = (a_0 - a_0.min()) / (a_0.max() - a_0.min())
    a_1 = np.power(np.array(a_1), contrast)
    a_1 = (a_1 - a_1.min()) / (a_1.max() - a_1.min())
    a_2 = np.power(np.array(a_2), contrast)
    a_2 = (a_2 - a_2.min()) / (a_2.max() - a_2.min())
    plt.scatter(np.array(X_0)[:,0], np.array(X_0)[:,1], alpha=a_0, c='red')
    plt.scatter(np.array(X_1)[:,0], np.array(X_1)[:,1], alpha=a_1, c='green')
    plt.scatter(np.array(X_2)[:,0], np.array(X_2)[:,1], alpha=a_2, c='blue')
    
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())

    plt.savefig("example_1_2.png", dpi=200, bbox_inches='tight')

    # Plot the contribution values
    plt.figure("influence values")
    plt.plot(c_2)

    # Plot training points with normalized attribution values to test point
    plt.figure("attributions to test point")
    
    # Plot the decision boundaries
    plt.contourf(xx, yy, Z_0, cmap=cm_0)
    plt.contourf(xx, yy, Z_1, cmap=cm_1)
    plt.contourf(xx, yy, Z_2, cmap=cm_2)
    
    # Plot the points
    #max_c_val = max(att_val)
    #min_c_val = min(att_val)
    #print(min_c_val)
    #print(max_c_val)
    c_0 = (c_0 - min(c_0)) / (max(c_0) - min(c_0))
    c_1 = (c_1 - min(c_1)) / (max(c_1) - min(c_1))
    c_2 = (c_2 - min(c_2)) / (max(c_2) - min(c_2))
    
    #plt.scatter(np.array(X_0)[:,0], np.array(X_0)[:,1], alpha=c_0, c='red')
    #plt.scatter(np.array(X_1)[:,0], np.array(X_1)[:,1], alpha=c_1, c='green')
    plt.scatter(np.array(X_2)[:,0], np.array(X_2)[:,1], alpha=c_2, c='blue')
    
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())

    # Plot the test point
    plt.scatter(test_x[0][0].item(), test_x[0][1].item(), c='black', s=50)

    plt.savefig("synthetic_rp_1.png", dpi =400, bbox_inches='tight')

    
    
    
    
