#!/usr/bin/env python3
"""
Train a loss-prediction head (LP) using one or more base models (supports single, temperature-scaled, or deep-ensemble inputs).
"""
import argparse
import os
import random
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms as T
from torchvision.datasets import CIFAR10, CIFAR100, StanfordCars
from tqdm import tqdm

# ---------------------------------------------------------------------
# Loss Predictor Definition
# ---------------------------------------------------------------------
class LossPredictor(nn.Module):
    def __init__(self, feat_dim, hidden1=128, hidden2=64, dropout=0.5, include_sep=False):
        super().__init__()
        self.include_sep = include_sep
        input_dim = feat_dim + (1 if include_sep else 0)
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden1), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden1, hidden2), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden2, 1)
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

# ---------------------------------------------------------------------
# Indexed Dataset Wrapper
# ---------------------------------------------------------------------
class IndexedDataset(torch.utils.data.Dataset):
    def __init__(self, ds): self.ds = ds
    def __len__(self): return len(self.ds)
    def __getitem__(self, idx):
        img, target = self.ds[idx][0], self.ds[idx][1]
        return img, target, idx

# ---------------------------------------------------------------------
# Utilities: load model with optional abstain class
# ---------------------------------------------------------------------
def get_model(arch, dataset, extra_class=False):
    # number of base classes per dataset
    base_cls_map = {
        "cifar10": 10,
        "cifar100": 100,
        "stanfordcars": 196,
    }
    base_cls = base_cls_map[dataset]
    num_classes = base_cls + (1 if extra_class else 0)
    model = torchvision.models.resnet18(weights=None, num_classes=num_classes)
    # adjust for smaller inputs if necessary (not needed for stanfordcars)
    if dataset in ["cifar10","cifar100"]:
        model.conv1 = nn.Conv2d(3,64,3,1,1,bias=False)
        model.maxpool = nn.Identity()
    return model


def make_transforms(dataset):
    if dataset in ["cifar10","cifar100"]:
        mean,std=(0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616)
        ttrain=T.Compose([T.RandomCrop(32,4),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean,std)])
        ttest=T.Compose([T.ToTensor(),T.Normalize(mean,std)])
    elif dataset=="tinyimagenet200":
        mean,std=(0.485,0.456,0.406),(0.229,0.224,0.225)
        ttrain=T.Compose([T.RandomResizedCrop(64),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean,std)])
        ttest=T.Compose([T.Resize(64),T.CenterCrop(64),T.ToTensor(),T.Normalize(mean,std)])
    elif dataset in ["stanfordcars","camelyon17","fmow"]:
        mean,std=(0.485,0.456,0.406),(0.229,0.224,0.225)
        ttrain=T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean,std)])
        ttest=T.Compose([T.Resize(256),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean,std)])
    else:
        raise ValueError(f"Unknown dataset {dataset}")
    return ttrain, ttest

