import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


# Set all seeds
torch.manual_seed(3)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(2)
torch.cuda.manual_seed_all(2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


latent_dim = 200
compare_with_deep = True
# Define the neural network model with two layers using tanh activation
class TwoLayerTanhNN(nn.Module):
    def __init__(self):
        super(TwoLayerTanhNN, self).__init__()
        self.linear1 = nn.Linear(1, latent_dim)  # First layer with 10 neurons
        self.tanh1 = nn.Tanh()
        self.linear2 = nn.Linear(latent_dim, 1)  # Second layer to output

    def forward(self, x):
        x = self.linear1(x)
        x = self.tanh1(x)
        x = self.linear2(x)
        return x


class FourLayerTanhNN(nn.Module):
    def __init__(self):
        super(FourLayerTanhNN, self).__init__()
        self.linear1 = nn.Linear(1, int(latent_dim))  # First layer with 10 neurons
        self.linear2 = nn.Linear(int(latent_dim), int(latent_dim))  # Second layer with 10 neurons
        self.linear3 = nn.Linear(int(latent_dim), int(latent_dim))  # Third layer with 10 neurons
        self.linear4 = nn.Linear(int(latent_dim), 1)  # Fourth layer to output
        self.tanh1 = nn.Tanh()
        self.tanh2 = nn.Tanh()
        self.tanh3 = nn.Tanh()

    def forward(self, x):
        x = self.linear1(x)
        x = self.tanh1(x)
        x = self.linear2(x)
        x = self.tanh2(x)
        x = self.linear3(x)
        x = self.tanh3(x)
        x = self.linear4(x)
        return x

class FourLayerReLUNN(nn.Module):
    def __init__(self):
        super(FourLayerReLUNN, self).__init__()
        self.linear1 = nn.Linear(1, int(latent_dim))  # First layer with 10 neurons
        self.linear2 = nn.Linear(int(latent_dim), int(latent_dim))  # Second layer with 10 neurons
        self.linear3 = nn.Linear(int(latent_dim), int(latent_dim))  # Third layer with 10 neurons
        self.linear4 = nn.Linear(int(latent_dim), 1)  # Fourth layer to output
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linear2(x)
        x = self.relu2(x)
        x = self.linear3(x)
        x = self.relu3(x)
        x = self.linear4(x)
        return x


class TwoLayer2TanhNN(nn.Module):
    def __init__(self):
        super(TwoLayer2TanhNN, self).__init__()
        self.linear1 = nn.Linear(1, int(100))  # First layer with 10 neurons
        self.linear1_2 = nn.Linear(1, int(100))  # First layer with 10 neurons
        self.tanh1 = nn.Tanh()
        self.tanh1_2 = nn.Tanh()
        self.linear2 = nn.Linear(int(100), 1)  # Second layer to output

    def forward(self, x):
        x1 = self.linear1(x)
        x1_mask = self.linear1_2(x)
        x1 = self.tanh1(x1)
        x1_mask = self.tanh1_2(x1_mask)
        x = x1 * x1_mask
        x = self.linear2(x)
        return x

class FourLayer2TanhNN(nn.Module):
    def __init__(self):
        super(FourLayer2TanhNN, self).__init__()
        self.linear1 = nn.Linear(1, int(latent_dim/2))
        self.linear1_2 = nn.Linear(1, int(latent_dim/2))
        self.linear2 = nn.Linear(int(latent_dim/2), int(latent_dim/2))
        self.linear2_2 = nn.Linear(int(latent_dim/2), int(latent_dim/2))
        self.linear3 = nn.Linear(int(latent_dim/2), int(latent_dim/2))
        self.linear3_2 = nn.Linear(int(latent_dim/2), int(latent_dim/2))
        self.linear4 = nn.Linear(int(latent_dim/2), 1)
        self.tanh1 = nn.Tanh()
        self.tanh1_2 = nn.Tanh()
        self.tanh2 = nn.Tanh()
        self.tanh2_2 = nn.Tanh()
        self.tanh3 = nn.Tanh()
        self.tanh3_2 = nn.Tanh()

    def forward(self, x):
        x1 = self.linear1(x)
        x1_mask = self.linear1_2(x)
        x1 = self.tanh1(x1)
        x1_mask = self.tanh1_2(x1_mask)
        x1 = x1 * x1_mask
        x2 = self.linear2(x1)
        x2_mask = self.linear2_2(x1)
        x2 = self.tanh2(x2)
        x2_mask = self.tanh2_2(x2_mask)
        x2 = x2 * x2_mask
        x3 = self.linear3(x2)
        x3_mask = self.linear3_2(x2)
        x3 = self.tanh3(x3)
        x3_mask = self.tanh3_2(x3_mask)
        x3 = x3 * x3_mask
        x = self.linear4(x3)
        return x

# Define the neural network model with two layers using ReLU activation
class TwoLayerReLUNN(nn.Module):
    def __init__(self):
        super(TwoLayerReLUNN, self).__init__()
        self.linear1 = nn.Linear(1, latent_dim)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(latent_dim, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linear2(x)
        return x

# Function to train any model

def train_model(model, x, y, epochs=10000):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        predictions = model(x)
    return predictions, loss


# Generating the x values
x = torch.linspace(-2 * np.pi, 2 * np.pi, 1000).unsqueeze(1)  # Reshape x to be a column vector

# Defining the complex sinusoidal function with high frequency and amplitude variations
y = 10 * torch.sin(7 * x) + 15 * torch.sin(10 * x) + 5 * torch.cos(5* x)

# Instantiate the models
deep_tanh_model = FourLayerTanhNN().to(device)
deep_relu_model = FourLayerReLUNN().to(device)
deep_tanh2_model = FourLayer2TanhNN().to(device)
tanh_model = TwoLayerTanhNN().to(device)
relu_model = TwoLayerReLUNN().to(device)
tanh2_model = TwoLayer2TanhNN().to(device)


# Count Trainable parameters:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(tanh_model), 'parameters Shallow Tanh')
print(count_parameters(relu_model), 'parameters Shallow ReLU')
print(count_parameters(tanh2_model), 'parameters Shallow Tanh with Hadamard')
print(count_parameters(deep_tanh_model), 'parameters Deep Tanh')
print(count_parameters(deep_relu_model), 'parameters Deep ReLU')
print(count_parameters(deep_tanh2_model), 'parameters Deep Tanh with Hadamard')

# Train the models
tanh_predictions, tanh_loss = train_model(tanh_model, x.to(device), y.to(device))
relu_predictions, relu_loss= train_model(relu_model, x.to(device), y.to(device))
deep_tanh_predictions, deep_tanh_loss = train_model(deep_tanh_model, x.to(device), y.to(device))
deep_relu_predictions, deep_relu_loss = train_model(deep_relu_model, x.to(device), y.to(device))
tanh2_predictions, tanh2_loss= train_model(tanh2_model, x.to(device), y.to(device))
deep_tanh2_predictions, deep_tanh2_loss = train_model(deep_tanh2_model, x.to(device), y.to(device))

colors_per_activation = {'tanh': 'blue', 'relu': 'orange', 'sigmoid': 'green', 'timestanh': 'red', 'times2tanh': 'purple', 'timessigmoid': 'brown', 'plustanh': 'pink', 'layernorm': 'skyblue', '1024': 'darkslateblue', 'selu': 'sandybrown'}
# Plotting the results
scale=0.9
plt.figure(figsize=(12*scale, 5*scale))

plt.subplot(1, 3, 1)
plt.plot(x.cpu().numpy(), y.cpu().numpy(), label='Target', color='black')
plt.plot(x.cpu().numpy(), tanh_predictions.cpu().numpy(), label='Tanh Network', linestyle='--', color=colors_per_activation['tanh'])
plt.title('Shallow Tanh, Loss = {:.2f}'.format(tanh_loss))
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(x.cpu().numpy(), y.cpu().numpy(), label='Target', color='black')
plt.plot(x.cpu().numpy(), relu_predictions.cpu().numpy(), label='ReLU Network', linestyle='--', color=colors_per_activation['relu'])
plt.title('Shallow ReLU, Loss = {:.2f}'.format(relu_loss))
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(x.cpu().numpy(), y.cpu().numpy(), label='Target', color='black')
plt.plot(x.cpu().numpy(), tanh2_predictions.cpu().numpy(), label=r'Tanh (Hadamard) Network', linestyle='--', color=colors_per_activation['timestanh'])
plt.title(r'Shallow Tanh (Hadamard)' + ', Loss = {:.2f}'.format(tanh2_loss))
plt.legend()

plt.tight_layout()
plt.savefig('Function_Approximation_Plots.pdf', format='pdf')
plt.show()
plt.close()
plt.clf()

plt.figure(figsize=(12*scale, 5*scale))

plt.subplot(1, 3, 1)
plt.plot(x.cpu().numpy(), y.cpu().numpy(), label='Target', color='black')
plt.plot(x.cpu().numpy(), deep_tanh_predictions.cpu().numpy(), label='Tanh Network', linestyle='--', color=colors_per_activation['tanh'])
plt.title('Deep Tanh, Loss = {:.2f}'.format(deep_tanh_loss))
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(x.cpu().numpy(), y.cpu().numpy(), label='Target', color='black')
plt.plot(x.cpu().numpy(), deep_relu_predictions.cpu().numpy(), label='ReLU Network', linestyle='--', color=colors_per_activation['relu'])
plt.title('Deep ReLU, Loss = {:.2f}'.format(deep_relu_loss))
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(x.cpu().numpy(), y.cpu().numpy(), label='Target', color='black')
plt.plot(x.cpu().numpy(), tanh2_predictions.cpu().numpy(), label=r'Tanh (Hadamard) Network', linestyle='--', color=colors_per_activation['timestanh'])
plt.title(r'Shallow Tanh (Hadamard)' + ', Loss = {:.2f}'.format(tanh2_loss))
plt.legend()

# plt.subplot(1, 3, 3)
# plt.plot(x.cpu().numpy(), y.cpu().numpy(), label='Target', color='black')
# plt.plot(x.cpu().numpy(), deep_tanh2_predictions.cpu().numpy(), label=r'Tanh (Hadamard) Network', linestyle='--', color=colors_per_activation['timestanh'])
# plt.title(r'Deep Tanh (Hadamard)' + ', Loss = {:.2f}'.format(deep_tanh2_loss))
# plt.legend()

plt.tight_layout()
plt.savefig('Function_Approximation_Plots_Deep.pdf', format='pdf')
plt.show()