#!/usr/bin/env python3
"""
Train configurable models on CIFAR-10, CIFAR-100, Stanford Cars, Camelyon17-WILDS with literature-recommended hyperparameters,
supporting a single random seed, selective prediction methods (MSP, Self-Adaptive Training),
and optional reduced training split for subsequent LP head training.

Usage examples:
    python train_main.py --arch resnet18 --dataset cifar10 --seed 0 --method msp --reduced_train
    python train_main.py --arch resnet18 --dataset cifar10 --seed 0 --method sat \
        --reduced_train --train_split_ratio 0.8
"""
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision
from torchvision import transforms as T
from torchvision.datasets import CIFAR10, CIFAR100, StanfordCars
from torchvision.datasets import ImageFolder
from tqdm import tqdm

# ---------------------------------------------------------------------
# Self-Adaptive Training criterion
# ---------------------------------------------------------------------
class SelfAdaptiveTraining:
    def __init__(self, num_examples, num_base_classes, mom=0.9):
        self.prob_history = torch.zeros(num_examples, num_base_classes)
        self.updated = torch.zeros(num_examples, dtype=torch.int)
        self.mom = mom
        self.num_base = num_base_classes

    def _update_prob(self, prob, index, y):
        onehot = torch.zeros_like(prob)
        onehot.scatter_(1, y.unsqueeze(1), 1.0)
        hist = self.prob_history[index].clone().to(prob.device)
        cond = (self.updated[index] == 1).to(prob.device).unsqueeze(-1).expand_as(prob)
        base = torch.where(cond, hist, onehot)
        prob_mom = self.mom * base + (1 - self.mom) * prob
        self.prob_history[index] = prob_mom.cpu()
        self.updated[index] = 1
        return prob_mom

    def __call__(self, logits, y, index):
        prob = F.softmax(logits.detach()[:, :self.num_base], dim=1)
        prob_mom = self._update_prob(prob, index, y)
        B = y.size(0)
        soft = torch.zeros_like(logits)
        true_prob = prob_mom[torch.arange(B), y]
        soft[torch.arange(B), y] = true_prob
        soft[:, -1] = 1.0 - true_prob
        soft = F.normalize(soft, p=1, dim=1)
        return torch.mean(-torch.sum(F.log_softmax(logits, dim=1) * soft, dim=1))

# ---------------------------------------------------------------------
# Indexed dataset wrapper
# ---------------------------------------------------------------------
class IndexedDataset(Dataset):
    def __init__(self, ds):
        self.ds = ds
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        img, target = self.ds[idx][0], self.ds[idx][1]
        return img, target, idx

# ---------------------------------------------------------------------
# Defaults & hyperparameters
# ---------------------------------------------------------------------
EPOCHS_DEFAULT = {
    "cifar10": 200,
    "cifar100": 200,
    "tinyimagenet200": 200,
    "stanfordcars": 200,
    "camelyon17": 10,
    "fmow": 200,
}
HPARAMS = {
    "simple_cnn": {"cifar10":(0.01,1e-4,128),"cifar100":(0.01,1e-4,128),
                    "tinyimagenet200":(0.01,1e-4,128),"stanfordcars":(0.01,1e-4,128),
                    "camelyon17":(0.01,1e-4,128),"fmow":(0.01,1e-4,64)},
    "resnet18":   {"cifar10":(0.1,5e-4,128),"cifar100":(0.1,5e-4,128),
                    "tinyimagenet200":(0.1,5e-4,256),"stanfordcars":(0.01,5e-4,128),
                    "camelyon17":(0.01,5e-4,128),"fmow":(0.01,5e-4,64)},
    "wideresnet": {"cifar10":(0.1,5e-4,128),"cifar100":(0.1,5e-4,128),
                    "tinyimagenet200":(0.1,5e-4,256),"stanfordcars":(0.01,5e-4,128),
                    "camelyon17":(0.01,5e-4,128),"fmow":(0.01,5e-4,64)},
}

