import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint_adjoint as odeint
import matplotlib.pyplot as plt
import numpy as np
import random

# --------------------- Set Fixed Seed ------------------------
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# --------------------- ODE Function --------------------------
class ODEFunc(nn.Module):
    def __init__(self, dim, hidden_dim=64):
        super().__init__()
        self.linear1 = nn.Linear(dim, hidden_dim, bias=False)
        self.linear2 = nn.Linear(hidden_dim, dim, bias=False)
        self.activation = nn.Tanh()

    def forward(self, t, x):
        return self.linear2(self.activation(self.linear1(x)))

    def lipschitz_constant(self):
        with torch.no_grad():
            W1_norm = torch.linalg.norm(self.linear1.weight, ord=2)
            W2_norm = torch.linalg.norm(self.linear2.weight, ord=2)
        return (W1_norm * W2_norm).item()

# --------------------- Neural ODE ----------------------------
class NeuralODE(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.func = ODEFunc(input_dim)
        self.classifier = nn.Linear(input_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        t = torch.tensor([0., 1.], device=x.device)
        out = odeint(self.func, x, t)[1]
        return self.classifier(out)

    def lipschitz_constant(self):
        return self.func.lipschitz_constant()

# --------------------- Train & Eval --------------------------
def train(model, device, loader, optimizer, criterion):
    model.train()
    total, correct = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)
    return correct / total

def evaluate(model, device, loader):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            correct += (logits.argmax(1) == y).sum().item()
            total += y.size(0)
    return correct / total

# --------------------- Main Loop -----------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_ds  = datasets.MNIST(root="./data", train=False, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_ds, batch_size=128, num_workers=2, pin_memory=True)

    model = NeuralODE(input_dim=28*28).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    lips_list = []
    gap_list = []

    for epoch in range(1, 11):
        train_acc = train(model, device, train_loader, optimizer, criterion)
        test_acc = evaluate(model, device, test_loader)
        gap = (1 - test_acc) - (1 - train_acc)  # test error - train error
        lipschitz = model.lipschitz_constant()

        lips_list.append(lipschitz)
        gap_list.append(gap)

        print(f"Epoch {epoch:02d} → Train acc: {train_acc:.4f}, Test acc: {test_acc:.4f}, "
              f"Lipschitz: {lipschitz:.2f}, Gap: {gap:.4f}")

    # ------------------- Plot After Training ------------------
    plt.figure(figsize=(6, 4))
    plt.plot(lips_list, gap_list, 'o-', color='blue')
    plt.xlabel("Lipschitz Constant")
    plt.ylabel("Generalization Gap (Test Error - Train Error)")
    plt.title("Lipschitz vs Generalization Gap (Neural ODE on MNIST)")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