# ---------------------------------------------------------------------
# Main training
# ---------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", choices=["cifar10","cifar100","stanfordcars"], required=True)
    parser.add_argument("--arch", choices=["resnet18"], default="resnet18")
    parser.add_argument("--method", choices=["msp","sat","temp","de"], required=True,
                        help="Method: msp, sat, temp (temp-scaled msp), or de (deep ensemble)")
    parser.add_argument("--base_method", choices=["msp","sat"], default=None,
                        help="For de, base method (default msp)")
    parser.add_argument("--ensemble_seeds", type=str, default=None,
                        help="Seeds for ensemble (comma-separated), default single seed")
    parser.add_argument("--temp", type=float, default=1.0,
                        help="Temperature for scaling logits (used if method=temp)")
    parser.add_argument("--seed", type=int, required=True,
                        help="Seed for LP training/data split")
    parser.add_argument("--reduced_train", action="store_true",
                        help="Use reduced train split (keep LP indices)")
    parser.add_argument('--data_dir', required=True,
                        help='Root dir for dataset')
    parser.add_argument('--ckpt_dir', required=True,
                        help='Root dir for checkpoints')
    parser.add_argument("--indices_dir", default="./indices")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--wd", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--hidden1", type=int, default=128)
    parser.add_argument("--hidden2", type=int, default=64)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--include_sep", action="store_true")
    args = parser.parse_args()

    random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)
    if args.device.startswith("cuda"): torch.cuda.manual_seed_all(args.seed)

    # determine base_method
    if args.method == "de": base_method = args.base_method or "msp"
    elif args.method == "temp": base_method = "msp"
    else: base_method = args.method

    # parse seeds
    seeds = [int(s) for s in args.ensemble_seeds.split(",")] if args.method=="de" and args.ensemble_seeds else [args.seed]

    suffix = "_reduced" if args.reduced_train else ""
    ckpt_subdir = os.path.join(args.ckpt_dir, args.dataset)

    # load base models
    base_models, backbones = [], []
    for s in seeds:
        path = os.path.join(ckpt_subdir,
            f"{args.arch}_{args.dataset}_{base_method}_seed{s}{suffix}.pt")
        m = get_model(args.arch, args.dataset, extra_class=(base_method=="sat")).to(args.device)
        m.load_state_dict(torch.load(path, map_location=args.device))
        m.eval(); base_models.append(m)
        backbones.append(nn.Sequential(*list(m.children())[:-1]).to(args.device))

    # data
    ttrain, ttest = make_transforms(args.dataset)
    if args.dataset == "cifar10":
        ds_full = CIFAR10(args.data_dir, True, transform=ttrain, download=True)
        ds_test_raw = CIFAR10(args.data_dir, False, transform=ttest, download=True)
    elif args.dataset == "cifar100":
        ds_full = CIFAR100(args.data_dir, True, transform=ttrain, download=True)
        ds_test_raw = CIFAR100(args.data_dir, False, transform=ttest, download=True)
    else:  # stanfordcars
        ds_full = StanfordCars(args.data_dir, split="train", transform=ttrain, download=False)
        ds_test_raw = StanfordCars(args.data_dir, split="test", transform=ttest, download=False)
    ds_indexed = IndexedDataset(ds_full)
    ds_test = IndexedDataset(ds_test_raw)

    # load LP indices
    idx_file = os.path.join(args.indices_dir,
        f"{args.dataset}_{args.arch}_{base_method}_seed{args.seed}_lp_indices.npy")
    lp_idx = np.load(idx_file); lp_ds = Subset(ds_indexed, lp_idx.tolist())
    lp_loader = DataLoader(lp_ds, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(ds_test, batch_size=args.batch_size, shuffle=False)

    # setup LP head
    feat_dim = base_models[0].fc.in_features
    lp = LossPredictor(feat_dim, hidden1=args.hidden1, hidden2=args.hidden2,
                       dropout=args.dropout, include_sep=args.include_sep).to(args.device)
    opt = optim.Adam(lp.parameters(), lr=args.lr, weight_decay=args.wd)
    sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', factor=0.5, patience=2)

    records=[]
    for epoch in range(1, args.epochs+1):
        lp.train(); tr_base=tr_lp=0.0
        for x,y,_ in tqdm(lp_loader, desc=f"Epoch {epoch}"):
            x,y = x.to(args.device), y.to(args.device)
            feats = torch.stack([bb(x).view(x.size(0),-1) for bb in backbones]).mean(0)
            logits_stack = [m(x) for m in base_models]
            if args.method=="temp":
                probs = torch.stack([F.softmax(l/args.temp,dim=1) for l in logits_stack]).mean(0)
            else:
                probs = torch.stack([F.softmax(l,dim=1) for l in logits_stack]).mean(0)
            sep = 1 - probs.max(dim=1)[0]
            true = (probs.argmax(dim=1) != y).float()
            tr_base += F.mse_loss(sep,true,reduction='sum').item()
            if args.include_sep: feats = torch.cat([feats, sep.unsqueeze(1)], dim=1)
            pred = lp(feats); loss = F.mse_loss(pred, true, reduction='sum')
            opt.zero_grad(); loss.backward(); opt.step()
            tr_lp += loss.item()
        btr = tr_base/len(lp_ds); ltr = tr_lp/len(lp_ds); adv_tr = btr - ltr

        lp.eval(); te_base=te_lp=0.0
        with torch.no_grad():
            for x,y,_ in test_loader:
                x,y = x.to(args.device), y.to(args.device)
                feats = torch.stack([bb(x).view(x.size(0),-1) for bb in backbones]).mean(0)
                logits_stack = [m(x) for m in base_models]
                if args.method=="temp":
                    probs = torch.stack([F.softmax(l/args.temp,dim=1) for l in logits_stack]).mean(0)
                else:
                    probs = torch.stack([F.softmax(l,dim=1) for l in logits_stack]).mean(0)
                sep = 1 - probs.max(dim=1)[0]
                true = (probs.argmax(dim=1) != y).float()
                te_base += F.mse_loss(sep,true,reduction='sum').item()
                if args.include_sep: feats = torch.cat([feats, sep.unsqueeze(1)], dim=1)
                pred = lp(feats); te_lp += F.mse_loss(pred, true, reduction='sum').item()
        bte = te_base/len(ds_test); lte = te_lp/len(ds_test); adv_te = bte - lte

        records.append([epoch, btr, ltr, adv_tr, bte, lte, adv_te])
        print(f"Epoch {epoch}: BASE Train={btr:.4f}, LP Train={ltr:.4f}, Adv Train={adv_tr:.4f} | "
              f"BASE Test={bte:.4f}, LP Test={lte:.4f}, Adv Test={adv_te:.4f}")
        sch.step(adv_te)

    os.makedirs("./results", exist_ok=True)
    csv_path = os.path.join("./results",
        f"adv_{args.arch}_{args.dataset}_{args.method}_seed{args.seed}.csv")
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "baseline_train", "lp_train", "adv_train", "baseline_test", "lp_test", "adv_test"])
        writer.writerows(records)
    print(f"Saved metrics to {csv_path}")

    outname = os.path.join(ckpt_subdir,
        f"lp_{args.arch}_{args.dataset}_{args.method}_seed{args.seed}.pt")
    os.makedirs(os.path.dirname(outname), exist_ok=True)
    torch.save(lp.state_dict(), outname)
    print(f"Saved LP model to {outname}")

if __name__ == "__main__": main()