import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim

def f(x):
    # return 10 * np.sin(4 * x) + 1 * np.sin(15 * x)
    return np.sin(2* np.pi*x) + 1 


x = np.linspace(0, 10 * np.pi, 3000)
y = f(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(1).to(device)
y_tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(1).to(device)
x_train, x_test, y_train, y_test = train_test_split(x_tensor, y_tensor, test_size=0.00001, random_state=42)
x_new = np.linspace(2 * np.pi,  3* np.pi, 2000)
y_new = f(x_new)
x_new_tensor = torch.tensor(x_new, dtype=torch.float32).unsqueeze(1).to(device)
y_new_tensor = torch.tensor(y_new, dtype=torch.float32).unsqueeze(1).to(device)

class CNO1D(nn.Module):
    def __init__(self):
        super(CNO1D, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv1d(256, 256, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension for Conv1D
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        x = x.mean(dim=2)  # Global average pooling
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = CNO1D().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5000
train_losses = []

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")




model.eval()
with torch.no_grad():
    y_pred = model(x_test)
    test_loss = criterion(y_pred, y_test)

print(f"Test Loss: {test_loss.item():.4f}")

# Predictions on new data
x_test = torch.linspace(0, 10 * np.pi, 200)
x_test_tensor = x_test.unsqueeze(1).float().to(device)

with torch.no_grad():
    y_pred = model(x_test_tensor)

# ---- MOVE EVERYTHING TO CPU BEFORE USING NUMPY / MATPLOTLIB ----
x_cpu = x_test.cpu().numpy()                               # (200,)
y_pred_cpu = y_pred.detach().cpu().squeeze().numpy()       # (200,)
y_true_cpu = f(x_cpu)                                      # ensure numpy input

# Std deviation (computed on CPU)
y_std = y_pred_cpu.std()

# ---- PLOTTING ----
plt.figure(figsize=(12, 6))

plt.scatter(x_cpu, y_true_cpu, 
            label="Actual Training Points", marker='x', color="red")

plt.plot(x_cpu, y_pred_cpu, label="Predicted Curve")

# 95% confidence band (gray)
plt.fill_between(
    x_cpu,
    y_pred_cpu - 1 * y_std,
    y_pred_cpu + 1 * y_std,
    color='gray', alpha=0.2
)

plt.title("Problem with CNN on Approximating Periodic Function")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()
