#!/usr/bin/env python3
"""
Train models on CIFAR-10_N / CIFAR-100_N by holding out 10k train samples,
adding original test into train, and saving held-out indices for evaluation.
"""
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset
import torchvision
from torchvision import transforms as T
from torchvision.datasets import CIFAR10, CIFAR100, StanfordCars
from tqdm import tqdm

# ---------------------------------------------------------------------
# Self-Adaptive Training criterion (unchanged)
# ---------------------------------------------------------------------
class SelfAdaptiveTraining:
    def __init__(self, num_examples, num_base_classes, mom=0.9):
        self.prob_history = torch.zeros(num_examples, num_base_classes)
        self.updated = torch.zeros(num_examples, dtype=torch.int)
        self.mom = mom
        self.num_base = num_base_classes
    def _update_prob(self, prob, index, y):
        onehot = torch.zeros_like(prob)
        onehot.scatter_(1, y.unsqueeze(1), 1.0)
        hist = self.prob_history[index].clone().to(prob.device)
        cond = (self.updated[index]==1).to(prob.device).unsqueeze(-1).expand_as(prob)
        base = torch.where(cond, hist, onehot)
        prob_mom = self.mom*base + (1-self.mom)*prob
        self.prob_history[index] = prob_mom.cpu()
        self.updated[index] = 1
        return prob_mom
    def __call__(self, logits, y, index):
        prob = F.softmax(logits.detach()[:,:self.num_base], dim=1)
        prob_mom = self._update_prob(prob, index, y)
        B=y.size(0)
        soft = torch.zeros_like(logits)
        tp = prob_mom[torch.arange(B), y]
        soft[torch.arange(B), y] = tp
        soft[:, -1] = 1.0 - tp
        soft = F.normalize(soft, p=1, dim=1)
        return torch.mean(-torch.sum(F.log_softmax(logits,1)*soft, dim=1))

# ---------------------------------------------------------------------
# Indexed dataset wrapper
# ---------------------------------------------------------------------
class IndexedDataset(Dataset):
    def __init__(self, ds): self.ds=ds
    def __len__(self): return len(self.ds)
    def __getitem__(self, idx):
        img, target = self.ds[idx][0], self.ds[idx][1]
        return img, target, idx

# ---------------------------------------------------------------------
# Defaults & HPARAMS
# ---------------------------------------------------------------------
EPOCHS_DEFAULT={"cifar10_n":200,"cifar100_n":200}
HPARAMS={
    "simple_cnn":{"cifar10_n":(0.01,1e-4,128),"cifar100_n":(0.01,1e-4,128)},
    "resnet18":   {"cifar10_n":(0.1,5e-4,128),  "cifar100_n":(0.1,5e-4,128)},
    "wideresnet": {"cifar10_n":(0.1,5e-4,128),  "cifar100_n":(0.1,5e-4,128)},
}

from train_main import get_model as base_get_model, make_transforms as base_make_transforms

# ---------------------------------------------------------------------
# Wrapper to handle `_n` variants
# ---------------------------------------------------------------------
def get_model(arch, dataset, extra_class=False):
    """
    Map cifar10_n/cifar100_n back to their base for model instantiation.
    """
    base = dataset.replace('_n', '')
    return base_get_model(arch, base, extra_class)

# ---------------------------------------------------------------------
# Transforms wrapper
# ---------------------------------------------------------------------
def make_transforms(dataset):
    """
    Map `_n` variants back to base for transform construction.
    """
    base = dataset.replace('_n', '')
    return base_make_transforms(base)

