import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchdiffeq import odeint
import pandas as pd
import tqdm
seed = 15000
torch.manual_seed(seed)
np.random.seed(seed)
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a time-dependent neural ODE model
class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.fc1 = nn.Linear(2, 50).to(device)
        self.fc2 = nn.Linear(50, 2).to(device)

    def forward(self, t, y):
        time_dep_weight = torch.sin(t).to(device)  # Time-dependent weight, ensure on device
        out = torch.relu(self.fc1(y) * time_dep_weight)
        return self.fc2(out)

# ODE Block
class ODEBlock(nn.Module):
    def __init__(self, ode_func):
        super(ODEBlock, self).__init__()
        self.ode_func = ode_func

    def forward(self, x, t=torch.tensor([0.0, 1.0], device='cpu')):
        try:
            t = t.to(x.device).float()  # Ensure t is a floating point tensor and on the same device as x
            out = odeint(self.ode_func, x, t)
            return out[-1]
        except Exception as e:
            print(f"Error in ODEBlock forward pass: {e}")
            return x


# Model with ODE block
class ODEModel(nn.Module):
    def __init__(self):
        super(ODEModel, self).__init__()
        self.ode_block = ODEBlock(ODEFunc()).to(device)

    def forward(self, x):
        return self.ode_block(x)

# Lipschitz constant calculation
def lipschitz_constant(model):
    max_singular_value = 0
    for param in model.parameters():
        if len(param.shape) > 1:  # Only consider weight matrices
            singular_values = torch.svd(param)[1].to(device)
            max_singular_value = max(max_singular_value, singular_values.max().item())
    return max_singular_value

# Dummy training data
train_data = torch.randn(100, 2).to(device)
train_labels = (train_data * 2).to(device)
val_data = torch.randn(20, 2).to(device)
val_labels = (val_data * 2).to(device)

# Generalization gap and training with Lipschitz penalty
def train_model(lambda_penalty, num_epochs=50):
    model = ODEModel().to(device)
    criterion = nn.MSELoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        predictions = model(train_data)
        loss = criterion(predictions, train_labels)
        
        # Calculate Lipschitz constant and add penalty
        lip_const = lipschitz_constant(model)
        loss += lambda_penalty * lip_const  # Add penalization term
        
        loss.backward()
        optimizer.step()
    
    # Calculate generalization gap
    train_loss = criterion(model(train_data), train_labels).item()
    val_loss = criterion(model(val_data), val_labels).item()
    generalization_gap = val_loss - train_loss
    return generalization_gap

# Experiment setup
lambdas = [0, 0.01, 0.1,1]  # Different values of λ
num_repeats = 20  # Repeat each experiment 20 times
results = {lambda_penalty: [] for lambda_penalty in lambdas}

# Run the experiment for each λ
for lambda_penalty in lambdas:
    for _ in tqdm.tqdm(range(num_repeats)):
        generalization_gap = train_model(lambda_penalty)
        results[lambda_penalty].append(generalization_gap)

# Box plot of generalization gap vs λ
data = []
for lambda_penalty, gaps in results.items():
    data.extend([(lambda_penalty, gap) for gap in gaps])

lambda_vals, generalization_gaps = zip(*data)
plt.figure(figsize=(10, 6))
sns.boxplot(x= pd.Series(lambda_vals), y=pd.Series(generalization_gaps))
plt.xlabel('λ (Penalization Factor)')
plt.ylabel('Generalization Gap')
plt.title('Generalization Gap vs λ with Lipschitz Constant Penalty')
plt.grid(True)
plt.show()
