import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score

BACKBONE = "distilbert-base-uncased"
OOD_LABEL = 150          # 你的数据里 max=150 且 unique_count=151 => intent==150 为 OOD/OOS
NUM_ID_LABELS = 150      # ID intents: 0..149

class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_sigma=1.0):
        super().__init__()
        self.prior_sigma = prior_sigma
        self.w_mu = nn.Parameter(torch.zeros(out_features, in_features))
        self.w_rho = nn.Parameter(torch.full((out_features, in_features), -5.0))
        self.b_mu = nn.Parameter(torch.zeros(out_features))
        self.b_rho = nn.Parameter(torch.full((out_features,), -5.0))

    def _sigma(self, rho):
        return F.softplus(rho) + 1e-8

    def forward(self, x):
        w_sigma = self._sigma(self.w_rho)
        b_sigma = self._sigma(self.b_rho)
        w = self.w_mu + w_sigma * torch.randn_like(self.w_mu)
        b = self.b_mu + b_sigma * torch.randn_like(self.b_mu)
        logits = F.linear(x, w, b)
        kl = self.kl_div(w_sigma, b_sigma)
        return logits, kl

    def kl_div(self, w_sigma, b_sigma):
        prior_var = self.prior_sigma ** 2
        def kl_gauss(mu, sigma):
            var = sigma ** 2
            return 0.5 * torch.sum((var + mu**2) / prior_var - 1.0 + torch.log(prior_var / var))
        return kl_gauss(self.w_mu, w_sigma) + kl_gauss(self.b_mu, b_sigma)

class DistilBERT_BNN(nn.Module):
    def __init__(self, num_labels, proj_dim=256, prior_sigma=1.0):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(BACKBONE)
        hidden = self.backbone.config.hidden_size
        self.proj = nn.Linear(hidden, proj_dim)
        self.bayes = BayesianLinear(proj_dim, num_labels, prior_sigma=prior_sigma)

    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        z = out.last_hidden_state[:, 0, :]
        z = self.proj(z)
        logits, kl = self.bayes(z)
        return logits, kl

def set_backbone_trainable(model: DistilBERT_BNN, trainable: bool, unfreeze_last_n_layers: int = 2):
    for p in model.backbone.parameters():
        p.requires_grad = False
    if trainable:
        layers = model.backbone.transformer.layer
        for layer in layers[-unfreeze_last_n_layers:]:
            for p in layer.parameters():
                p.requires_grad = True

def beta_warmup(step, warmup_steps):
    return 1.0 if warmup_steps <= 0 else min(1.0, step / warmup_steps)

@torch.no_grad()
def mc_probs(model, batch, T=20):
    model.eval()
    probs = []
    for _ in range(T):
        logits, _ = model(**batch)
        probs.append(torch.softmax(logits, dim=-1))
    probs = torch.stack(probs, dim=0)     # (T,B,C)
    mean_probs = probs.mean(dim=0)        # (B,C)
    return probs, mean_probs

