import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tqdm
import random

seed = 100
torch.manual_seed(seed)
np.random.seed(seed)
# Define the deeper neural ODE model
class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 2),
        )

    def forward(self, t, y):
        return self.net(y)

class NeuralODE(nn.Module):
    def __init__(self, ode_func):
        super(NeuralODE, self).__init__()
        self.ode_func = ode_func

    def forward(self, y0, t):
        return odeint(self.ode_func, y0, t)

# Function to calculate the generalization error gap
def calculate_generalization_gap(train_loss, test_loss):
    return test_loss - train_loss

# Function to calculate V
def calculate_V(model, t_max, L_sigma, A, B, norm_z0, L_theta):
    V = norm_z0 + t_max * L_sigma * B * (L_sigma * A)**(L-1) * np.exp(t_max * L_theta)
    return V

# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate real-life complex data (e.g., a particle trajectory in a potential field)
def generate_real_life_data(n_samples):
    t = np.linspace(0, 10, n_samples)
    x = np.sin(t) + 0.5 * np.random.randn(n_samples)  # Simulate some complex motion
    y = np.cos(t) + 0.5 * np.random.randn(n_samples)  # Simulate some complex motion
    data = np.stack([x, y], axis=1)
    return torch.tensor(data, dtype=torch.float32).to(device)

# Perform the experiment
regularization_params = [0.0, 0.01,0.08,0.1 ]  # Different regularization parameters
num_trials = 20
num_epochs = 25
generalization_gaps = []

delta = 0.1  # Given value of delta

for reg_param in regularization_params:
    gaps_for_param = []
    
    for _ in tqdm.tqdm(range(num_trials)):
        # Fixed number of samples
        n_samples = 2000

        # Generate real-life complex data
        y0 = generate_real_life_data(n_samples)
        t = torch.tensor([0.0, 1.0]).to(device)  # Time range

        # Initialize and move the model to device
        model = NeuralODE(ODEFunc()).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.01)

        for epoch in range(num_epochs):
            optimizer.zero_grad()
            y_pred = model(y0, t)
            train_loss = torch.mean((y_pred - y0)**2)

            # Backward pass to compute gradients
            train_loss.backward(retain_graph=True)  # Retain graph for further backward passes

            # Ensure gradients are computed
            grad_norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None and p.grad.norm().item() != 0]
            if grad_norms:
                mu = max(grad_norms)
            else:
                mu = 0.0  # Handle cases with zero gradients

            L = max([torch.linalg.norm(p, ord=2).item() for p in model.parameters() if len(p.size()) > 1], default=0)
            d = y0.size(1)
            A = max([torch.norm(p, p=2).item() for p in model.parameters() if len(p.size()) > 1], default=0)
            B = max([torch.norm(p, p=2).item() for p in model.parameters() if len(p.size()) == 1], default=0)
            L_sigma = 1.0
            t_max = t.max().item()
            norm_z0 = torch.norm(y0).item()
            L_theta = 1.0
            for layer in model.ode_func.net:
                if isinstance(layer, nn.Linear):
                    weight_matrix = layer.weight.data
                    spectral_norm = torch.linalg.norm(weight_matrix, ord=2).item()
                    L_theta *= spectral_norm

            V = calculate_V(model, t_max, L_sigma, A, B, norm_z0, L_theta)

            # Loss with regularization
            loss = train_loss + reg_param * V
            loss.backward()  # No need to retain graph here since it's the last backward pass
            optimizer.step()

        # Evaluate on test data
        test_y0 = generate_real_life_data(n_samples)  # New initial condition for test
        test_y_pred = model(test_y0, t)
        test_loss = torch.mean((test_y_pred - test_y0)**2)

        # Calculate generalization gap
        gen_gap = calculate_generalization_gap(train_loss.item(), test_loss.item())
        gaps_for_param.append(gen_gap)

    generalization_gaps.append(gaps_for_param)

# Plotting the boxplot
plt.figure(figsize=(10, 6))
sns.boxplot(data=generalization_gaps)
plt.xticks(ticks=range(len(regularization_params)), labels=regularization_params)
plt.xlabel('Regularization Parameter')
plt.ylabel('Generalization Gap')
plt.title('Generalization Gap vs Regularization Parameter')
plt.show()