# ---------------------------------------------------------------------
# Model definitions
# ---------------------------------------------------------------------
# SimpleCNN omitted for brevity; same as original script

def get_model(arch, dataset, extra_class=False):
    base_classes = {"cifar10":10,"cifar100":100,
                    "tinyimagenet200":200,"stanfordcars":196,
                    "camelyon17":2,"fmow":62}[dataset]
    num_classes = base_classes + (1 if extra_class else 0)
    if arch == "resnet18":
        m = torchvision.models.resnet18(weights=None, num_classes=num_classes)
        if dataset in ["cifar10","cifar100"]:
            m.conv1 = nn.Conv2d(3,64,3,stride=1,padding=1,bias=False)
            m.maxpool = nn.Identity()
        return m
    # other archs...
    raise ValueError(f"Unknown arch {arch}")

# ---------------------------------------------------------------------
# Transforms and loaders
# ---------------------------------------------------------------------
def make_transforms(dataset):
    if dataset in ["cifar10","cifar100"]:
        mean,std=(0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616)
        ttrain=T.Compose([T.RandomCrop(32,4),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean,std)])
        ttest=T.Compose([T.ToTensor(),T.Normalize(mean,std)])
    elif dataset=="tinyimagenet200":
        mean,std=(0.485,0.456,0.406),(0.229,0.224,0.225)
        ttrain=T.Compose([T.RandomResizedCrop(64),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean,std)])
        ttest=T.Compose([T.Resize(64),T.CenterCrop(64),T.ToTensor(),T.Normalize(mean,std)])
    elif dataset in ["stanfordcars","camelyon17","fmow"]:
        mean,std=(0.485,0.456,0.406),(0.229,0.224,0.225)
        ttrain=T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean,std)])
        ttest=T.Compose([T.Resize(256),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean,std)])
    else:
        raise ValueError(f"Unknown dataset {dataset}")
    return ttrain, ttest


def load_dataset(dataset, ttrain, ttest, base_dir):
    if dataset=="cifar10":
        ds_train = CIFAR10(base_dir, True, transform=ttrain, download=True)
        ds_test  = CIFAR10(base_dir, False, transform=ttest, download=True)
    elif dataset=="cifar100":
        ds_train = CIFAR100(base_dir, True, transform=ttrain, download=True)
        ds_test  = CIFAR100(base_dir, False, transform=ttest, download=True)
    elif dataset=="tinyimagenet200":
        ds_train = ImageFolder(os.path.join(base_dir,"tinyimagenet200/train"),ttrain)
        ds_test  = ImageFolder(os.path.join(base_dir,"tinyimagenet200/val"),ttest)
    elif dataset=="camelyon17":
        from wilds import get_dataset
        dataset_w = get_dataset(dataset="camelyon17", root_dir=base_dir, download=True)
        ds_train = dataset_w.get_subset("train", transform=ttrain)
        ds_test  = dataset_w.get_subset("val", transform=ttest)
    elif dataset=="fmow":
        from wilds import get_dataset
        dataset_w = get_dataset(dataset="fmow", root_dir=base_dir, download=True)
        ds_train = dataset_w.get_subset("train", transform=ttrain)
        ds_test  = dataset_w.get_subset("val", transform=ttest)
    elif dataset=="stanfordcars":
        ds_train = StanfordCars(base_dir, split="train", transform=ttrain, download=False)
        ds_test  = StanfordCars(base_dir, split="test", transform=ttest, download=False)
    else:
        raise ValueError(f"Unknown dataset {dataset}")
    return IndexedDataset(ds_train), IndexedDataset(ds_test)

