# Train and evaluate the baseline CNN on rotated (or rotated+translated) MNIST, initializing from a saved prior mean.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from models import BaselineCNN
from dataset import RotatedMNISTDataset

def set_flat_params_to(model, flat_params):
    # Copy a flat parameter vector into the model's parameters (in-place).
    pointer = 0
    for p in model.parameters():
        numel = p.numel()
        p.data.copy_(flat_params[pointer:pointer+numel].view_as(p))
        pointer += numel

def train_baseline(
    data_dir="rotated_mnist",
    batch_size=128,
    lr=1e-3,
    epochs=20,
    seed=0,
    # device="cuda" if torch.cuda.is_available() else "cpu",
    device = "cpu",
    save_path="rotated_mnist/baseline_cnn.pt",
    prior_path="rotated_mnist/prior_mu_baseline.pt"
):
    torch.manual_seed(seed)

    # Load pre-saved splits (train/val/test) produced by the dataset generation script
    train_set = RotatedMNISTDataset(os.path.join(data_dir, "train.pt"))
    val_set   = RotatedMNISTDataset(os.path.join(data_dir, "val.pt"))
    test_set  = RotatedMNISTDataset(os.path.join(data_dir, "test.pt"))

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    # Instantiate model and initialize weights to the pre-trained prior mean
    model = BaselineCNN().to(device)
    prior_mu = torch.load(prior_path, map_location=device)
    set_flat_params_to(model, prior_mu)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_acc = 0.0

    for epoch in range(1, epochs+1):
        # ---- train ----
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * imgs.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

        train_loss = total_loss / total
        train_acc = correct / total

        # ---- validate ----
        val_acc = evaluate(model, val_loader, device)

        print(f"Epoch {epoch:02d}: "
              f"Train loss {train_loss:.4f}, Train acc {train_acc:.4f}, Val acc {val_acc:.4f}")

        # Save best model by validation accuracy
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)

    print(f"Best val acc: {best_val_acc:.4f}")
    # reload the best checkpoint and evaluate on test set
    model.load_state_dict(torch.load(save_path))
    test_acc = evaluate(model, test_loader, device)
    print(f"Test acc: {test_acc:.4f}")
    return model

def evaluate(model, loader, device="cpu"):
    # Compute accuracy over a DataLoader
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total

if __name__ == "__main__":
    translated = False

    if not translated:
        train_baseline(data_dir="rotated_mnist",
                       batch_size=128,
                       lr=1e-3,
                       epochs=1,
                       seed=0,
                       # device="cuda" if torch.cuda.is_available() else "cpu",
                       device = "cpu",
                       save_path="rotated_mnist/baseline_cnn.pt",
                       prior_path="rotated_mnist/prior_mu_baseline.pt")
    else:
        train_baseline(data_dir="rotated_translated_mnist",
                       batch_size=128,
                       lr=1e-3,
                       epochs=1,
                       seed=0,
                       # device="cuda" if torch.cuda.is_available() else "cpu",
                       device="cpu",
                       save_path="rotated_translated_mnist/baseline_cnn.pt",
                       prior_path="rotated_translated_mnist/prior_mu_baseline.pt"
        )