from kan.KANLayer import *
from tqdm import tqdm
import torch
from torch.optim import Adam
import time
import matplotlib.pyplot as plt
start_time = time.time()
# Hyperparameters
tol = 1e-1
hidden_size = 128
n_epochs = 5000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define function
def function(x):
    return 2+torch.sin(2 * torch.pi * x)
    # return 2*x

# Define the first vector
p_vector0 = torch.normal(0, 1, size=(1, 2 * hidden_size)).to(device)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

# Define the second vector
p_vector1 = torch.normal(0, 1, size=(1, hidden_size)).to(device)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)  # Normalize

# Define the third vector
p_vector2 = torch.normal(0, 1, size=(1, hidden_size // 2)).to(device)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)  # Normalize

p_vector0 = p_vector0[:, :hidden_size*2]
p_vector1 = p_vector1[:, :hidden_size]
p_vector2 = p_vector2[:, :hidden_size//2]

# Define goodness functions properly
def goodness(layer, x,k):
    # return torch.sum(layer(x)[0]**2, dim=1)  # Extract first output
    cos =nn.CosineSimilarity(dim=1, eps=1e-6)
    if k==0:
        return cos(layer(x)[0], p_vector0)
    
    elif k==1:
        return cos(layer(x)[0], p_vector1) 
    elif k==2:
        return cos(layer(x)[0], p_vector2)
        
def predict(x, layer1, layer2, layer3):
    # Ensure input x is a float tensor and move to the correct device
    # if x is integer, make it tensor
    if isinstance(x, int):
        x = torch.tensor(x).view(1, 1)
    x = x.float().to(device)
    # Generate trail points
    n = 1000
    trail_y = torch.linspace(0.5, 3.5, n, device=device).view(n, 1)
    trail_x = x.repeat(n, 1)
    
    trail_points_intol = torch.cat((trail_x, trail_y, torch.ones_like(trail_x)), 1)
    trail_points_outtol = torch.cat((trail_x, trail_y, torch.zeros_like(trail_x)), 1)  

    # Compute goodness for correct and incorrect trail poicnts
    # print((trail_points_correct).shape)
    goodness_intol = (
        goodness(layer1, trail_points_intol,0) +
        goodness(layer2, layer1(trail_points_intol)[0],1) +
        goodness(layer3, layer2(layer1(trail_points_intol)[0])[0],2)
    )
    goodness_outtol = (
        goodness(layer1, trail_points_outtol,0) +
        goodness(layer2, layer1(trail_points_outtol)[0],1) +
        goodness(layer3, layer2(layer1(trail_points_outtol)[0])[0],2)
    )
    # print(goodness_pos,goodness_neg)
    mask = goodness_intol < goodness_outtol
    # print(mask)
    return trail_y[mask]
    
def create_dataset(x, n):
	y_noised = function(x).view(1, n).repeat(10, 1) + torch.randn(10, n) * tol  # 0.1
	x = x.repeat(10, 1)

	in_tol_pos = torch.cat(
		(x, y_noised, torch.ones_like(x, dtype=torch.float)), 0
	).view(3, 10, n)
	in_tol_neg = torch.cat((x, y_noised, torch.ones_like(x, dtype=torch.float) * 0.0), 0).view(3, 10, n)

	y_high = function(x[:5]) + tol  # upper tolerance band
	y_low = function(x[:5]) - tol  # lower tolerance band

	# Use linspace for evenly spaced points
	y_above = torch.zeros_like(y_noised[:5])
	y_below = torch.zeros_like(y_noised[:5])
	y_max = y_noised.max().repeat(5, n) + 2
	y_min = y_noised.min().repeat(5, n) - 2

	for i in range(5):
		y_above[i] = (y_max[i] - y_high[i]) * torch.rand(1, n) + y_high[i]
		y_below[i] = (y_low[i] - y_min[i]) * torch.rand(1, n) + y_min[i]
	y_out_tol = torch.cat((y_above, y_below), 0)  # shape [10, n]: 10 points above and below tolerance band

	out_tol_pos = torch.cat(
		(x, y_out_tol, torch.ones_like(x, dtype=torch.float) * 0.0), 0
	).view(3, 10, n)  # -1 is correct label for out of tol data
	out_tol_neg = torch.cat(
		(x, y_out_tol, torch.ones_like(x, dtype=torch.float)), 0
	).view(3, 10, n)

	positive_data = torch.cat((in_tol_pos, out_tol_pos), 1).flatten(1, 2)
	negative_data = torch.cat((in_tol_neg, out_tol_neg), 1).flatten(1, 2)

	return positive_data, negative_data


def plot_data(positive_data, negative_data):
    plt.scatter(positive_data[0], positive_data[1], color="green", label="Positive",s=3)
    # plt.scatter(negative_data[0], negative_data[1], color="red", label="Negative",s=3)
    plt.title("Dataset")
    plt.legend()
    plt.show()

# Training function
def train_layer(layer, positive_data, negative_data, n_epochs, k,lr=1e-3):
    optimizer = Adam(layer.parameters(), lr=lr)

    def closure():
        optimizer.zero_grad()
        y_pos, _, _, _ = layer(positive_data.t())
        goodness_pos = goodness(layer, positive_data.t(),k)
        y_neg, _, _, _ = layer(negative_data.t())
        goodness_neg = goodness(layer, negative_data.t(),k)
        loss = torch.log(1 + torch.exp(goodness_pos - goodness_neg)).mean()
        loss.backward(retain_graph=True)
        return loss

    for epoch in tqdm(range(n_epochs)):
        optimizer.step(closure)
        torch.nn.utils.clip_grad_norm_(layer.parameters(), max_norm=1.0)
        if epoch % 1000 == 0:
            tqdm.write(f"Epoch {epoch}, Loss: {closure().item()}")
            # print(f"Epoch {epoch}, Loss: {closure().item()}")

# Main script
if __name__ == "__main__":
    # Dataset setup
    x_low, x_high, n = 0, 1, 20
    x = torch.linspace(x_low, x_high, n).view(1, n)
    print("Shape of x: ", x.shape)
    positive_data, negative_data = create_dataset(x, n)
    plot_data(positive_data, negative_data)

    positive_data, negative_data = positive_data.to(device), negative_data.to(device)

    # Define layers
    layer1 = KANLayer(in_dim=3, out_dim=hidden_size * 2).to(device)
    layer2 = KANLayer(in_dim=hidden_size * 2, out_dim=hidden_size).to(device)
    layer3 = KANLayer(in_dim=hidden_size, out_dim=hidden_size // 2).to(device)

    # Train layers sequentially
    print("Training Layer 1...")
    train_layer(layer1, positive_data, negative_data, n_epochs,0)
    y1_pos, _, _, _ = layer1(positive_data.t())
    y1_neg, _, _, _ = layer1(negative_data.t())
    # print(positive_data.t().shape, negative_data.shape)
    # print(y1_pos.shape, y1_neg.shape)
    print("Training Layer 2...")
    train_layer(layer2, y1_pos.t(), y1_neg.t(), n_epochs,1)
    y2_pos, _, _, _ = layer2(y1_pos)
    y2_neg, _, _, _ = layer2(y1_neg)
    # print(y2_pos.shape, y2_neg.shape)   
    print("Training Layer 3...")
    train_layer(layer3, y2_pos.t(), y2_neg.t(), n_epochs,2)
    y3_pos, _, _, _ = layer3(y2_pos)
    y3_neg, _, _, _ = layer3(y2_neg)
    print("Time taken: ", time.time() - start_time)

    x_test = torch.linspace(0,1, 20)
    y_test = function(x_test)
    y_pred = torch.zeros_like(x_test)
    y_std = torch.zeros_like(x_test)
    for i, x in enumerate(x_test):
        y_pred[i] = predict(x, layer1, layer2, layer3).mean()
        y_std[i] = predict(x, layer1, layer2, layer3).std()
        
    print("MSE: ", torch.mean((y_test - y_pred)**2))
    plt.scatter(x_test, y_test, label="True",marker='x',color='red')
    plt.plot(x_test, y_pred, label="Predicted",color='blue')
    plt.fill_between(x_test, y_pred - y_std, y_pred + y_std, color='gray', alpha=0.2)
    plt.title("Function Approximation using KAN")
    plt.legend()
    plt.show()