from kan import *
from tqdm import tqdm
import random
import numpy as np
import torch
from torch import nn
import matplotlib
import os
import noise

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
matplotlib.use('TkAgg')
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


# Define a MLP class
class MLP(nn.Module):
    def __init__(self, shape):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(len(shape) - 1):
            self.layers.append(nn.Linear(shape[i], shape[i + 1]))

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        return self.layers[-1](x)


# Define a function to train MLP
def train(model, dataset, epochs, lr, device='cuda'):
    model.to(device)
    optimizer = torch.optim.LBFGS(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    train_input = dataset['train_input'].to(device)
    train_label = dataset['train_label'].to(device)

    train_losses = []
    model.train()
    for epoch in tqdm(range(epochs)):
        def closure():
            optimizer.zero_grad()
            outputs = model(train_input)
            loss = loss_fn(outputs, train_label)
            loss.backward()
            return loss

        optimizer.step(closure)
        with torch.no_grad():
            train_output = model(train_input)
            train_loss = torch.sqrt(loss_fn(train_output, train_label))
            train_losses.append(train_loss.item())

    return train_losses


# Define a set of functions
f1 = lambda x: x[:, [0]] ** 2
f2 = lambda x: torch.exp(x[:, [0]])

f3 = lambda x: abs(x[:, [0]])
f4 = lambda x: 1 - torch.sqrt(abs(x[:, [0]]))

f5 = lambda x: ((x[:, [0]] > -0.5) & (x[:, [0]] < 0.5)).float()  # Square wave function
f6 = lambda x: torch.where((torch.abs(x[:, [0]]) < 0.5), 1 - 4 * (x[:, [0]] ** 2), torch.tensor([1.]).to(x.device))

f7 = lambda x: 1 / (x[:, [0]] ** 2)
f8 = lambda x: 1 / (1 - x[:, [0]] ** 2) - 1

f9 = lambda x: torch.cos(1 / x[:, [0]])
f10 = lambda x: torch.cos(2 * torch.pi / (1 - x[:, [0]] ** 2))

f = f3
SNR_goal = -4
train_nums = [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]
test_losses_kan1 = []
test_losses_kan2 = []
test_losses_mlp1 = []
test_losses_mlp2 = []

for train_num in train_nums:
    dataset = create_dataset(f, n_var=1, train_num=train_num, device='cuda')

    # Add noise to the training labels
    sigma = 0.001
    SNR = 1000
    while SNR > SNR_goal:
        train_label_noise, SNR = noise.add_noise(dataset['train_label'].cpu().numpy(), sigma, seed=seed)
        sigma += 0.001
    dataset['train_label'] = torch.tensor(train_label_noise).to('cuda')

    kan_model1 = KAN(width=[1, 5, 1], grid=3, k=3, seed=0, device='cuda')
    kan_results = kan_model1.fit(dataset, opt="LBFGS", steps=50, lr=1)
    test_losses_kan1.append(kan_results['test_loss'][-1])

    kan_model2 = KAN(width=[1, 10, 1], grid=3, k=3, seed=0, device='cuda')
    kan_results = kan_model2.fit(dataset, opt="LBFGS", steps=50, lr=0.1)
    test_losses_kan2.append(kan_results['test_loss'][-1])

    mlp_model1 = MLP([1, 39, 1]).to('cuda')
    train_losses1 = train(model=mlp_model1, dataset=dataset, epochs=50, lr=1, device='cuda')
    mlp_outputs = mlp_model1(dataset['test_input'])
    mlp_loss = torch.sqrt(nn.MSELoss()(mlp_outputs, dataset['test_label']))
    test_losses_mlp1.append(mlp_loss.item())

    mlp_model2 = MLP([1, 79, 1]).to('cuda')
    train_losses2 = train(model=mlp_model2, dataset=dataset, epochs=50, lr=0.1, device='cuda')
    mlp_outputs = mlp_model2(dataset['test_input'])
    mlp_loss = torch.sqrt(nn.MSELoss()(mlp_outputs, dataset['test_label']))
    test_losses_mlp2.append(mlp_loss.item())

# Print the test losses for all models
print('############# kan 1 #############')
for i in range(len(test_losses_kan1)):
    print(test_losses_kan1[i])
print('############# mlp 1 #############')
for i in range(len(test_losses_mlp1)):
    print(test_losses_mlp1[i])
print('############# kan 2 #############')
for i in range(len(test_losses_kan2)):
    print(test_losses_kan2[i])
print('############# mlp 2 #############')
for i in range(len(test_losses_mlp2)):
    print(test_losses_mlp2[i])
