# Train a prior mean by briefly fitting BaselineCNN or EquivariantCNN on the 'prior' split and save flattened weights.

from models import BaselineCNN
from models import EquivariantCNN
from generate_rotated_mnist import RotatedMNISTDataset
from torch.utils.data import DataLoader
import torch, torch.nn as nn, torch.optim as optim

# choose whether to use translated dataset and which model family to train
translated = False
baseline = True

# select dataset directory based on `translated` flag
if not translated:
    path_mnist = "rotated_mnist"
else:
    path_mnist = "rotated_translated_mnist"

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

# load the precomputed prior split (used to learn a prior mean)
prior_ds = RotatedMNISTDataset(path_mnist + "/prior.pt")

# dataloader for short prior training; num_workers can be tuned for your environment
prior_loader = DataLoader(prior_ds, batch_size=128, shuffle=True, num_workers=4)

# instantiate chosen model and move to device
if baseline:
    model = BaselineCNN().to(device)
else:
    model = EquivariantCNN().to(device)

opt = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# quick training loop
for epoch in range(5):
    model.train()
    for imgs, labels, angles in prior_loader:
        # only images and labels are used to train the prior mean
        imgs, labels = imgs.to(device), labels.to(device)
        opt.zero_grad()
        loss = criterion(model(imgs), labels)
        loss.backward(); opt.step()
    print(f"Epoch {epoch:02d}: "
          f"loss {loss:.4f}")


# save prior mean (flattened weights)
def get_flat_params_from(model):
    return torch.cat([p.detach().view(-1).cpu() for p in model.parameters()])

prior_mu = get_flat_params_from(model)
if baseline:
    torch.save(prior_mu, path_mnist + "/prior_mu_baseline.pt")
else:
    torch.save(prior_mu, path_mnist + "/prior_mu_equivariant.pt")
