import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import torch.autograd as autograd

from sklearn.datasets import make_circles, make_classification, make_moons

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

import copy 

import numpy as np

np.random.seed(10)
torch.manual_seed(10)

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
    )

print(f"Using {device} device")

# Generate the data points
num_data = 135
num_classes = 3
class_colors = ['red', 'green', 'blue']

X = np.random.uniform(-2.0, 2.0, size=(num_data,2))
Y = []
for x in X:
    if np.linalg.norm(x) < 1:
        Y.append(0)
    elif x[1] < 0:
        Y.append(1)
    else:
        Y.append(2)
        
X = torch.Tensor(X)
Y = torch.Tensor(Y).type(torch.LongTensor)

# Create the mesh for showing decision boundaries
h = 0.02
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# The model
model_width = 10
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_map = nn.Sequential(
            nn.Linear(2, model_width),
            nn.ReLU(),
            )
        self.lin_classifier = nn.Linear(model_width, num_classes)
    
    def forward(self, x):
        f = self.feature_map(x)
        logits = self.lin_classifier(f)
        return logits
    
model = NeuralNetwork().to(device)

# Training hyperparameters
num_iterations = 3000
l2_lambda = 0.01

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Optimizer & lr_scheduler
optimizer = torch.optim.SGD(
    [{'params' : model.feature_map.parameters()},
     {'params' : model.lin_classifier.parameters(), 'weight_decay' : (2 * l2_lambda)}], 
    lr = 1.0,
    momentum=0.5)
scheduler = lr_scheduler.LinearLR(
    optimizer, 
    start_factor=0.2,
    end_factor=0.3,
    total_iters=num_iterations)

# Best loss and parameters
best_loss_val = float('inf')
best_loss_params = None

# Train the model
loss_val = []
iter = 1
while iter <= num_iterations:

    # Enable for batch normalization and dropout layers
    model.train() 
    
    # Compute prediction & loss
    logits = model(X)
    loss = loss_fn(logits, Y)
    loss_val.append(loss.item())

    # Update best loss (if necessary)
    if loss.item() < best_loss_val:
        best_loss_val = loss.item()
        best_loss_params = copy.deepcopy(model.state_dict())

    # Backprop
    loss.backward()
    
    # GD step
    optimizer.step()
    optimizer.zero_grad()
    
    # Go next
    scheduler.step()
    iter += 1
    
# Plot the training losses
plt.figure("training losses")
plt.title('Training losses')
plt.plot(loss_val)

# Test point
test_x = torch.Tensor([[-1.25, 1.25]])
test_y = torch.Tensor([2]).type(torch.LongTensor)
    
# --- Compute the Hessian of the model --- #
# Flatten all the parameters
model_weight_1 = model.feature_map[0].weight
model_bias_1 = model.feature_map[0].bias
model_weight_2 = model.lin_classifier.weight
model_bias_2 = model.lin_classifier.bias
params = torch.cat([model_weight_1.flatten(), model_bias_1.flatten(), model_weight_2.flatten(), model_bias_2.flatten()])
# The function generator
def make_function(x,y):
    def model_fn(params):
        w_1 = params[:20].view(10,2)
        b_1 = params[20:30].view(10,1)
        w_2 = params[30:60].view(3,10)
        b_2 = params[60:63].view(3,1)
        output = torch.matmul(w_1, torch.transpose(x, 0, 1))
        #print(output)
        #print(b_1)
        output = torch.add(output, b_1)
        #print(output)
        output = torch.relu(output)
        #print(output)
        output = torch.matmul(w_2, output)
        #print(output)
        output = torch.add(output, b_2)
        #print(output)
        output = output.view(1,3)
        loss = loss_fn(output, y)
        return loss
    return model_fn
# Compute the hessian
hessian = None
for i in range(num_data):
    if i == 0:
        hessian = autograd.functional.hessian(make_function(X[i:i+1], Y[i:i+1]), params)
    else:
        hessian = torch.add(hessian, autograd.functional.hessian(make_function(X[i:i+1], Y[i:i+1]), params))
hessian = torch.multiply(hessian, 1/num_data)
# Add small pertubation to make it invertible
hessian = torch.add(hessian, 0.001 * torch.eye(63))
inv_hessian = torch.inverse(hessian)

