# train_experts_by_label.py
import os
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import wandb

from model import ResNetv2_rej     # <- assumes you already have this

# ------------------------------------------------------------------------------
# 0) Dataset wrapper – keep only chosen labels but DO NOT relabel them
# ------------------------------------------------------------------------------
class LabelFilteredSVHN(Dataset):
    """
    Return only samples whose original label is inside <allowed_labels>.
    Labels are left unchanged (still in {0,…,9}), so a 10-way softmax works.
    """
    def __init__(self, base_ds: datasets.SVHN, allowed_labels: List[int]):
        self.base = base_ds
        self.allowed = set(allowed_labels)

        # indices of items to keep
        self.idxs = [
            i for i, lbl in enumerate(self.base.labels) if int(lbl) in self.allowed
        ]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        real_idx = self.idxs[idx]

        # SVHN stores images as uint8 (C,32,32); convert to HWC for ToTensor
        img = self.base.data[real_idx]           # (3,32,32)
        img = np.transpose(img, (1, 2, 0))       # (32,32,3)

        if self.base.transform is not None:
            img = self.base.transform(img)

        label = int(self.base.labels[real_idx])  # still 0…9
        return img, label


# ------------------------------------------------------------------------------
# 1) Config & label groups
# ------------------------------------------------------------------------------
WANDB_PROJECT = "my-svhn-moe"
CONFIG = dict(
    epochs=20,
    batch_size=64,
    lr=1e-3,
)

LABEL_GROUPS = [
    [1, 2, 3],      # expert 1
    [2, 3, 4],      # expert 2
    [3, 4, 5],      # expert 3
    [4, 5, 6],      # expert 4
    [5, 6, 7],      # expert 5
    [0, 8, 9],      # expert 6
]
N_EXPERTS = len(LABEL_GROUPS)


# ------------------------------------------------------------------------------
# 2) Helpers
# ------------------------------------------------------------------------------
def make_model() -> nn.Module:
    """10-logit ResNet-v2 rejector."""
    return ResNetv2_rej(depth=10, num_classes=10, dropout=0.0)


def train_epoch(model, loader, crit, opt, device):
    model.train()
    tloss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = crit(out, y)
        opt.zero_grad(); loss.backward(); opt.step()

        tloss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    return tloss / total, correct / total


@torch.no_grad()
def eval_epoch(model, loader, crit, device):
    model.eval()
    eloss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = crit(out, y)

        eloss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    return eloss / total, correct / total


# ------------------------------------------------------------------------------
# 3) Main
# ------------------------------------------------------------------------------
def main():
    wandb.init(
        project=WANDB_PROJECT,
        name="experts-by-label",
        config={**CONFIG, "label_groups": LABEL_GROUPS, "n_experts": N_EXPERTS},
    )
    cfg = wandb.config
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Device:", device)

    # transforms (RGB mean/std from official SVHN stats)
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4377, 0.4438, 0.4728),
                             std=(0.1980, 0.2010, 0.1970)),
    ])

    train_raw = datasets.SVHN(root="./SVHN", split="train", download=True, transform=tfm)
    test_raw  = datasets.SVHN(root="./SVHN", split="test",  download=True, transform=tfm)

    # build loaders per expert
    experts = []
    for labels in LABEL_GROUPS:
        tr_ds = LabelFilteredSVHN(train_raw, labels)
        te_ds = LabelFilteredSVHN(test_raw,  labels)
        experts.append({
            "labels": labels,
            "train": DataLoader(tr_ds, batch_size=cfg.batch_size,
                                shuffle=True,  num_workers=2, drop_last=True),
            "test":  DataLoader(te_ds, batch_size=cfg.batch_size,
                                shuffle=False, num_workers=2),
        })

    # train & evaluate each expert
    for k, exp in enumerate(experts, 1):
        print(f"\n=== Expert {k} – training on labels {exp['labels']} "
              f"({len(exp['train'].dataset)} samples) ===")
        model = make_model().to(device)
        crit = nn.CrossEntropyLoss()
        opt = optim.Adam(model.parameters(), lr=cfg.lr)

        best, ckpt = 0.0, f"expert_{k}_best.pth"
        for epoch in range(cfg.epochs):
            tl, ta = train_epoch(model, exp["train"], crit, opt, device)
            vl, va = eval_epoch(model, exp["test"],  crit, device)

            print(f"  ep {epoch+1:02d}/{cfg.epochs} "
                  f"train {ta*100:5.1f}%  val {va*100:5.1f}%")
            wandb.log({
                f"exp{k}/train_loss": tl,
                f"exp{k}/train_acc":  ta,
                f"exp{k}/val_loss":   vl,
                f"exp{k}/val_acc":    va,
                "epoch": epoch + 1,
                "expert": k,
            })

            if va > best:
                best = va
                torch.save(model.state_dict(), ckpt)
        print(f"✓ Expert {k}: best val acc = {best*100:.2f}% (saved {ckpt})")

    print("\nAll experts finished. See Weights & Biases for full logs.")


if __name__ == "__main__":
    main()
