import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint  # Neural ODE integration
import matplotlib.pyplot as plt
import tqdm
import numpy as np
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
# Neural ODE block
class ODEFunc(nn.Module):
    def __init__(self, hidden_dim):
        super(ODEFunc, self).__init__()
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()

    def forward(self, t, x):
        return self.relu(self.linear(x))

# Neural ODE model
class NeuralODE(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(NeuralODE, self).__init__()
        self.odefunc = ODEFunc(hidden_dim)
        self.linear_in = nn.Linear(input_dim, hidden_dim)
        self.linear_out = nn.Linear(hidden_dim, 1)  # Assuming scalar output

    def forward(self, x):
        x = self.linear_in(x)
        # Fix: Ensure t is a floating point tensor
        out = odeint(self.odefunc, x, torch.tensor([0.0, 1.0], dtype=torch.float32).to(device))[-1]
        return self.linear_out(out)

# Train and evaluate the model on GPU
def train_and_evaluate(input_dim, hidden_dim, train_data, test_data, num_epochs=100):
    model = NeuralODE(input_dim, hidden_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()

    # Training loop
    for epoch in tqdm.tqdm(range(num_epochs)):
        model.train()
        for inputs, targets in train_data:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    # Calculate training and testing error
    train_error, test_error = 0, 0
    model.eval()
    with torch.no_grad():
        for inputs, targets in train_data:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            train_error += criterion(outputs, targets).item()
        for inputs, targets in test_data:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            test_error += criterion(outputs, targets).item()

    generalization_error = test_error 
    return generalization_error

# Generate data (or use your dataset)
def generate_dummy_data(n_samples, input_dim):
    X = torch.randn(n_samples, input_dim)
    y = torch.sin(X.sum(1, keepdim=True))  # Dummy target
    return [(X, y)]

# Main function to train models with different hidden units and plot generalization gap
def main():
    input_dim = 2
    train_data = generate_dummy_data(100, input_dim)
    test_data = generate_dummy_data(30, input_dim)

    hidden_dims = [100,200,300,400,500,600,700,800,900]  # Vary the number of hidden units
    generalization_errors = []

    for hidden_dim in hidden_dims:
        gen_error = train_and_evaluate(input_dim, hidden_dim, train_data, test_data)
        generalization_errors.append(gen_error)
        print(f"Hidden Dim: {hidden_dim}, Generalization Gap: {gen_error}")

    # Plotting the generalization gap vs number of hidden units
    plt.plot(hidden_dims, generalization_errors, marker='o')
    plt.xlabel('Number of Hidden Units')
    plt.ylabel('Generalization Error ')
    plt.title('Generalization Error vs Number of Hidden Units')
    plt.grid(True)
    plt.show()

# Run the main function
main()