@torch.no_grad()
def uncertainty_scores_from_probs(probs_T, mean_probs):
    # probs_T: (T,B,C)
    entropy = -(mean_probs * mean_probs.clamp_min(1e-9).log()).sum(dim=-1)  # (B,)
    ent_each = -(probs_T * probs_T.clamp_min(1e-9).log()).sum(dim=-1)       # (T,B)
    mi = entropy - ent_each.mean(dim=0)                                     # (B,)
    return entropy.cpu().numpy(), mi.cpu().numpy()

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--epochs", type=int, default=6)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--lr_head", type=float, default=1e-3)
    ap.add_argument("--lr_backbone", type=float, default=2e-5)
    ap.add_argument("--kl_warmup_frac", type=float, default=0.15)
    ap.add_argument("--freeze_epochs", type=int, default=2)
    ap.add_argument("--T", type=int, default=20)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    args = ap.parse_args()

    # HF dataset name: clinc_oos (commonly used). Provides train/validation/test + oos splits.
    ds = load_dataset("clinc_oos", "plus")  # "plus" is a common config; if it fails, try without config.

    tokenizer = AutoTokenizer.from_pretrained(BACKBONE)

    def tok(batch):
        return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=48)

    ds = ds.map(tok, batched=True)

    # In clinc_oos, labels are usually in "intent"; might be "label"
    # We'll standardize:
    if "label" in ds["train"].column_names:
        label_col = "label"
    elif "intent" in ds["train"].column_names:
        label_col = "intent"
    else:
        raise ValueError(f"Could not find label column in {ds['train'].column_names}")

    # Identify splits: ID and OOD
    # Common: train/validation/test for ID, and "oos_test" for OOD; sometimes "oos_validation"
    id_train = ds["train"]
    id_val = ds.get("validation", None) or ds.get("val", None)
    id_test = ds.get("test", None)


    ood_test = ds.get("oos_test", None) or ds.get("oos", None)
    # ood_test = ds.get("test", None) or ds.get("oos", None)

    OOD_LABEL = 150

    id_train = ds["train"].filter(lambda x: x["intent"] != OOD_LABEL)  # 基本等于原 train
    val_id = ds["validation"].filter(lambda x: x["intent"] != OOD_LABEL)
    test_id = ds["test"].filter(lambda x: x["intent"] != OOD_LABEL)

    val_ood = ds["validation"].filter(lambda x: x["intent"] == OOD_LABEL)
    test_ood = ds["test"].filter(lambda x: x["intent"] == OOD_LABEL)



    if id_val is None or id_test is None or ood_test is None:
        raise ValueError(f"Expected validation/test/oos_test splits. Available: {list(ds.keys())}")

    # Determine num_labels from ID train labels
    # IMPORTANT: ood_test labels may be a special label; ignore it for num_labels.
    num_labels = int(max(id_train[label_col])) + 1

    def format_split(split):
        split = split.rename_column(label_col, "labels")
        cols = ["input_ids", "attention_mask", "labels"]
        split.set_format(type="torch", columns=cols)
        return split

    id_train = format_split(id_train)
    id_val = format_split(id_val)
    id_test = format_split(id_test)

    # OOD split labels aren't used for classification accuracy; keep placeholder
    ood_test = ood_test.rename_column(label_col, "labels")
    ood_test.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    train_loader = torch.utils.data.DataLoader(id_train, batch_size=args.batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(id_val, batch_size=args.batch_size)
    test_loader = torch.utils.data.DataLoader(id_test, batch_size=args.batch_size)
    ood_loader = torch.utils.data.DataLoader(ood_test, batch_size=args.batch_size)

    model = DistilBERT_BNN(num_labels=num_labels).to(args.device)

    head_params = list(model.proj.parameters()) + list(model.bayes.parameters())
    backbone_params = [p for p in model.backbone.parameters() if p.requires_grad]

    optimizer = torch.optim.AdamW([
        {"params": head_params, "lr": args.lr_head},
        {"params": backbone_params, "lr": args.lr_backbone},
    ], weight_decay=0.01)

    total_steps = args.epochs * len(train_loader)
    kl_warmup_steps = int(args.kl_warmup_frac * total_steps)
    lr_sched = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.06 * total_steps),
        num_training_steps=total_steps,
    )

    dataset_size = len(id_train)
    global_step = 0

    for epoch in range(1, args.epochs + 1):
        model.train()

        if epoch <= args.freeze_epochs:
            set_backbone_trainable(model, trainable=False)
        else:
            set_backbone_trainable(model, trainable=True, unfreeze_last_n_layers=2)

        losses = []
        for batch in train_loader:
            batch = {k: v.to(args.device) for k, v in batch.items()}
            logits, kl = model(batch["input_ids"], batch["attention_mask"])
            ce = F.cross_entropy(logits, batch["labels"])
            beta = beta_warmup(global_step, kl_warmup_steps)
            loss = ce + beta * (kl / dataset_size)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_sched.step()

            losses.append(loss.item())
            global_step += 1

        # ID validation accuracy (single pass)
        model.eval()
        y_true, y_pred = [], []
        for batch in val_loader:
            batch = {k: v.to(args.device) for k, v in batch.items()}
            logits, _ = model(batch["input_ids"], batch["attention_mask"])
            pred = logits.argmax(dim=-1)
            y_true.extend(batch["labels"].cpu().numpy().tolist())
            y_pred.extend(pred.cpu().numpy().tolist())

        acc = accuracy_score(y_true, y_pred)
        print(f"Epoch {epoch}: train_loss={np.mean(losses):.4f} val_acc={acc:.4f}")

    # Final ID test accuracy
    model.eval()
    y_true, y_pred = [], []
    for batch in test_loader:
        batch = {k: v.to(args.device) for k, v in batch.items()}
        logits, _ = model(batch["input_ids"], batch["attention_mask"])
        pred = logits.argmax(dim=-1)
        y_true.extend(batch["labels"].cpu().numpy().tolist())
        y_pred.extend(pred.cpu().numpy().tolist())
    test_acc = accuracy_score(y_true, y_pred)
    print(f"ID Test Accuracy: {test_acc:.4f}")

    # OOD scoring using entropy/MI from MC samples
    # Build scores for combined ID test + OOD test
    all_scores_entropy = []
    all_scores_mi = []
    all_is_ood = []

    def process_loader(loader, is_ood_flag: int):
        nonlocal all_scores_entropy, all_scores_mi, all_is_ood
        for batch in loader:
            batch = {k: v.to(args.device) for k, v in batch.items()}
            probs_T, mean_probs = mc_probs(model, {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}, T=args.T)
            ent, mi = uncertainty_scores_from_probs(probs_T, mean_probs)
            all_scores_entropy.extend(ent.tolist())
            all_scores_mi.extend(mi.tolist())
            all_is_ood.extend([is_ood_flag] * len(ent))

    process_loader(test_loader, is_ood_flag=0)  # ID
    process_loader(ood_loader, is_ood_flag=1)   # OOD

    y = np.array(all_is_ood, dtype=np.int32)
    s_ent = np.array(all_scores_entropy, dtype=np.float64)
    s_mi = np.array(all_scores_mi, dtype=np.float64)

    auroc_ent = roc_auc_score(y, s_ent)
    auprc_ent = average_precision_score(y, s_ent)
    auroc_mi = roc_auc_score(y, s_mi)
    auprc_mi = average_precision_score(y, s_mi)

    print(f"OOD (Entropy) AUROC={auroc_ent:.4f} AUPRC={auprc_ent:.4f}")
    print(f"OOD (MI)      AUROC={auroc_mi:.4f} AUPRC={auprc_mi:.4f}")

if __name__ == "__main__":
    main()