# Train and evaluate the equivariant CNN (escnn/e2cnn) on rotated (or rotated+translated) MNIST, initializing from a saved prior mean.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from models import EquivariantCNN
from dataset import RotatedMNISTDataset

def set_flat_params_to(model, flat_params):
    # Copy a flat parameter vector into the model's parameters (in-place).
    pointer = 0
    for p in model.parameters():
        numel = p.numel()
        p.data.copy_(flat_params[pointer:pointer+numel].view_as(p))
        pointer += numel

def train_equivariant_cnn(
    data_dir="rotated_mnist",
    batch_size=128,
    lr=1e-3,
    epochs=20,
    seed=0,
    # device="cuda" if torch.cuda.is_available() else "cpu",
    device = "cpu",
    save_path="rotated_mnist/equivariant_cnn.pt",
    prior_path="rotated_mnist/prior_mu_equivariant.pt"
):
    torch.manual_seed(seed)

    # Load pre-saved splits (train/val/test) produced by the dataset generation script
    train_set = RotatedMNISTDataset(os.path.join(data_dir, "train.pt"))
    val_set   = RotatedMNISTDataset(os.path.join(data_dir, "val.pt"))
    test_set  = RotatedMNISTDataset(os.path.join(data_dir, "test.pt"))

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    # Instantiate model and initialize weights to the pre-trained prior mean
    model = EquivariantCNN().to(device)
    prior_mu = torch.load(prior_path, map_location=device)
    set_flat_params_to(model, prior_mu)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_acc = 0.0

    for epoch in range(1, epochs+1):
        # ---- train ----
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

        train_loss = total_loss / total
        train_acc = correct / total

        # ---- validate ----
        val_acc = evaluate(model, val_loader, device)

        print(f"Epoch {epoch:02d}: "
              f"Train loss {train_loss:.4f}, Train acc {train_acc:.4f}, Val acc {val_acc:.4f}")

        # Save best model by validation accuracy
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)

    print(f"Best val acc: {best_val_acc:.4f}")
    # reload the best checkpoint and evaluate on test set
    model.load_state_dict(torch.load(save_path))
    test_acc = evaluate(model, test_loader, device)
    print(f"Test acc: {test_acc:.4f}")
    return model

def evaluate(model, loader, device="cpu"):
    # Compute accuracy over a DataLoader
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total

if __name__ == "__main__":
    translated = False

    if not translated:
        train_equivariant_cnn(
            data_dir="rotated_mnist",
            batch_size=128,
            lr=1e-3,
            epochs=20,
            seed=0,
            # device="cuda" if torch.cuda.is_available() else "cpu",
            device = "cpu",
            save_path="rotated_mnist/equivariant_cnn.pt",
            prior_path="rotated_mnist/prior_mu_equivariant.pt")
    else:
        train_equivariant_cnn(
            data_dir="rotated_translated_mnist",
            batch_size=128,
            lr=1e-3,
            epochs=1,
            seed=0,
            # device="cuda" if torch.cuda.is_available() else "cpu",
            # device ="cuda" if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 7 else "cpu",
            device="cpu",
            save_path="rotated_translated_mnist/equivariant_cnn.pt",
            prior_path="rotated_translated_mnist/prior_mu_equivariant.pt"
        )
