# train.py

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from data.generate_data import generate_toy_data
from models import AntiSymm21Model, PermEquiv21Model, SimpleMLP21

# Device-agnostic setup
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Step 0: Set up basic parameters
num_tuples = 10000
n = 8
seed = 42
lr = 0.001
epochs = 30
batch_size = 50

# Step 1: Generate and load the toy dataset
X_train, y_train = generate_toy_data(num_tuples, n, seed)

# Move data to device
X_train, y_train = X_train.to(device), y_train.to(device)

# DataLoader for batching
dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Step 2: Define loss function
criterion = nn.MSELoss()

# Step 3: Define training function
def train_model(model, optimizer, criterion, train_loader, epochs, name="Model"):
    model.to(device)
    model.train()
    start_time = time.time()
    for epoch in range(epochs):
        running_loss = 0.0
        for batch_inputs, batch_targets in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_inputs)
            loss = criterion(outputs, batch_targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss ({name}): {running_loss / len(train_loader):.4f}")
    total_time = time.time() - start_time
    print(f"Training time for {name}: {total_time:.2f} seconds")
    return model

# Step 4: Initialize and train all models
models_info = [
    #(AntiSymm21Linear(n), f"AntiSymm21Linear({n})"),
    (AntiSymm21Model(n), f"AntiSymm21Model({n})"),
    #(PermEquiv21Model(n), f"PermEquiv21Model({n})"),
    (SimpleMLP21(n), f"SimpleMLP21({n})"),
]

trained_models = []
for model, name in models_info:
    optimizer = optim.SGD(model.parameters(), lr=lr)
    #optimizer = optim.Adam(model.parameters(), lr=lr)
    trained_model = train_model(model, optimizer, criterion, train_loader, epochs, name)
    trained_models.append((trained_model, name))

# Step 5: Save the trained models
save_directory = 'saved_models'
os.makedirs(save_directory, exist_ok=True)

for model, name in trained_models:
    file_path = os.path.join(save_directory, f"{name.lower()}_diag.pth")
    torch.save(model.state_dict(), file_path)
    print(f"Saved {name} model to {file_path}")