# ---------------------------------------------------------------------
# Load & split dataset
# ---------------------------------------------------------------------
def load_dataset(dataset, ttrain, ttest, base_dir, seed):
    # for cifar10_n and cifar100_n
    base = dataset.split('_')[0]
    if base=='cifar10':
        full_train = CIFAR10(base_dir, True, transform=ttrain, download=True)
        orig_test  = CIFAR10(base_dir, False, transform=ttrain, download=True)
    else:
        full_train = CIFAR100(base_dir, True, transform=ttrain, download=True)
        orig_test  = CIFAR100(base_dir, False, transform=ttrain, download=True)
    # shuffle and split
    idxs = np.arange(len(full_train))
    rng = np.random.default_rng(seed)
    rng.shuffle(idxs)
    hold = idxs[:10000]
    remain = idxs[10000:]
    train_ds = ConcatDataset([Subset(full_train, remain), orig_test])
    test_ds  = Subset(full_train, hold)
    return IndexedDataset(train_ds), IndexedDataset(test_ds), hold

# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
def main():
    parser=argparse.ArgumentParser()
    parser.add_argument('--dataset',choices=['cifar10_n','cifar100_n'],required=True)
    parser.add_argument('--arch',choices=HPARAMS.keys(),required=True)
    parser.add_argument('--seed',type=int,required=True)
    parser.add_argument('--method',choices=['msp','sat'],required=True)
    parser.add_argument('--data_dir',required=True)
    parser.add_argument('--ckpt_dir',required=True)
    args=parser.parse_args()

    args.epochs = EPOCHS_DEFAULT[args.dataset]
    ttrain,ttest = make_transforms(args.dataset)
    train_ds,test_ds,held_idx = load_dataset(args.dataset,ttrain,ttest,args.data_dir,args.seed)

    # save held-out indices
    out_dir = './indices'
    os.makedirs(out_dir, exist_ok=True)
    np.save(os.path.join(out_dir,f'{args.dataset}_holdout_seed{args.seed}.npy'), held_idx)
    print(f'Held-out indices saved to {out_dir}/{args.dataset}_holdout_seed{args.seed}.npy')

    # DataLoaders
    batch = HPARAMS[args.arch][args.dataset][2]
    train_loader=DataLoader(train_ds,batch_size=batch,shuffle=True)
    test_loader=DataLoader(test_ds,batch_size=256,shuffle=False)

    # Model, optimizer, scheduler (as before)
    extra = (args.method=='sat')
    model = get_model(args.arch,args.dataset,extra_class=extra).to('cuda')
    lr,wd,_ = HPARAMS[args.arch][args.dataset]
    optim = __import__('torch.optim', fromlist=['SGD']).SGD(model.parameters(),lr=lr,momentum=0.9,weight_decay=wd,nesterov=True)
    sched = CosineAnnealingLR(optim,T_max=args.epochs)
    crit = SelfAdaptiveTraining(len(train_ds), model.fc.out_features-1 if extra else model.fc.out_features) if args.method=='sat' else None

    best=0.0
    for e in range(1,args.epochs+1):
        model.train()
        for x,y,idx in tqdm(train_loader,desc=f'Epoch {e}'):
            x,y = x.cuda(),y.cuda()
            optim.zero_grad()
            logits = model(x)
            loss = F.cross_entropy(logits,y) if args.method=='msp' or e<=args.epochs//2 else crit(logits,y,idx)
            loss.backward();optim.step()
        # validation on test_ds
        model.eval()
        total=correct=0
        with torch.no_grad():
            for x,y,idx in test_loader:
                x,y=x.cuda(),y.cuda()
                preds=model(x).argmax(dim=1)
                correct+= (preds==y).sum().item(); total+=y.size(0)
        acc = correct/total
        sched.step()
        print(f'Epoch {e}: Test Acc={acc:.4f}')
        if acc>best:
            best=acc
            os.makedirs(os.path.join(args.ckpt_dir,args.dataset), exist_ok=True)
            torch.save(model.state_dict(), os.path.join(args.ckpt_dir,args.dataset,f'{args.arch}_{args.dataset}_{args.method}_seed{args.seed}.pt'))
    print(f'Done. Best={best:.4f}')

if __name__=='__main__': main()