import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt  
from torch.utils.data import DataLoader, TensorDataset
from data.generate_data import generate_toy_data
from models import AntiSymm21Model, PermEquiv21Model, SimpleMLP21

# Device-agnostic setup
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Step 0: Set up basic parameters
num_tuples = 10000
n = 8
seed = 42
lr = 0.001
epochs = 20
batch_size = 32

# Step 1: Generate and load the toy dataset
X_train, y_train = generate_toy_data(num_tuples, n, seed)

# Move data to device
X_train, y_train = X_train.to(device), y_train.to(device)

# DataLoader for batching
dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Step 2: Define loss function
criterion = nn.MSELoss()

# Step 3: Define training function with loss tracking
def train_model(model, optimizer, criterion, train_loader, epochs, name="Model"):
    model.to(device)
    model.train()
    start_time = time.time()

    # List to store the training loss at each epoch
    epoch_losses = []

    for epoch in range(epochs):
        running_loss = 0.0
        for batch_inputs, batch_targets in train_loader:
            optimizer.zero_grad()
            outputs = model(batch_inputs)
            loss = criterion(outputs, batch_targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Average loss for the current epoch
        avg_loss = running_loss / len(train_loader)
        epoch_losses.append(avg_loss)
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss ({name}): {avg_loss:.4f}")

    total_time = time.time() - start_time
    print(f"Training time for {name}: {total_time:.2f} seconds")

    # Return the trained model and the epoch losses
    return model, epoch_losses

# Step 4: Initialize and train all models
models_info = [
    (AntiSymm21Model(n), f"AntiSymm21Model({n})", f"AntiSymmPermEquiv"),
    (PermEquiv21Model(n), f"PermEquiv21Model({n})", f"PermEquiv"),
    (SimpleMLP21(n), f"SimpleMLP21({n})", f"MLP"),
]

trained_models = []
epoch_losses_dict = {}  # Dictionary to store losses for each model

for model, model_name, plot_name in models_info:
    #optimizer = optim.Adam(model.parameters(), lr=lr)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    trained_model, epoch_losses = train_model(model, optimizer, criterion, train_loader, epochs, model_name)
    trained_models.append((trained_model, model_name))
    epoch_losses_dict[plot_name] = epoch_losses  # Store the losses

# Step 5: Save the trained models
save_directory = 'saved_models'
os.makedirs(save_directory, exist_ok=True)

for model, name in trained_models:
    file_path = os.path.join(save_directory, f"{name.lower()}_diag.pth")
    torch.save(model.state_dict(), file_path)
    print(f"Saved {name} model to {file_path}")

# Step 6: Plot the loss vs. epoch for each model
plt.figure(figsize=(6, 4))  # figure (8x4 inches)

for name, losses in epoch_losses_dict.items():
    plt.plot(range(1, epochs + 1), losses, label=name)

import matplotlib.ticker as ticker
plt.yscale('log')  # Logarithmic scale for better visualization
plt.gca().yaxis.set_major_formatter(ticker.LogFormatterMathtext(base=10.0))  # Use 10^x format
plt.gca().yaxis.set_minor_locator(ticker.NullLocator())  # Turn off minor ticks
plt.gca().xaxis.set_major_locator(ticker.MultipleLocator(5))

plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.legend()
plt.grid(True)

# Make sure the 'figures' directory exists, create it if it doesn't
if not os.path.exists('figures'):
    os.makedirs('figures')

# Save the plot into the 'figures' folder with a unique name
plt.savefig(f'figures/train_error_vs_epochs_({n}).png', dpi=300)

#plt.savefig(f'train_error_vs_epochs_({n}).png', dpi=300)
plt.show()