# ---------------------------------------------------------------------
# Main training loop
# ---------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", choices=EPOCHS_DEFAULT.keys(), default="cifar10")
    parser.add_argument("--arch", choices=HPARAMS.keys(), default="resnet18")
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--method", choices=["msp","sat"], default="msp")
    parser.add_argument("--sat_pretrain", type=int, default=None)
    parser.add_argument("--epochs", type=int, default=None)
    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument('--data_dir', required=True,
                        help='Root dir for dataset')
    parser.add_argument('--ckpt_dir', required=True,
                        help='Root dir for checkpoints')
    parser.add_argument("--reduced_train", action="store_true",
                        help="Use only a fraction of training set and reserve rest for LP training")
    parser.add_argument("--train_split_ratio", type=float, default=0.8,
                        help="Fraction of data for base-model training when --reduced_train is set")
    args = parser.parse_args()

    # determine epochs
    if args.epochs is None:
        args.epochs = EPOCHS_DEFAULT[args.dataset]
    if args.method == "sat" and args.sat_pretrain is None:
        args.sat_pretrain = args.epochs // 2

    # seed
    random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)
    if args.device.startswith("cuda"): torch.cuda.manual_seed_all(args.seed)

    # data
    ttrain, ttest = make_transforms(args.dataset)
    full_train_ds, test_ds = load_dataset(args.dataset, ttrain, ttest, args.data_dir)

    # optional reduced split
    if args.reduced_train:
        total = len(full_train_ds)
        indices = np.arange(total)
        rng = np.random.RandomState(args.seed)
        rng.shuffle(indices)
        split = int(args.train_split_ratio * total)
        train_idx, lp_idx = indices[:split].tolist(), indices[split:].tolist()
        os.makedirs("./indices", exist_ok=True)
        np.save(f"./indices/{args.dataset}_{args.arch}_{args.method}_seed{args.seed}_lp_indices.npy", np.array(lp_idx))
        train_ds = Subset(full_train_ds, train_idx)
    else:
        train_ds = full_train_ds

    train_loader = DataLoader(train_ds,
                              batch_size=HPARAMS[args.arch][args.dataset][2],
                              shuffle=True, num_workers=args.workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False,
                             num_workers=args.workers, pin_memory=True)

    # model setup
    extra = (args.method == "sat")
    model = get_model(args.arch, args.dataset, extra_class=extra).to(args.device)
    lr, wd, _ = HPARAMS[args.arch][args.dataset]
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9,
                          weight_decay=wd, nesterov=True)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)

    sat_crit = None
    if args.method == "sat":
        base_classes = model.fc.out_features - 1
        sat_crit = SelfAdaptiveTraining(len(full_train_ds), base_classes, mom=0.9)


    best_acc = 0.0
    for epoch in range(1, args.epochs+1):
        model.train()
        running_loss = correct = total = 0
        for x, y, idx in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y = x.to(args.device), y.to(args.device)
            optimizer.zero_grad()
            logits = model(x)
            if args.method == "msp":
                loss = F.cross_entropy(logits, y)
            else:
                if epoch <= args.sat_pretrain:
                    loss = F.cross_entropy(logits[:, :-1], y)
                else:
                    loss = sat_crit(logits, y, idx)
            loss.backward(); optimizer.step()
            running_loss += loss.item() * y.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
        train_acc = correct/total

        # validation
        model.eval()
        val_correct = val_total = 0
        with torch.no_grad():
            for x, y, _ in test_loader:
                x, y = x.to(args.device), y.to(args.device)
                logits = model(x)
                preds = logits.argmax(dim=1)
                val_correct += (preds == y).sum().item()
                val_total += y.size(0)
        test_acc = val_correct/val_total
        scheduler.step()

        print(f"Epoch {epoch}/{args.epochs} - Train Acc: {train_acc:.4f}, Val Acc: {test_acc:.4f}")

        # checkpoint
        suffix = "_reduced" if args.reduced_train else ""
        fname = f"{args.arch}_{args.dataset}_{args.method}_seed{args.seed}{suffix}.pt"
        ckpt_path = os.path.join(args.ckpt_dir, args.dataset, fname)
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), ckpt_path)

    print(f"Done. Best Validation Acc: {best_acc:.4f}")

if __name__ == "__main__":
    main()
