import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List, Dict
import re

class ResponseDataset(Dataset):
    def __init__(self, response_data: np.ndarray, train_mask: np.ndarray):
        self.data = torch.tensor(response_data, dtype=torch.float32)
        self.train_mask = torch.tensor(train_mask, dtype=torch.float32)
        self.train_indices = self.train_mask == 1
        self.test_indices = self.train_mask == 0
        self.val_indices   = self.train_mask == 2
        self.valid_data    = self.data != -1

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.train_indices[idx], idx

class Standard3PLIRT(nn.Module):
    def __init__(
        self,
        response_data: np.ndarray,
        train_mask: np.ndarray,
        student_names: List[str],
        test_names:    List[str],
        split_difficulty: bool = True,
        split_ability:    bool = True,
        lr: float = 1e-3,
        batch_size: int = 64,
        max_epochs:  int = 1000,
        device: str = None,
        eps:   float = 1e-3,
        embedding_dim: int = 32,
        hidden_sizes:  List[int] = [128,64,32],
        theta_init:    float = 0.5,
        theta_max:     float = 1.0,
        difficulty_base_init:  float = 0.0,
        difficulty_base_max:   float = 3.0,
        difficulty_base_min:   float = 0.0,
        difficulty_other_init: float = 0.0,
        difficulty_other_max:  float = 1.5,
        difficulty_other_min:  float = 0,
        a_init:   float = 0.2,
        a_scale:  float = 3.0,
        c_init:   float = 0.25,
        enable_abs_clamp: bool = True,
        use_guessing:    bool = False
    ):
        super().__init__()

        self.device = torch.device("cpu")
        self.split_difficulty = split_difficulty
        self.split_ability    = split_ability
        self.enable_abs_clamp = enable_abs_clamp
        self.use_guessing     = use_guessing
        self.eps              = eps

        valid_data = response_data != -1
        train_data = train_mask == 1
        test_data  = train_mask == 0
        val_data   = train_mask == 2
        total_valid = valid_data.sum()
        total_train = np.logical_and(valid_data, train_data).sum()
        total_test  = np.logical_and(valid_data, test_data).sum()
        total_val   = np.logical_and(valid_data, val_data).sum()
        if total_valid:
            print(f"Train Ratio: {total_train/total_valid*100:.1f}%  Test Ratio {total_test/total_valid*100:.1f}%  Val Ratio: {total_val/total_valid*100:.1f}%")
        print("I'm MMIRT.")
        self.data        = response_data
        self.train_mask  = train_mask
        self.num_students= response_data.shape[0]
        self.num_items   = response_data.shape[1]
        self.student_names = student_names
        self.test_names    = test_names
        self.a_scale     = a_scale
        self.theta_max   = theta_max
        self.difficulty_base_max   = difficulty_base_max
        self.difficulty_base_min   = difficulty_base_min
        self.difficulty_other_max  = difficulty_other_max
        self.difficulty_other_min  = difficulty_other_min
        self.a_init      = a_init
        self.c_init      = c_init
        self.difficulty_base_init = difficulty_base_init
        self.difficulty_other_init = difficulty_other_init
        self.theta_init  = theta_init
        

        if split_difficulty:
            self.problem_base_names = []
            self.problem_to_base = {}
            for name in test_names:
                base = re.sub(r"^(?:<no_question>|<no_image>|<no_info>)", "", name)
                if base not in self.problem_base_names:
                    self.problem_base_names.append(base)
                self.problem_to_base[name] = base
        else:
            self.problem_base_names = test_names[:]
            self.problem_to_base = {n:n for n in test_names}
        self.num_unique_problems = len(self.problem_base_names)

        self.a_base_raw    = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))
        self.a_text_raw    = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))
        self.a_image_raw   = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))
        self.a_synergy_raw = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))
        # guessing
        self.c_raw         = nn.Parameter(torch.full((self.num_unique_problems,), c_init, device=self.device))

        if split_difficulty:
            self.b_base_raw    = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_base_init, device=self.device))
            self.b_text_raw    = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
            self.b_image_raw   = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
            self.b_synergy_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
        else:
            self.b_raw        = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_base_init, device=self.device))

        if split_difficulty:
            self.param_indices = torch.tensor(
                [self.problem_base_names.index(self.problem_to_base[n]) for n in test_names],
                device=self.device, dtype=torch.long
            )
        else:
            self.param_indices = torch.tensor(
                [self.problem_base_names.index(n) for n in test_names],
                device=self.device, dtype=torch.long
            )

        output_dim = 4 if split_ability else 1
        self.theta = nn.Parameter(torch.full((self.num_students, output_dim), theta_init, device=self.device))

        tags = []
        for n in test_names:
            if "<no_info>" in n:     tags.append([0.0,0.0])
            elif "<no_image>" in n:  tags.append([1.0,0.0])
            elif "<no_question>" in n:tags.append([0.0,1.0])
            else:                    tags.append([1.0,1.0])
        self.register_buffer("mask", torch.tensor(tags, dtype=torch.float32))

        # --- Optim / DataLoader ---
        self.to(self.device)
        self.dataset   = ResponseDataset(response_data, train_mask)
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.max_epochs = max_epochs
        self.loss_history = []

    def get_parameters(self):
        if self.split_difficulty:
            a_base    = self.a_base_raw[self.param_indices]
            a_text    = self.a_text_raw[self.param_indices]
            a_image   = self.a_image_raw[self.param_indices]
            a_synergy = self.a_synergy_raw[self.param_indices]
            c         = self.c_raw[self.param_indices]
            b_base    = self.b_base_raw[self.param_indices]
            b_text    = self.b_text_raw[self.param_indices]
            b_image   = self.b_image_raw[self.param_indices]
            b_synergy = self.b_synergy_raw[self.param_indices]
            return a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c
        else:
            a_base    = self.a_base_raw[self.param_indices]
            a_text    = self.a_text_raw[self.param_indices]
            a_image   = self.a_image_raw[self.param_indices]
            a_synergy = self.a_synergy_raw[self.param_indices]
            c         = self.c_raw[self.param_indices]
            b         = self.b_raw[self.param_indices]
            return a_base, a_text, a_image, a_synergy, b, c

    def compute_b_full(self, *args):
        """b_full = b_base + r_text*b_text + r_image*b_image + r_text*r_image*b_synergy"""
        if self.split_difficulty:
            b_base, b_text, b_image, b_synergy = args
            r_text  = self.mask[:,0]
            r_image = self.mask[:,1]
            return (b_base
                    + r_text*r_text * b_text  # r_text*b_text
                    + r_image*r_image*b_image # r_image*b_image
                    + r_text*r_image*b_synergy)
        else:
            return args[0]

    def forward(self, responses, response_mask, student_idx):
        if self.split_ability:
            th = self.theta[student_idx]  # (batch,4)
            theta_base    = th[:,0].unsqueeze(1)
            theta_text    = th[:,1].unsqueeze(1)
            theta_image   = th[:,2].unsqueeze(1)
            theta_synergy = th[:,3].unsqueeze(1)
        else:
            th_full = self.theta[student_idx].squeeze(-1).unsqueeze(1)  # (batch,1)

        params = self.get_parameters()
        if self.split_ability:
            a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = params
            b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy).unsqueeze(0)
        else:
            a_base, a_text, a_image, a_synergy, b, c = params
            b_full = self.compute_b_full(b).unsqueeze(0)

        r_text  = self.mask[:,0].unsqueeze(0)  # (1,num_items)
        r_image = self.mask[:,1].unsqueeze(0)

        if self.split_ability:
            a_full = (a_base.unsqueeze(0)
                      + r_text * a_text.unsqueeze(0)
                      + r_image * a_image.unsqueeze(0)
                      + r_text*r_image * a_synergy.unsqueeze(0))
            theta_full = (theta_base
                      + r_text * theta_text
                      + r_image * theta_image
                      + r_text*r_image * theta_synergy)
            linear = a_full * theta_full - b_full  
        else:
            a_full = (a_base
                      + r_text.squeeze(0)*a_text
                      + r_image.squeeze(0)*a_image
                      + (r_text*r_image).squeeze(0)*a_synergy).unsqueeze(0)
            linear = a_full * th_full - b_full

        linear = torch.clamp(linear, -15, 15)
        if self.use_guessing:
            p = c.unsqueeze(0) + (1 - c.unsqueeze(0)) / (1 + torch.exp(-linear))
        else:
            p = 1 / (1 + torch.exp(-linear))
        p = torch.clamp(p, self.eps, 1-self.eps)

        valid = response_mask
        rp = p[valid]
        rr = responses[valid]
        if rr.numel() > 0:
            nll = -(rr*torch.log(rp) + (1-rr)*torch.log(1-rp)).mean()
        else:
            nll = torch.tensor(0.0, device=self.device)
        return nll

    def fit(self):
        best_loss = float("inf")
        best_epoch = 0
        patience = 100
        for epoch in range(self.max_epochs):
            epoch_loss = 0.0
            cnt = 0
            for resp, mask, idx in self.dataloader:
                resp = resp.to(self.device)
                mask = mask.to(self.device)
                idx  = idx.to(self.device)
                self.optimizer.zero_grad()
                loss = self.forward(resp, mask, idx)
                if loss.item() > 0:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
                    self.optimizer.step()
                    self._clamp_params()
                    epoch_loss += loss.item()
                    cnt += 1
            if cnt > 0:
                avg = epoch_loss/cnt
                self.loss_history.append(avg)
                if avg < best_loss:
                    best_loss = avg; best_epoch = epoch
                if epoch - best_epoch >= patience:
                    print(f"Early stopping at epoch {epoch+1}, best loss {best_loss:.4f}")
                    break
                if (epoch+1)%100==0 or epoch==0:
                    print(f"Epoch {epoch+1}/{self.max_epochs}  Loss={avg:.4f}")
        return self.get_estimates()

    def _clamp_params(self):
        with torch.no_grad():
            # theta: [0, theta_max]
            self.theta.data.clamp_(0, self.theta_max)
            self.theta.data[:, 0].clamp_(0, 0.1)
            for p in [self.a_base_raw, self.a_text_raw, self.a_image_raw, self.a_synergy_raw]:
                p.data.clamp_(1e-4, self.a_scale)
            # c: [1e-4, 0.5]
            self.c_raw.data.clamp_(1e-4, 0.5)
            if self.split_difficulty:
                self.b_base_raw.data.clamp_(self.difficulty_base_min, self.difficulty_base_max)
                for p in [self.b_text_raw, self.b_image_raw, self.b_synergy_raw]:
                    p.data.clamp_(-self.difficulty_other_max, 0.0)
            else:
                self.b_raw.data.clamp_(0.0, self.difficulty_base_max)

    def get_estimates(self) -> Dict:
        theta_np = self.theta.detach().cpu().numpy()
        if self.split_ability:
            theta_dict = {
                self.student_names[i]: {
                    "theta_base":    theta_np[i,0],
                    "theta_text":    theta_np[i,1],
                    "theta_image":   theta_np[i,2],
                    "theta_synergy": theta_np[i,3],
                }
                for i in range(self.num_students)
            }
        else:
            theta_dict = {
                self.student_names[i]: {"theta": float(theta_np[i,0])}
                for i in range(self.num_students)
            }

        params = self.get_parameters()
        res = {
            "theta": theta_dict,
        }
        if self.split_difficulty:
            a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = [p.detach().cpu().numpy() for p in params]
            b_full = self.compute_b_full(
                torch.tensor(b_base), torch.tensor(b_text),
                torch.tensor(b_image), torch.tensor(b_synergy)
            ).numpy()
        else:
            a_base, a_text, a_image, a_synergy, b, c = [p.detach().cpu().numpy() for p in params]
            b_full = b

        res.update({
            "discrimination_base":   dict(zip(self.test_names, a_base)),
            "discrimination_text":   dict(zip(self.test_names, a_text)),
            "discrimination_image":  dict(zip(self.test_names, a_image)),
            "discrimination_synergy":dict(zip(self.test_names, a_synergy)),
            "difficulty_full":       dict(zip(self.test_names, b_full)),
            "guessing":              dict(zip(self.test_names, c)),
        })
        return res

    def compute_item_fisher_information(self, model_name: str) -> Dict[str, np.ndarray]:
        self.eval()
        with torch.no_grad():
            fisher = {}
            theta_all = self.theta.detach()
            for j, name in enumerate(self.test_names):
                r_text  = self.mask[j,0].item()
                r_image = self.mask[j,1].item()
                idx     = self.param_indices[j].item()
                a_base    = self.a_base_raw[idx].item()
                a_text    = self.a_text_raw[idx].item()
                a_image   = self.a_image_raw[idx].item()
                a_synergy = self.a_synergy_raw[idx].item()
                if self.split_difficulty:
                    b_base  = self.b_base_raw[idx].item()
                    b_text  = self.b_text_raw[idx].item()
                    b_image = self.b_image_raw[idx].item()
                    b_synergy = self.b_synergy_raw[idx].item()
                    b_full = b_base + r_text*b_text + r_image*b_image + r_text*r_image*b_synergy
                else:
                    b_full = self.b_raw[idx].item()
                # a_full, theta_full
                a_full = a_base + r_text*a_text + r_image*a_image + r_text*r_image*a_synergy
                theta_vals = theta_all[:,0] if not self.split_ability else (
                    theta_all[:,0] + r_text*theta_all[:,1] + r_image*theta_all[:,2] + r_text*r_image*theta_all[:,3]
                )
                # p_i(theta)
                z = a_full * (theta_vals) - b_full
                z = torch.clamp(z, -15, 15)
                if self.use_guessing:
                    c = self.c_raw[idx].item()
                    p = c + (1-c)/(1+torch.exp(-z))
                else:
                    p = 1/(1+torch.exp(-z))
                p = torch.clamp(p, self.eps, 1-self.eps)
                # fisher = a_full^2 * p*(1-p)
                info = (a_full**2) * p*(1-p)
                fisher[name] = info.mean().cpu().numpy() 
        self.train()
        return fisher
    
    def evaluate_validation(self) -> Dict[str, float]:
        self.eval()
        with torch.no_grad():
            valid_mask = self.dataset.val_indices & self.dataset.valid_data
            valid_pairs = [
                (i, j)
                for i in range(self.num_students)
                for j in range(self.num_items)
                if valid_mask[i, j]
            ]
            if not valid_pairs:
                return {}

            student_indices = torch.tensor([p[0] for p in valid_pairs], device=self.device)
            item_indices    = torch.tensor([p[1] for p in valid_pairs], device=self.device)

            if self.split_ability:
                theta_comp = self.theta[student_indices]
            else:
                theta_comp = self.theta[student_indices].squeeze(-1)
            masks = self.mask[item_indices]
            r_text  = masks[:, 0]
            r_image = masks[:, 1]

            if self.split_ability:
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, \
                    b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    a_base    = a_base[item_indices]
                    a_text    = a_text[item_indices]
                    a_image   = a_image[item_indices]
                    a_synergy = a_synergy[item_indices]
                    b_base    = b_base[item_indices]
                    b_text    = b_text[item_indices]
                    b_image   = b_image[item_indices]
                    b_synergy = b_synergy[item_indices]
                    c         = c[item_indices]
                    b_full = b_base + r_text*b_text + r_image*b_image + r_text*r_image*b_synergy
                else:
                    # a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    # a_base  = a_base[item_indices]
                    # c       = c[item_indices]
                    # b_full  = b[item_indices]
                    raise NotImplementedError("split_ability but not split_difficulty is not implemented.")

                theta_base    = theta_comp[:, 0]
                theta_text    = theta_comp[:, 1]
                theta_image   = theta_comp[:, 2]
                theta_synergy = theta_comp[:, 3]
                theta_full = (theta_base
                              + r_text*theta_text
                                + r_image*theta_image
                                + r_text*r_image*theta_synergy)
                a_full = (a_base
                            + r_text*a_text
                            + r_image*a_image
                            + r_text*r_image*a_synergy)

                linear = linear = a_full * theta_full - b_full
            else:
                if self.split_difficulty:
                    a_base, _, _, _, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    a_base   = a_base[item_indices]
                    b_full   = (
                        b_base[item_indices]
                        + r_text*b_text[item_indices]
                        + r_image*b_image[item_indices]
                        + r_text*r_image*b_synergy[item_indices]
                    )
                else:
                    a_base, _, _, _, b, c = self.get_parameters()
                    a_base  = a_base[item_indices]
                    b_full  = b[item_indices]

                theta_full = theta_comp
                linear = a_base * (theta_full - b_full)

            # clamp, sigmoid, clamp
            linear = torch.clamp(linear, -15, 15)
            if self.use_guessing:
                probs = c + (1 - c) / (1 + torch.exp(-linear))
            else:
                probs = 1 / (1 + torch.exp(-linear))
            probs = torch.clamp(probs, self.eps, 1 - self.eps)

            true_responses = torch.tensor(
                self.data[student_indices.cpu(), item_indices.cpu()],
                device=self.device
            )
            pred_binary = (probs > 0.5).float()

            from sklearn.metrics import (
                roc_auc_score, accuracy_score,
                precision_score, recall_score, f1_score
            )
            probs_np     = probs.cpu().numpy()
            true_np      = true_responses.cpu().numpy()
            pred_np      = pred_binary.cpu().numpy()

            metrics = {
                "roc_auc":    roc_auc_score(true_np, probs_np),
                "accuracy":   accuracy_score(true_np, pred_np),
                "precision":  precision_score(true_np, pred_np),
                "recall":     recall_score(true_np, pred_np),
                "f1":         f1_score(true_np, pred_np),
            }
            # perplexity
            nll = -(true_responses * torch.log(probs) + (1 - true_responses) * torch.log(1 - probs)).mean()
            metrics["perplexity"] = torch.exp(nll).item()

            import numpy as _np
            names = [self.test_names[i] for i in item_indices.cpu().numpy()]
            shuffle_mask = _np.array([("shuffle" in n or "from" in n) for n in names])
            normal_mask  = ~shuffle_mask

            if shuffle_mask.any():
                metrics["roc_auc_shuffle"] = roc_auc_score(true_np[shuffle_mask], probs_np[shuffle_mask])
            else:
                metrics["roc_auc_shuffle"] = 0.0
            if normal_mask.any():
                metrics["roc_auc_normal"] = roc_auc_score(true_np[normal_mask], probs_np[normal_mask])
            else:
                metrics["roc_auc_normal"] = 0.0

        self.train()
        return metrics
    
    def evaluate_predictions(self) -> Dict[str, float]:
        
        self.eval()
        with torch.no_grad():
            valid_mask = self.dataset.test_indices & self.dataset.valid_data
            valid_pairs = [
                (i, j)
                for i in range(self.num_students)
                for j in range(self.num_items)
                if valid_mask[i, j]
            ]
            if not valid_pairs:
                return {}

            student_indices = torch.tensor([p[0] for p in valid_pairs], device=self.device)
            item_indices    = torch.tensor([p[1] for p in valid_pairs], device=self.device)

            if self.split_ability:
                theta_comp = self.theta[student_indices]
            else:
                theta_comp = self.theta[student_indices].squeeze(-1)
            masks = self.mask[item_indices]
            r_text  = masks[:, 0]
            r_image = masks[:, 1]

            if self.split_ability:
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, \
                    b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    a_base    = a_base[item_indices]
                    a_text    = a_text[item_indices]
                    a_image   = a_image[item_indices]
                    a_synergy = a_synergy[item_indices]
                    b_base    = b_base[item_indices]
                    b_text    = b_text[item_indices]
                    b_image   = b_image[item_indices]
                    b_synergy = b_synergy[item_indices]
                    c         = c[item_indices]
                    b_full = b_base + r_text*b_text + r_image*b_image + r_text*r_image*b_synergy
                else:
                    # a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    # a_base  = a_base[item_indices]
                    # c       = c[item_indices]
                    # b_full  = b[item_indices]
                    raise NotImplementedError("split_ability but not split_difficulty is not implemented.")

                theta_base    = theta_comp[:, 0]
                theta_text    = theta_comp[:, 1]
                theta_image   = theta_comp[:, 2]
                theta_synergy = theta_comp[:, 3]
                theta_full = (theta_base
                              + r_text*theta_text
                                + r_image*theta_image
                                + r_text*r_image*theta_synergy)
                a_full = (a_base
                            + r_text*a_text
                            + r_image*a_image
                            + r_text*r_image*a_synergy)

                linear = linear = a_full * theta_full - b_full
            else:
                if self.split_difficulty:
                    a_base, _, _, _, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    a_base   = a_base[item_indices]
                    b_full   = (
                        b_base[item_indices]
                        + r_text*b_text[item_indices]
                        + r_image*b_image[item_indices]
                        + r_text*r_image*b_synergy[item_indices]
                    )
                else:
                    a_base, _, _, _, b, c = self.get_parameters()
                    a_base  = a_base[item_indices]
                    b_full  = b[item_indices]

                theta_full = theta_comp
                linear = a_base * (theta_full - b_full)
            item_names = [self.test_names[idx] for idx in item_indices.cpu().numpy()]
            shuffle_mask = np.array([('shuffle' in name) or ('from' in name) for name in item_names])
            normal_mask = ~shuffle_mask
            num_shuffle = int(shuffle_mask.sum())
            num_normal = int(normal_mask.sum())
            print(f"Number of shuffle items: {num_shuffle}")
            print(f"Number of normal items: {num_normal}")
            # clamp, sigmoid, clamp
            linear = torch.clamp(linear, -15, 15)
            if self.use_guessing:
                probs = c + (1 - c) / (1 + torch.exp(-linear))
            else:
                probs = 1 / (1 + torch.exp(-linear))
            probs = torch.clamp(probs, self.eps, 1 - self.eps)

            true_responses = torch.tensor(
                self.data[student_indices.cpu(), item_indices.cpu()],
                device=self.device
            )
            pred_binary = (probs > 0.5).float()

            from sklearn.metrics import (
                roc_auc_score, accuracy_score,
                precision_score, recall_score, f1_score
            )
            probs_np     = probs.cpu().numpy()
            true_np      = true_responses.cpu().numpy()
            pred_np      = pred_binary.cpu().numpy()
            metrics = {
                "roc_auc":    roc_auc_score(true_np, probs_np),
                "accuracy":   accuracy_score(true_np, pred_np),
                "precision":  precision_score(true_np, pred_np),
                "recall":     recall_score(true_np, pred_np),
                "f1":         f1_score(true_np, pred_np),
            }
            # perplexity
            nll = -(true_responses * torch.log(probs) + (1 - true_responses) * torch.log(1 - probs)).mean()
            metrics["perplexity"] = torch.exp(nll).item()
            # roc_auc_shuffle
            if num_shuffle > 0:
                metrics['roc_auc_shuffle'] = roc_auc_score(true_np[shuffle_mask], probs_np[shuffle_mask])
            else:
                metrics['roc_auc_shuffle'] = 0.0

            # roc_auc_normal
            if num_normal > 0:
                metrics['roc_auc_normal'] = roc_auc_score(true_np[normal_mask], probs_np[normal_mask])
            else:
                metrics['roc_auc_normal'] = 0.0
        self.train()
        return metrics


    def update_single_theta(
        self,
        model_name: str,
        problem_names: List[str],
        lr: float = 1e-3,
        max_epochs: int = 1000,
        patience: int = 50
    ) -> Dict:
        if self.split_difficulty:
            selected_indices = [
                i for i, name in enumerate(self.test_names)
                if self.problem_to_base[name] in problem_names
            ]
        else:
            selected_indices = [
                i for i, name in enumerate(self.test_names)
                if name in problem_names
            ]
        if not selected_indices:
            print("Input problem names are not found in the model.")
            return self.get_estimates()

        sel_idx = torch.tensor(selected_indices, dtype=torch.long, device=self.device)
        optimizer = torch.optim.Adam([self.theta], lr=lr)

        best_loss = float("inf")
        epochs_no_improve = 0

        self.train()
        for epoch in range(max_epochs):
            epoch_loss = 0.0
            batch_count = 0

            for responses, response_mask, student_idx in self.dataloader:
                responses = responses.to(self.device)
                response_mask = response_mask.to(self.device)
                student_idx = student_idx.to(self.device)

                resp_sub = responses[:, sel_idx]
                mask_sub = response_mask[:, sel_idx]

                
                if self.split_ability:
                    theta_batch = self.theta[student_idx]  # (batch,4)
                else:
                    theta_batch = self.theta[student_idx].squeeze(-1).unsqueeze(1)  # (batch,1)

                
                params = self.get_parameters()
                if self.split_ability:
                    a_base_all, a_text_all, a_image_all, a_synergy_all, \
                    b_base_all, b_text_all, b_image_all, b_synergy_all, c_all = params

                    a_base    = a_base_all[sel_idx]
                    a_text    = a_text_all[sel_idx]
                    a_image   = a_image_all[sel_idx]
                    a_synergy = a_synergy_all[sel_idx]
                    b_base    = b_base_all[sel_idx]
                    b_text    = b_text_all[sel_idx]
                    b_image   = b_image_all[sel_idx]
                    b_synergy = b_synergy_all[sel_idx]
                    c_param   = c_all[sel_idx]
                else:
                    a_base_all, a_text_all, a_image_all, a_synergy_all, b_all, c_all = params
                    a_base  = a_base_all[sel_idx]
                    c_param = c_all[sel_idx]
                    if self.split_difficulty:
                        b_base    = self.b_base_raw[sel_idx]
                        b_text    = self.b_text_raw[sel_idx]
                        b_image   = self.b_image_raw[sel_idx]
                        b_synergy = self.b_synergy_raw[sel_idx]
                    else:
                        b        = b_all[sel_idx]

                
                tags = self.mask[sel_idx]  # (n_selected,2)
                r_text  = tags[:,0].unsqueeze(0)  # (1,n)
                r_image = tags[:,1].unsqueeze(0)

                
                if self.split_ability:
                    
                    tb = theta_batch[:,0].unsqueeze(1)
                    tt = theta_batch[:,1].unsqueeze(1)
                    ti = theta_batch[:,2].unsqueeze(1)
                    ts = theta_batch[:,3].unsqueeze(1)
                    # b_full, a_full, theta_full
                    b_full = b_base.unsqueeze(0) + r_text*b_text.unsqueeze(0) + r_image*b_image.unsqueeze(0) + r_text*r_image*b_synergy.unsqueeze(0)
                    a_full = a_base.unsqueeze(0) + r_text*a_text.unsqueeze(0) + r_image*a_image.unsqueeze(0) + r_text*r_image*a_synergy.unsqueeze(0)
                    theta_full = tb + r_text*tt + r_image*ti + r_text*r_image*ts
                    linear = a_full * theta_full - b_full
                else:
                    if self.split_difficulty:
                        b_full = b_base.unsqueeze(0) + r_text*b_text.unsqueeze(0) + r_image*b_image.unsqueeze(0) + r_text*r_image*b_synergy.unsqueeze(0)
                    else:
                        b_full = b.unsqueeze(0)
                    a_full = a_base.unsqueeze(0) + r_text*a_text.unsqueeze(0) + r_image*a_image.unsqueeze(0) + r_text*r_image*a_synergy.unsqueeze(0)
                    theta_full = theta_batch
                    linear = a_full * (theta_full) - b_full

                
                linear = torch.clamp(linear, -15, 15)
                if self.use_guessing:
                    p = c_param.unsqueeze(0) + (1 - c_param.unsqueeze(0)) / (1 + torch.exp(-linear))
                else:
                    p = 1 / (1 + torch.exp(-linear))
                p = torch.clamp(p, self.eps, 1 - self.eps)

                
                valid = mask_sub.bool()
                if valid.any():
                    rp = p[valid]
                    rr = resp_sub[valid]
                    loss = -(rr * torch.log(rp) + (1 - rr) * torch.log(1 - rp)).mean()
                else:
                    continue

                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_([self.theta], 1.0)
                optimizer.step()

                self._clamp_params()

                epoch_loss += loss.item()
                batch_count += 1

            #
            if batch_count > 0:
                avg_loss = epoch_loss / batch_count
                if avg_loss < best_loss:
                    best_loss = avg_loss
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"[update_single_theta] Early stopping at epoch {epoch+1}")
                break

        self.train()
        return self.get_estimates()
