import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import json
import os

from data.generate_data import generate_toy_data_pfaff

# === Model imports ===
from models import AntiSymm20Model, SimpleMLP20, PermEquiv20Model
#PermEquiv32Model
# Add additional model imports as needed
# from AnotherModel import AnotherModel

# === Device setup ===
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# === Config ===
n = 4
seed = 42
lr = 0.001
epochs = 20
batch_size = 32
n_runs = 5
training_sizes = [30, 300, 3000, 30000]
#training_sizes = [30, 300]

# === Function to train and evaluate a model ===
def train_and_evaluate(train_loader, test_loader, model, optimizer, criterion, epochs):
    model.to(device)
    model.train()
    for epoch in range(epochs):
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    model.eval()
    test_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
    return test_loss / total_samples

# === Step 4: Initialize and train all models ===
models_info = [
    (AntiSymm20Model(n), f"AntiSymm20Model({n})", f"AntiSymmPermEquiv"),
    (PermEquiv20Model(n), f"PermEquiv20Model({n})", f"PermEquiv"),
    (SimpleMLP20(n), f"SimpleMLP20({n})", f"MLP"),
]

# === Storage for test errors ===
test_errors = {plot_name: [] for _, _, plot_name in models_info}

# === Main experiment loop ===
for train_size in training_sizes:
    print(f"\n==> Training Size: {train_size}")
    for model, model_name, plot_name in models_info:
        errors = []

        for run in range(n_runs):
            print(f"  [{model_name}] Run {run + 1}/{n_runs}")
            X_train, y_train = generate_toy_data_pfaff(train_size, n, seed + run)
            train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)

            X_test, y_test = generate_toy_data_pfaff(1000, n, seed + 1000 + run)
            test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=16, shuffle=False)

            model_instance = model  # Instantiate model
            optimizer = optim.SGD(model_instance.parameters(), lr=lr)
            criterion = nn.MSELoss()

            test_error = train_and_evaluate(train_loader, test_loader, model_instance, optimizer, criterion, epochs)
            errors.append(test_error)

        avg_error = np.mean(errors)
        #print(f"  → Avg Test MSE for {model_name}: {avg_error:.4f}")
        std_dev = np.std(errors)
        print(f"  → Avg Test MSE for {model_name}: {avg_error:.4f} ± {std_dev:.4f}")
        test_errors[plot_name].append(errors)

# === Plot results ===
import matplotlib.ticker as ticker
# Convert training sizes to strings for categorical x-axis
training_labels = [str(size) for size in training_sizes]
x_pos = np.arange(len(training_sizes))  # positions for categorical ticks

plt.figure(figsize=(6, 4)) 

# Plot each model’s mean error with categorical spacing
for model_name, error_lists in test_errors.items():
    #mean_errors = np.mean(error_lists, axis=1)
    #std_devs = np.std(error_lists, axis=1) 
    #plt.errorbar(x_pos, mean_errors, yerr=std_devs, marker='o', capsize=4, label=model_name)
    mean_errors = np.mean(error_lists, axis=1)
    plt.plot(x_pos, mean_errors, marker='o', label=model_name)


# Set x-ticks to the training size labels
plt.xticks(ticks=x_pos, labels=training_labels)

# Set y-axis to log scale
plt.yscale('log')  # Logarithmic scale for better visualization
plt.gca().yaxis.set_major_formatter(ticker.LogFormatterMathtext(base=10.0))  # Use 10^x format
plt.gca().yaxis.set_minor_locator(ticker.NullLocator())  # Turn off minor ticks

# Set labels and grid
plt.xlabel('Training Set Size', fontsize=12)
plt.ylabel('Test MSE', fontsize=12)
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
plt.legend(fontsize=11)
plt.tight_layout()

# Make sure the 'figures' directory exists
if not os.path.exists('figures'):
    os.makedirs('figures')

# Save and show
plt.savefig(f'figures/test_error_vs_train_size_log_{n}_expand.png', dpi=300)
plt.show()