# Compute the influence function
model_val = []
att_val = []
for i in range(num_data):
    
    # --- Find the influence value of the training point i --- #
    # Compute the gradient of the training point
    model_fn = make_function(X[i:i+1], Y[i:i+1])
    loss = model_fn(params)
    grad_train = autograd.grad(outputs=loss, inputs=params)[0].view(63, 1)
    
    # Compute the gradient of the test point
    model_fn = make_function(test_x, test_y)
    loss = model_fn(params)
    grad_test = autograd.grad(outputs=loss, inputs=params)[0].view(63, 1)
    
    # Compute the global influence
    parameter_shift = -torch.matmul(inv_hessian, grad_train)
    m_val = torch.norm(parameter_shift)
    model_val.append(m_val.item())
    
    # Compute the test influence
    a_val = torch.matmul(grad_test.view(1,63), parameter_shift)
    att_val.append(a_val.item())
    
with torch.no_grad():
    
    # Normalize the model influnce
    contrast = 1
    model_val = np.array(model_val)
    model_val = np.power(model_val, contrast)
    model_val = (model_val - model_val.min()) / (model_val.max() - model_val.min())
    
    # Separate the training points by class and influence sign
    X_0 = []
    X_1 = []
    X_2 = []
    X_0_pos = []
    X_0_neg = []
    X_1_pos = []
    X_1_neg = []
    X_2_pos = []
    X_2_neg = []
    a_0 = []
    a_1 = []
    a_2 = []
    c_0_pos = []
    c_0_neg = []
    c_1_pos = []
    c_1_neg = []
    c_2_pos = []
    c_2_neg = []
    i = 0
    
    while i < num_data:
        if Y[i] == 0:
            X_0.append([X[i][0].item(), X[i][1].item()])
            
            # Contribution to model
            a_0.append(model_val[i])
            
            # Contribution to test point
            if att_val[i] >= 0:
                X_0_pos.append([X[i][0].item(), X[i][1].item()])
                c_0_pos.append(att_val[i])
            else:
                X_0_neg.append([X[i][0].item(), X[i][1].item()])
                c_0_neg.append(att_val[i])
                
        elif Y[i] == 1:
            X_1.append([X[i][0].item(), X[i][1].item()])
            
            # Contribution to model
            a_1.append(model_val[i])
            
            # Contribution to test point
            if att_val[i] >= 0:
                X_1_pos.append([X[i][0].item(), X[i][1].item()])
                c_1_pos.append(att_val[i])
            else:
                X_1_neg.append([X[i][0].item(), X[i][1].item()])
                c_1_neg.append(att_val[i])
                
        else:
            X_2.append([X[i][0].item(), X[i][1].item()])
            
            # Contribution to model
            a_2.append(model_val[i])
            
            # Contribution to test point
            if att_val[i] >= 0:
                X_2_pos.append([X[i][0].item(), X[i][1].item()])
                c_2_pos.append(att_val[i])
            else:
                X_2_neg.append([X[i][0].item(), X[i][1].item()])
                c_2_neg.append(att_val[i])
        
        i += 1
    
    # --- Plot the results --- #
    
    # For plotting the decision boundaries
    Z = model(torch.tensor(np.column_stack([xx.ravel(), yy.ravel()])).type(torch.FloatTensor))
    al = 0.25
    
    # Plot the region of class 0
    Z_0 = nn.Softmax(dim=1)(Z).numpy()[:,0].reshape(xx.shape)
    cm_0 = ListedColormap([(1.0, 0.0, 0.0, 0.0), (1.0, 0.0, 0.0, al)])
    
    # Plot the region of class 1
    Z_1 = nn.Softmax(dim=1)(Z).numpy()[:,1].reshape(xx.shape)
    cm_1 = ListedColormap([(0.0, 1.0, 0.0, 0.0), (0.0, 1.0, 0.0, al)])
    
    # Plot the region of class 2
    Z_2 = nn.Softmax(dim=1)(Z).numpy()[:,2].reshape(xx.shape)
    cm_2 = ListedColormap([(0.0, 0.0, 1.0, 0.0), (0.0, 0.0, 1.0, al)])
    
    # Plot training points
    plt.figure("training points")
    
    # Plot the decision boundaries
    plt.contourf(xx, yy, Z_0, cmap=cm_0)
    plt.contourf(xx, yy, Z_1, cmap=cm_1)
    plt.contourf(xx, yy, Z_2, cmap=cm_2)
    
    # Plot the training points
    plt.scatter(np.array(X_0)[:,0], np.array(X_0)[:,1], c='red')
    plt.scatter(np.array(X_1)[:,0], np.array(X_1)[:,1], c='green')
    plt.scatter(np.array(X_2)[:,0], np.array(X_2)[:,1], c='blue')
    
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())
    
    plt.savefig("example_1_1.png", dpi=200, bbox_inches='tight')
    
    # Plot training points with normalized attribution values to model
    plt.figure("attributions to model")
    
    # Plot the decision boundaries
    plt.contourf(xx, yy, Z_0, cmap=cm_0)
    plt.contourf(xx, yy, Z_1, cmap=cm_1)
    plt.contourf(xx, yy, Z_2, cmap=cm_2)
    
    # Plot the points
    plt.scatter(np.array(X_0)[:,0], np.array(X_0)[:,1], alpha=a_0, c='red')
    plt.scatter(np.array(X_1)[:,0], np.array(X_1)[:,1], alpha=a_1, c='green')
    plt.scatter(np.array(X_2)[:,0], np.array(X_2)[:,1], alpha=a_2, c='blue')
    
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())

    plt.savefig("example_1_2.png", dpi=200, bbox_inches='tight')
    
    # Plot training points with normalized attribution values to test point
    plt.figure("attributions to test point")
    
    # Plot the decision boundaries
    plt.contourf(xx, yy, Z_0, cmap=cm_0)
    plt.contourf(xx, yy, Z_1, cmap=cm_1)
    plt.contourf(xx, yy, Z_2, cmap=cm_2)

    # Normalize the attributions
    contrast = 1
    
    c_0_pos = np.power(np.array(c_0_pos), contrast)
    c_0_pos = (c_0_pos - min(c_0_pos)) / (max(c_0_pos) - min(c_0_pos))
    c_0_neg = np.abs(np.array(c_0_neg))
    c_0_neg = np.power(c_0_neg, contrast)
    c_0_neg = (c_0_neg - min(c_0_neg)) / (max(c_0_neg) - min(c_0_neg))

    c_1_pos = np.power(np.array(c_1_pos), contrast)
    c_1_pos = (c_1_pos - min(c_1_pos)) / (max(c_1_pos) - min(c_1_pos))
    c_1_neg = np.abs(np.array(c_1_neg))
    c_1_neg = np.power(c_1_neg, contrast)
    c_1_neg = (c_1_neg - min(c_1_neg)) / (max(c_1_neg) - min(c_1_neg))
    
    c_2_pos = np.power(np.array(c_2_pos), contrast)
    c_2_pos = (c_2_pos - min(c_2_pos)) / (max(c_2_pos) - min(c_2_pos))
    c_2_neg = np.abs(np.array(c_2_neg))
    c_2_neg = np.power(c_2_neg, contrast)
    c_2_neg = (c_2_neg - min(c_2_neg)) / (max(c_2_neg) - min(c_2_neg))
    
    # Plot the points
    plt.scatter(np.array(X_0_pos)[:,0], np.array(X_0_pos)[:,1], alpha=c_0_pos, c='red', marker='+')
    plt.scatter(np.array(X_0_neg)[:,0], np.array(X_0_neg)[:,1], alpha=c_0_neg, c='red', marker='_')
    
    plt.scatter(np.array(X_1_pos)[:,0], np.array(X_1_pos)[:,1], alpha=c_1_pos, c='green', marker='+')
    plt.scatter(np.array(X_1_neg)[:,0], np.array(X_1_neg)[:,1], alpha=c_1_neg, c='green', marker='_')
    
    plt.scatter(np.array(X_2_pos)[:,0], np.array(X_2_pos)[:,1], alpha=c_2_pos, c='blue', marker='+')
    plt.scatter(np.array(X_2_neg)[:,0], np.array(X_2_neg)[:,1], alpha=c_2_neg, c='blue', marker='_')
    
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())
    
    # Plot the test point
    plt.scatter(test_x[0][0].item(), test_x[0][1].item(), c='black')
    
    plt.savefig("synthetic_if_0", dpi =400, bbox_inches='tight')
