import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List, Dict
# import mlflow
import re

class ResponseDataset(Dataset):
    def __init__(self, response_data: np.ndarray, train_mask: np.ndarray):
        self.data = torch.tensor(response_data, dtype=torch.float32)
        self.train_mask = torch.tensor(train_mask, dtype=torch.float32)

        self.train_indices = self.train_mask == 1

        self.test_indices = self.train_mask == 0

        self.val_indices = self.train_mask == 2

        self.valid_data = self.data != -1

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        responses = self.data[idx]
        train_mask = self.train_indices[idx]
        return responses, train_mask, idx

class Standard3PLIRT(nn.Module):
    def __init__(
        self,
        response_data: np.ndarray,
        train_mask: np.ndarray,
        student_names: List[str],
        test_names: List[str],
        split_difficulty: bool = True,
        split_ability: bool = True,
        lr: float = 1e-3,
        batch_size: int = 64,
        max_epochs: int = 1000,
        device: str = None,
        eps: float = 1e-3,
        embedding_dim: int = 32,          
        hidden_sizes: List[int] = [128, 64, 32], 
        theta_init: float = 0.5,
        theta_max: float = 1.0,  
        difficulty_base_init: float = 0.0,
        difficulty_base_max: float = 3.0,
        difficulty_base_min: float = 0.0,
        difficulty_other_init: float = 0.0,
        difficulty_other_max: float = 1.5,
        difficulty_other_min: float = 0,
        a_init: float = 0.2,
        a_scale: float = 3.0,  
        c_init: float = 0.25,
        enable_abs_clamp: bool = True,  
        use_guessing: bool = False
    ):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        self.device = self.device if torch.cuda.is_available() else torch.device("cpu")
        self.split_difficulty = split_difficulty
        self.split_ability = split_ability
        self.enable_abs_clamp = enable_abs_clamp  
        self.use_guessing = use_guessing      
        print(f"Using device: {self.device}")
        print(f"Split difficulty: {split_difficulty}")
        print(f"Split ability: {split_ability}")
        print(f"Absolute clamp enabled: {enable_abs_clamp}")
        print(f"Use guessing (c parameter): {self.use_guessing}")
        print("I'm MMMultimodal IRT")

        valid_data = response_data != -1
        train_data = np.array(train_mask) == 1
        test_data = np.array(train_mask) == 0
        val_data = np.array(train_mask) == 2

        # print(f"train_mask's 2 ratios: {np.sum(train_mask == 0) / train_mask.size}, {np.sum(train_mask == 1) / train_mask.size}, {np.sum(train_mask == 2) / train_mask.size}")

        total_valid = np.sum(valid_data)
        total_train = np.sum(train_data & valid_data)
        total_test = np.sum(test_data & valid_data)
        total_val = np.sum(val_data & valid_data)

        if total_valid > 0:
            train_ratio = total_train / total_valid * 100
            test_ratio = total_test / total_valid * 100
            val_ratio = total_val / total_valid * 100
            print(f"Training data ratio: {train_ratio:.1f}% ({total_train}/{total_valid} valid responses)")
            print(f"Test data ratio: {test_ratio:.1f}% ({total_test}/{total_valid} valid responses)")
            print(f"Validation data ratio: {val_ratio:.1f}% ({total_val}/{total_valid} valid responses)")
        else:
            print("No valid responses found in the data")

        self.eps = eps
        self.data = response_data
        self.train_mask = train_mask
        self.num_students = response_data.shape[0]
        self.num_items = response_data.shape[1]
        self.student_names = student_names
        self.test_names = test_names
        self.difficulty_base_max = difficulty_base_max
        self.difficulty_base_min = difficulty_base_min
        self.difficulty_other_max = difficulty_other_max
        self.difficulty_other_min = difficulty_other_min
        self.a_scale = a_scale  
        self.theta_max = theta_max


        if split_difficulty:
            self.problem_base_names = []
            self.problem_to_base = {}
            for name in test_names:
                base_name = re.sub(r"^(?:<no_question>|<no_image>|<no_info>)", "", name)
                if base_name not in self.problem_base_names:
                    self.problem_base_names.append(base_name)
                self.problem_to_base[name] = base_name
        else:
            self.problem_base_names = test_names
            self.problem_to_base = {name: name for name in test_names}

        self.num_unique_problems = len(self.problem_base_names)


        self.a_base_raw = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))
        self.a_text_raw = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))
        self.a_image_raw = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))
        self.a_synergy_raw = nn.Parameter(torch.full((self.num_unique_problems,), a_init, device=self.device))

        self.c_raw = nn.Parameter(torch.full((self.num_unique_problems,), c_init, device=self.device))


        if split_difficulty:
            # b_base: [0, difficulty_base_max]
            self.b_base_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_base_init, device=self.device))
            # b_text, b_image, b_synergy: [-difficulty_other_max, 0]
            self.b_text_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
            self.b_image_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
            self.b_synergy_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
        else:
            # b: [0, difficulty_base_max]
            self.b_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_base_init, device=self.device))


        if split_difficulty:
            self.param_indices = torch.tensor(
                [self.problem_base_names.index(self.problem_to_base[name]) for name in test_names],
                device=self.device
            )
        else:
            self.param_indices = torch.tensor(
                [self.problem_base_names.index(name) for name in test_names],
                device=self.device
            )


        output_dim = 4 if split_ability else 1
        self.theta = nn.Parameter(torch.full((self.num_students, output_dim), theta_init, device=self.device))


        tags = []
        for name in self.test_names:
            if "<no_info>" in name:
                tags.append([0.0,0.0])
            elif "<no_image>" in name:
                tags.append([1.0, 0.0])
            elif "<no_question>" in name:
                tags.append([0.0, 1.0])
            else:
                tags.append([1.0, 1.0])
        self.register_buffer("mask", torch.tensor(tags, dtype=torch.float32))

        self.to(self.device)
        self.dataset = ResponseDataset(response_data, train_mask)
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.max_epochs = max_epochs
        self.loss_history = []

    def get_parameters(self):
        
        if self.split_difficulty:
            a_base = self.a_base_raw[self.param_indices]
            a_text = self.a_text_raw[self.param_indices]
            a_image = self.a_image_raw[self.param_indices]
            a_synergy = self.a_synergy_raw[self.param_indices]
            c = self.c_raw[self.param_indices]
            b_base = self.b_base_raw[self.param_indices]
            b_text = self.b_text_raw[self.param_indices]
            b_image = self.b_image_raw[self.param_indices]
            b_synergy = self.b_synergy_raw[self.param_indices]
            return a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c
        else:
            a_base = self.a_base_raw[self.param_indices]
            a_text = self.a_text_raw[self.param_indices]
            a_image = self.a_image_raw[self.param_indices]
            a_synergy = self.a_synergy_raw[self.param_indices]
            c = self.c_raw[self.param_indices]
            b = self.b_raw[self.param_indices]
            return a_base, a_text, a_image, a_synergy, b, c

    def compute_b_full(self, *args):
        if self.split_difficulty:
            b_base, b_text, b_image, b_synergy = args
            r_text = self.mask[:, 0].to(self.device)
            r_image = self.mask[:, 1].to(self.device)
            return (r_image * b_image.to(self.device)
                    + r_text * b_text.to(self.device)
                    + b_base.to(self.device)
                    + r_text * r_image * b_synergy.to(self.device))
        else:
            b = args[0]
            return b.to(self.device)

    def forward(self, responses, response_mask, student_idx):
        if self.split_ability:
            
            theta_components = self.theta[student_idx]  # shape: (batch_size, 4)
            theta_base = theta_components[:, 0].unsqueeze(1)  # (batch,1)
            theta_text = theta_components[:, 1].unsqueeze(1)
            theta_image = theta_components[:, 2].unsqueeze(1)
            theta_synergy = theta_components[:, 3].unsqueeze(1)

            
            if self.split_difficulty:
                a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
            else:
                a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                b_full = self.compute_b_full(b)
            
            r_text = self.mask[:, 0].unsqueeze(0)  # shape: (1, num_items)
            r_image = self.mask[:, 1].unsqueeze(0)  # shape: (1, num_items)

            
            a_base = a_base.unsqueeze(0)
            a_text = a_text.unsqueeze(0)
            a_image = a_image.unsqueeze(0)
            a_synergy = a_synergy.unsqueeze(0)
            b_full = b_full.unsqueeze(0)
            c = c.unsqueeze(0)

            
            linear = (a_base * theta_base +
                      r_image * a_image * theta_image +
                      r_text * a_text * theta_text +
                      a_image * r_text * a_synergy * theta_synergy -
                      b_full)
            linear = torch.clamp(linear, min=-15, max=15)
            
            
            if self.use_guessing:
                p = c + (1 - c) / (1 + torch.exp(-linear))
            else:
                p = 1 / (1 + torch.exp(-linear))
            p = torch.clamp(p, self.eps, 1 - self.eps)

            valid_responses = responses[response_mask]
            valid_p = p[response_mask]

            if valid_responses.numel() > 0:
                nll = -(valid_responses * torch.log(valid_p) + (1 - valid_responses) * torch.log(1 - valid_p)).mean()
            else:
                nll = torch.tensor(0.0, device=self.device)
            return nll
        else:
            
            theta_components = self.theta[student_idx]
            theta_full = theta_components.squeeze(-1)

            params = self.get_parameters()
            if self.split_difficulty:
                a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = params
                b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
                a = a_base
            else:
                a_base, a_text, a_image, a_synergy, b, c = params
                b_full = self.compute_b_full(b)
                a = a_base

            z = (theta_full - b_full) * a
            z = torch.clamp(z, min=-15, max=15)
            if self.use_guessing:
                p = c + (1 - c) / (1 + torch.exp(-z))
            else:
                p = 1 / (1 + torch.exp(-z))
            p = torch.clamp(p, self.eps, 1 - self.eps)

            valid_responses = responses[response_mask]
            valid_p = p[response_mask]

            if valid_responses.numel() > 0:
                nll = -(valid_responses * torch.log(valid_p) + (1 - valid_responses) * torch.log(1 - valid_p)).mean()
            else:
                nll = torch.tensor(0.0, device=self.device)
            return nll

    def get_estimates(self):
        theta_components = self.theta.detach().cpu().numpy()
        if self.split_ability:
            theta_dict = {
                self.student_names[i]: {
                    "theta_base": theta_components[i, 0],
                    "theta_text": theta_components[i, 1],
                    "theta_image": theta_components[i, 2],
                    "theta_synergy": theta_components[i, 3],
                }
                for i in range(self.num_students)
            }
        else:
            theta_dict = {
                self.student_names[i]: {
                    "theta": float(theta_components[i, 0])
                }
                for i in range(self.num_students)
            }

        params = self.get_parameters()
        if self.split_difficulty:
            a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = [p.detach().cpu().numpy() for p in params]
            b_full = self.compute_b_full(
                torch.tensor(b_base, device=self.device),
                torch.tensor(b_text, device=self.device),
                torch.tensor(b_image, device=self.device),
                torch.tensor(b_synergy, device=self.device),
            ).detach().cpu().numpy()
        else:
            a_base, a_text, a_image, a_synergy, b, c = [p.detach().cpu().numpy() for p in params]
            b_full = b

        result = {
            "discrimination_base": dict(zip(self.test_names, a_base)),
            "discrimination_text": dict(zip(self.test_names, a_text)),
            "discrimination_image": dict(zip(self.test_names, a_image)),
            "discrimination_synergy": dict(zip(self.test_names, a_synergy)),
            "difficulty_full": dict(zip(self.test_names, b_full)),
            "guessing": dict(zip(self.test_names, c)),
            "theta": theta_dict,
        }

        if self.split_difficulty:
            result.update({
                "difficulty_base": dict(zip(self.test_names, b_base)),
                "difficulty_text": dict(zip(self.test_names, b_text)),
                "difficulty_image": dict(zip(self.test_names, b_image)),
                "difficulty_synergy": dict(zip(self.test_names, b_synergy)),
            })

        return result

    def clamp_parameters(self):
        with torch.no_grad():
            
            positive_params = [
                (self.theta, self.theta_max),
                (self.a_base_raw, self.a_scale),
                (self.a_text_raw, self.a_scale),
                (self.a_image_raw, self.a_scale),
                (self.a_synergy_raw, self.a_scale),
                (self.c_raw, 0.5)
            ]   
            self.theta.data[:, 0].clamp_(0, 0.1)
            for param, max_val in positive_params:
                if self.enable_abs_clamp:
                    param.data.clamp_(1e-4, max_val)
                else:
                    param.data.clamp_(min=1e-4)

            if self.split_difficulty:
                
                if self.enable_abs_clamp:
                    self.b_base_raw.data.clamp_(self.difficulty_base_min , self.difficulty_base_max)
                else:
                    self.b_base_raw.data.clamp_(min=0)

                
                negative_params = [
                    self.b_text_raw,
                    self.b_image_raw,
                    self.b_synergy_raw
                ]
                
                for param in negative_params:
                    if self.enable_abs_clamp:
                        param.data.clamp_(-self.difficulty_other_max, -self.difficulty_other_min)
                    else:
                        param.data.clamp_(max=0)
            else:
                # b: [0, difficulty_base_max]
                if self.enable_abs_clamp:
                    self.b_raw.data.clamp_(0, self.difficulty_base_max)
                else:
                    self.b_raw.data.clamp_(min=0)
                    
    def fit(self):
        # mlflow.set_experiment("Standard3PLIRT_av_logloss")
        # with mlflow.start_run():
        best_loss = float("inf")
        best_epoch = 0
        patience = 100 

        for epoch in range(self.max_epochs):
            epoch_loss = 0.0
            batch_count = 0

            for responses, response_mask, student_idx in self.dataloader:
                responses = responses.to(self.device)
                response_mask = response_mask.to(self.device)
                student_idx = student_idx.to(self.device)

                self.optimizer.zero_grad()
                loss = self.forward(responses, response_mask, student_idx)

                if loss.item() > 0:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                    self.optimizer.step()
                    
                    self.clamp_parameters()

                    epoch_loss += loss.item()
                    batch_count += 1

            if batch_count > 0:
                avg_loss = epoch_loss / batch_count
                self.loss_history.append(avg_loss)
                # mlflow.log_metric("average_loss", avg_loss, step=epoch + 1)

                if avg_loss < best_loss:
                    best_loss = avg_loss
                    best_epoch = epoch

                if epoch - best_epoch >= patience:
                    print(f"Early stopping triggered. No improvement for {patience} epochs.")
                    print(f"Best loss was {best_loss:.4f} at epoch {best_epoch + 1}")
                    break

                if (epoch + 1) % 100 == 0 or epoch == 0:
                    print(f"Epoch {epoch+1}/{self.max_epochs}, Loss: {avg_loss:.4f}")

        # mlflow.end_run()
        return self.get_estimates()

    def evaluate_predictions(self) -> Dict[str, float]:
        
        self.eval()
        with torch.no_grad():
            
            test_mask = self.dataset.test_indices & self.dataset.valid_data
            test_pairs = [(i, j) for i in range(self.num_students) for j in range(self.num_items) if test_mask[i, j]]

            if not test_pairs:
                return {}

            
            student_indices = torch.tensor([pair[0] for pair in test_pairs], device=self.device)
            item_indices = torch.tensor([pair[1] for pair in test_pairs], device=self.device)

            
            theta_components = self.theta[student_indices]

            
            masks = self.mask[item_indices]

            
            if self.split_ability:
                theta_base, theta_text, theta_image, theta_synergy = (
                    theta_components[:, 0], theta_components[:, 1],
                    theta_components[:, 2], theta_components[:, 3]
                )
                if self.split_difficulty:
                    (a_base, a_text, a_image, a_synergy,
                     b_base, b_text, b_image, b_synergy, c) = self.get_parameters()
                    
                    a_base = a_base[item_indices]
                    a_text = a_text[item_indices]
                    a_image = a_image[item_indices]
                    a_synergy = a_synergy[item_indices]
                    b_base = b_base[item_indices]
                    b_text = b_text[item_indices]
                    b_image = b_image[item_indices]
                    b_synergy = b_synergy[item_indices]
                    c = c[item_indices]
                    b_full = (masks[:,1] * b_image + masks[:,0] * b_text
                              + b_base + masks[:,0] * masks[:,1] * b_synergy)
                else:
                    (a_base, a_text, a_image, a_synergy, b, c) = self.get_parameters()
                    a_base = a_base[item_indices]
                    c = c[item_indices]
                    b_full = b[item_indices]
                linear = (a_base * theta_base
                          + masks[:,1] * a_image * theta_image
                          + masks[:,0] * a_text * theta_text
                          + a_image * masks[:,0] * a_synergy * theta_synergy
                          - b_full)
                linear = torch.clamp(linear, min=-15, max=15)
                if self.use_guessing:
                    probs = c + (1 - c) / (1 + torch.exp(-linear))
                else:
                    probs = 1 / (1 + torch.exp(-linear))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)
            else:
                theta_full = theta_components.squeeze(-1)
                if self.split_difficulty:
                    (a_base, _, _, _, b_base, b_text, b_image, b_synergy, c) = self.get_parameters()
                    a_base = a_base[item_indices]
                    b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)[item_indices]
                else:
                    (a_base, _, _, _, b, c) = self.get_parameters()
                    a_base = a_base[item_indices]
                    b_full = self.compute_b_full(b)[item_indices]
                z = (theta_full - b_full) * a_base
                z = torch.clamp(z, min=-15, max=15)
                if self.use_guessing:
                    probs = c + (1 - c) / (1 + torch.exp(-z))
                else:
                    probs = 1 / (1 + torch.exp(-z))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)

            pred_binary = (probs > 0.5).float()
            true_responses = torch.tensor(
                self.data[student_indices.cpu(), item_indices.cpu()],
                device=self.device
            )

            probs_np = probs.cpu().numpy()
            true_np = true_responses.cpu().numpy()
            pred_binary_np = pred_binary.cpu().numpy()

            item_names = [self.test_names[idx] for idx in item_indices.cpu().numpy()]
            shuffle_mask = np.array([('shuffle' in name) or ('from' in name) for name in item_names])
            normal_mask = ~shuffle_mask

            num_shuffle = int(shuffle_mask.sum())
            num_normal = int(normal_mask.sum())
            print(f"Number of shuffle items: {num_shuffle}")
            print(f"Number of normal items: {num_normal}")

            from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
            metrics = {
                'roc_auc': roc_auc_score(true_np, probs_np),
                'accuracy': accuracy_score(true_np, pred_binary_np),
                'precision': precision_score(true_np, pred_binary_np),
                'recall': recall_score(true_np, pred_binary_np),
                'f1': f1_score(true_np, pred_binary_np),
            }
            nll = -(true_responses * torch.log(probs)
                    + (1 - true_responses) * torch.log(1 - probs)).mean()
            metrics['perplexity'] = torch.exp(nll).item()

            # roc_auc_shuffle
            if num_shuffle > 0:
                metrics['roc_auc_shuffle'] = roc_auc_score(true_np[shuffle_mask], probs_np[shuffle_mask])
            else:
                metrics['roc_auc_shuffle'] = 0.0

            # roc_auc_normal
            if num_normal > 0:
                metrics['roc_auc_normal'] = roc_auc_score(true_np[normal_mask], probs_np[normal_mask])
            else:
                metrics['roc_auc_normal'] = 0.0

        self.train()
        return metrics

    def predict(self) -> np.ndarray:
        self.eval()
        with torch.no_grad():
            if self.split_ability:
                theta_components = self.theta  # shape: (num_students, 4)
                theta_base = theta_components[:, 0].unsqueeze(1)  # (num_students, 1)
                theta_text = theta_components[:, 1].unsqueeze(1)
                theta_image = theta_components[:, 2].unsqueeze(1)
                theta_synergy = theta_components[:, 3].unsqueeze(1)
                
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
                else:
                    a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    b_full = self.compute_b_full(b)
                
                a_base = a_base.unsqueeze(0)
                a_text = a_text.unsqueeze(0)
                a_image = a_image.unsqueeze(0)
                a_synergy = a_synergy.unsqueeze(0)
                b_full = b_full.unsqueeze(0)
                c = c.unsqueeze(0)
                
                r_text = self.mask[:, 0].unsqueeze(0)
                r_image = self.mask[:, 1].unsqueeze(0)
                
                linear = (a_base * theta_base +
                          r_image * a_image * theta_image +
                          r_text * a_text * theta_text +
                          a_image * r_text * a_synergy * theta_synergy -
                          b_full)
                linear = torch.clamp(linear, min=-15, max=15)
                
                if self.use_guessing:
                    probs = c + (1 - c) / (1 + torch.exp(-linear))
                else:
                    probs = 1 / (1 + torch.exp(-linear))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)
            else:
                theta_full = self.theta.squeeze(-1).unsqueeze(1)  # shape: (num_students, 1)
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
                else:
                    a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    b_full = self.compute_b_full(b)
                a_base = a_base.unsqueeze(0)
                b_full = b_full.unsqueeze(0)
                if self.use_guessing:
                    z = (theta_full - b_full) * a_base
                    z = torch.clamp(z, min=-15, max=15)
                    probs = c.unsqueeze(0) + (1 - c.unsqueeze(0)) / (1 + torch.exp(-z))
                else:
                    z = (theta_full - b_full) * a_base
                    z = torch.clamp(z, min=-15, max=15)
                    probs = 1 / (1 + torch.exp(-z))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)
        self.train()
        return probs.cpu().numpy()
    
    def compute_item_fisher_information(self, model_name: str) -> Dict[str, np.ndarray]:
        self.eval()
        with torch.no_grad():
            fisher_info_dict = {}
            theta_all = self.theta.detach()
            num_students = theta_all.shape[0]

            for j in range(self.num_items):
                item_name = self.test_names[j]
                
                r_text = self.mask[j, 0].item()  # 0 or 1
                r_image = self.mask[j, 1].item()  # 0 or 1

                base_idx = int(self.param_indices[j].item())

                if self.split_ability:
                    
                    a_base = self.a_base_raw[base_idx].item()
                    a_text = self.a_text_raw[base_idx].item()
                    a_image = self.a_image_raw[base_idx].item()
                    a_synergy = self.a_synergy_raw[base_idx].item()
                    
                    c_param = self.c_raw[base_idx].item()

                    if self.split_difficulty:
                        b_base = self.b_base_raw[base_idx].item()
                        b_text = self.b_text_raw[base_idx].item()
                        b_image = self.b_image_raw[base_idx].item()
                        b_synergy = self.b_synergy_raw[base_idx].item()
                        
                        b_full = r_image * b_image + r_text * b_text + b_base + r_text * r_image * b_synergy
                    else:
                        b_full = self.b_raw[base_idx].item()
                    
                    g = torch.tensor([a_base,
                                      r_text * a_text,
                                      r_image * a_image,
                                      r_text * a_synergy * a_image], device=self.device)  # shape: (4,)

                    theta_base = theta_all[:, 0]
                    theta_text = theta_all[:, 1]
                    theta_image = theta_all[:, 2]
                    theta_synergy = theta_all[:, 3]
                    
                    linear = (a_base * theta_base +
                              r_image * a_image * theta_image +
                              r_text * a_text * theta_text +
                              r_text * a_synergy * a_image * theta_synergy -
                              b_full)
                    linear = torch.clamp(linear, min=-15, max=15)
                    
                    if self.use_guessing:
                        # p = c + (1-c) / (1 + exp(-linear))
                        p = c_param + (1 - c_param) / (1 + torch.exp(-linear))
                    else:
                        p = 1 / (1 + torch.exp(-linear))
                    # Clamp p
                    p = torch.clamp(p, self.eps, 1 - self.eps)
                    
                    coef = (p * (1 - p)).unsqueeze(1).unsqueeze(2)  # shape: (num_students, 1, 1)
                    
                    g_outer = torch.ger(g, g)  # outer product, shape (4,4)
                    
                    info_matrices = coef * g_outer  # broadcasting: (num_students, 4,4)
                    info_matrix_avg = torch.mean(info_matrices, dim=0)  # shape (4,4)
                    
                    fisher_info_dict[item_name] = info_matrix_avg.cpu().numpy()
                    
                else:
                    a = self.a_base_raw[base_idx].item()
                    if self.split_difficulty:
                        b_base = self.b_base_raw[base_idx].item()
                        b_text = self.b_text_raw[base_idx].item()
                        b_image = self.b_image_raw[base_idx].item()
                        b_synergy = self.b_synergy_raw[base_idx].item()
                        b_full = r_image * b_image + r_text * b_text + b_base + r_text * r_image * b_synergy
                    else:
                        b_full = self.b_raw[base_idx].item()
                    
                    theta_vals = theta_all.squeeze(-1)  # shape: (num_students,)
                    z = a * (theta_vals - b_full)
                    z = torch.clamp(z, min=-15, max=15)
                    if self.use_guessing:
                        c_param = self.c_raw[base_idx].item()
                        p = c_param + (1 - c_param) / (1 + torch.exp(-z))
                    else:
                        p = 1 / (1 + torch.exp(-z))
                    p = torch.clamp(p, self.eps, 1 - self.eps)
                    info_vals = p * (1 - p) * (a ** 2)  # shape: (num_students,)
                    info_avg = torch.mean(info_vals)
                    
                    fisher_info_dict[item_name] = np.array([[info_avg.item()]])
        self.train()
        return fisher_info_dict

    def update_single_theta(self, model_name: str, problem_names: List[str],
                              lr: float = 1e-3, max_epochs: int = 1000, patience: int = 50):
        if self.split_difficulty:
            selected_indices = [i for i, name in enumerate(self.test_names)
                                if self.problem_to_base[name] in problem_names]
        else:
            selected_indices = [i for i, name in enumerate(self.test_names)
                                if name in problem_names]

        if len(selected_indices) == 0:
            print("No items found for the specified problem names.")
            return self.get_estimates()

        selected_indices_tensor = torch.tensor(selected_indices, dtype=torch.long, device=self.device)

        optimizer = torch.optim.Adam([self.theta], lr=lr)

        best_loss = float("inf")
        best_epoch = 0
        epochs_no_improve = 0

        self.train()
        
        for epoch in range(max_epochs):
            epoch_loss = 0.0
            batch_count = 0

            for responses, response_mask, student_idx in self.dataloader:
                responses = responses.to(self.device)       # shape: (batch_size, num_items)
                response_mask = response_mask.to(self.device)
                student_idx = student_idx.to(self.device)

                responses_sub = responses[:, selected_indices_tensor]
                mask_sub = response_mask[:, selected_indices_tensor]

                if self.split_ability:
                    theta_batch = self.theta[student_idx]  # shape: (batch,4)
                    
                    selected_param_indices = self.param_indices[selected_indices_tensor]

                    a_base_all, a_text_all, a_image_all, a_synergy_all, b_base_all, b_text_all, b_image_all, b_synergy_all, c_all = self.get_parameters()
                    
                    a_base = a_base_all[selected_param_indices]
                    a_text = a_text_all[selected_param_indices]
                    a_image = a_image_all[selected_param_indices]
                    a_synergy = a_synergy_all[selected_param_indices]
                    c_param = c_all[selected_param_indices]
                    
                    b_base = b_base_all[selected_param_indices]
                    b_text = b_text_all[selected_param_indices]
                    b_image = b_image_all[selected_param_indices]
                    b_synergy = b_synergy_all[selected_param_indices]
                    
                    mask_items = self.mask[selected_indices_tensor]  # shape: (num_selected, 2)
                    # r_text, r_image: (num_selected,)
                    r_text = mask_items[:, 0]
                    r_image = mask_items[:, 1]
                    
                    b_full = r_image * b_image + r_text * b_text + b_base + r_text * r_image * b_synergy
                    
                    a_base = a_base.unsqueeze(0)         # (1, num_selected)
                    a_text = a_text.unsqueeze(0)
                    a_image = a_image.unsqueeze(0)
                    a_synergy = a_synergy.unsqueeze(0)
                    b_full = b_full.unsqueeze(0)         # (1, num_selected)
                    c_param = c_param.unsqueeze(0)         # (1, num_selected)
                    r_text = r_text.unsqueeze(0)           # (1, num_selected)
                    r_image = r_image.unsqueeze(0)         # (1, num_selected)

                    theta_base = theta_batch[:, 0].unsqueeze(1)  # (batch,1)
                    theta_text = theta_batch[:, 1].unsqueeze(1)
                    theta_image = theta_batch[:, 2].unsqueeze(1)
                    theta_synergy = theta_batch[:, 3].unsqueeze(1)

                    linear = (a_base * theta_base +
                              r_image * a_image * theta_image +
                              r_text * a_text * theta_text +
                              (r_text * a_synergy * a_image) * theta_synergy -
                              b_full)
                    linear = torch.clamp(linear, min=-15, max=15)
                    if self.use_guessing:
                        p = c_param + (1 - c_param) / (1 + torch.exp(-linear))
                    else:
                        p = 1 / (1 + torch.exp(-linear))
                    p = torch.clamp(p, self.eps, 1 - self.eps)
                else:
                    theta_batch = self.theta[student_idx].squeeze(-1)  # (batch_size,)
                    selected_param_indices = self.param_indices[selected_indices_tensor]
                    a_all, a_text_all, a_image_all, a_synergy_all, b_all, c_all = self.get_parameters()
                    a_val = a_all[selected_param_indices]
                    c_val = c_all[selected_param_indices]
                    if self.split_difficulty:
                        b_base = self.b_base_raw[selected_param_indices]
                        b_text = self.b_text_raw[selected_param_indices]
                        b_image = self.b_image_raw[selected_param_indices]
                        b_synergy = self.b_synergy_raw[selected_param_indices]
                        mask_items = self.mask[selected_indices_tensor]
                        r_text = mask_items[:, 0]
                        r_image = mask_items[:, 1]
                        b_full = r_image * b_image + r_text * b_text + b_base + r_text * r_image * b_synergy
                    else:
                        b_full = self.b_raw[selected_param_indices]
                    a_val = a_val.unsqueeze(0)
                    b_full = b_full.unsqueeze(0)
                    c_val = c_val.unsqueeze(0)
                    theta_batch = theta_batch.unsqueeze(1)
                    z = a_val * (theta_batch - b_full)
                    z = torch.clamp(z, min=-15, max=15)
                    if self.use_guessing:
                        p = c_val + (1 - c_val) / (1 + torch.exp(-z))
                    else:
                        p = 1 / (1 + torch.exp(-z))
                    p = torch.clamp(p, self.eps, 1 - self.eps)
                
                valid_responses = responses_sub[mask_sub]
                valid_p = p[mask_sub]
                if valid_responses.numel() > 0:
                    
                    loss = -(valid_responses * torch.log(valid_p) + (1 - valid_responses) * torch.log(1 - valid_p)).mean()
                    
                    if not torch.isfinite(loss):
                        continue
                else:
                    continue

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_([self.theta], max_norm=1.0)
                optimizer.step()

                epoch_loss += loss.item()
                batch_count += 1

            if batch_count == 0:
                continue
            avg_loss = epoch_loss / batch_count
            if avg_loss < best_loss:
                best_loss = avg_loss
                best_epoch = epoch
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. Best loss: {best_loss:.4f} at epoch {best_epoch+1}")
                break

            if (epoch+1) % 100 == 0 or epoch == 0:
                print(f"[update_single_theta] Epoch {epoch+1}/{max_epochs}, Loss: {avg_loss:.4f}")

        self.train()
        return self.get_estimates()
    
    def evaluate_validation(self):
        self.eval()
        with torch.no_grad():
            valid_mask = self.dataset.val_indices & self.dataset.valid_data
            valid_pairs = [(i, j) for i in range(self.num_students) for j in range(self.num_items) if valid_mask[i, j]]

            if not valid_pairs:
                return {}

            student_indices = torch.tensor([pair[0] for pair in valid_pairs], device=self.device)
            item_indices = torch.tensor([pair[1] for pair in valid_pairs], device=self.device)

            theta_components = self.theta[student_indices]
            masks = self.mask[item_indices]
            r_text = masks[:, 0]
            r_image = masks[:, 1]

            if self.split_ability:
                theta_base = theta_components[:, 0]
                theta_text = theta_components[:, 1]
                theta_image = theta_components[:, 2]
                theta_synergy = theta_components[:, 3]
                
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    a_base = a_base[item_indices]
                    a_text = a_text[item_indices]
                    a_image = a_image[item_indices]
                    a_synergy = a_synergy[item_indices]
                    b_base = b_base[item_indices]
                    b_text = b_text[item_indices]
                    b_image = b_image[item_indices]
                    b_synergy = b_synergy[item_indices]
                    c = c[item_indices]
                    b_full = r_image * b_image + r_text * b_text + b_base + r_text * r_image * b_synergy
                else:
                    a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    a_base = a_base[item_indices]
                    a_text = a_text[item_indices]
                    a_image = a_image[item_indices]
                    a_synergy = a_synergy[item_indices]
                    b = b[item_indices]
                    c = c[item_indices]
                    b_full = b

                
                linear = (a_base * theta_base +
                          r_image * a_image * theta_image +
                            r_text * a_text * theta_text +
                            a_image * r_text * a_synergy * theta_synergy -
                            b_full)
                linear = torch.clamp(linear, min=-15, max=15)
                if self.use_guessing:
                    probs = c + (1 - c) / (1 + torch.exp(-linear))
                else:
                    probs = 1 / (1 + torch.exp(-linear))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)
            else:
                theta_full = theta_components.squeeze(-1)
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    a_base = a_base[item_indices]
                    b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)[item_indices]
                else:
                    a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    a_base = a_base[item_indices]
                    b_full = self.compute_b_full(b)[item_indices]
                z = (theta_full - b_full) * a_base
                z = torch.clamp(z, min=-15, max=15)
                if self.use_guessing:
                    probs = c + (1 - c) / (1 + torch.exp(-z))
                else:
                    probs = 1 / (1 + torch.exp(-z))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)
            pred_binary = (probs > 0.5).float()
            true_responses = torch.tensor(self.data[student_indices.cpu(), item_indices.cpu()], device=self.device)
            from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
            probs_np = probs.cpu().numpy()
            true_np = true_responses.cpu().numpy()
            pred_binary_np = pred_binary.cpu().numpy()
            metrics = {
                "roc_auc": roc_auc_score(true_np, probs_np),
                "accuracy": accuracy_score(true_np, pred_binary_np),
                "precision": precision_score(true_np, pred_binary_np),
                "recall": recall_score(true_np, pred_binary_np),
                "f1": f1_score(true_np, pred_binary_np),
            }
            nll = -(true_responses * torch.log(probs) + (1 - true_responses) * torch.log(1 - probs)).mean()
            perplexity = torch.exp(nll).item()
            metrics["perplexity"] = perplexity
        self.train()
        return metrics
    
    def compute_item_kli(self, student_id, model_name="Standard3PLIRT"):
        self.eval()  
        with torch.no_grad():
            
            if isinstance(student_id, str):
                student_idx = self.student_names.index(student_id)
            else:
                student_idx = student_id
            
            
            kli_dict = {}
            
            
            theta = self.theta[student_idx].detach().clone()  # (d,) or (1,)
            
            
            n_samples = 25  
            c = 3.0  
            
            
            for j in range(self.num_items):
                item_name = self.test_names[j]
                
                r_text = self.mask[j, 0].item()  # 0 or 1
                r_image = self.mask[j, 1].item()  # 0 or 1
                
                
                base_idx = int(self.param_indices[j].item())
                
                if self.split_ability:
                    
                    a_base = self.a_base_raw[base_idx].item()
                    a_text = self.a_text_raw[base_idx].item()
                    a_image = self.a_image_raw[base_idx].item()
                    a_synergy = self.a_synergy_raw[base_idx].item()
                    c_param = self.c_raw[base_idx].item()
                    
                    if self.split_difficulty:
                        b_base = self.b_base_raw[base_idx].item()
                        b_text = self.b_text_raw[base_idx].item()
                        b_image = self.b_image_raw[base_idx].item()
                        b_synergy = self.b_synergy_raw[base_idx].item()
                        b_full = r_image * b_image + r_text * b_text + b_base + r_text * r_image * b_synergy
                    else:
                        b_full = self.b_raw[base_idx].item()
                    
                    
                    theta_base = theta[0].item()
                    theta_text = theta[1].item()
                    theta_image = theta[2].item()
                    theta_synergy = theta[3].item()
                    
                    linear_current = (a_base * theta_base +
                                    r_image * a_image * theta_image +
                                    r_text * a_text * theta_text +
                                    r_text * a_synergy * a_image * theta_synergy -
                                    b_full)
                    linear_current = max(min(linear_current, 15), -15)  # clamp
                    
                    if self.use_guessing:
                        p_current = c_param + (1 - c_param) / (1 + np.exp(-linear_current))
                    else:
                        p_current = 1 / (1 + np.exp(-linear_current))
                    p_current = max(min(p_current, 1 - self.eps), self.eps)  # clamp
                    
                    
                    kl_sum = 0.0
                    
                    ranges = []
                    stddevs = [0.2, 0.1, 0.1, 0.05]  
                    for d in range(4):
                        ranges.append((theta[d].item() - c * stddevs[d], theta[d].item() + c * stddevs[d]))
                    
                    
                    for _ in range(n_samples):
                        
                        theta_sample = [np.random.uniform(r[0], r[1]) for r in ranges]
                        
                        
                        linear_sample = (a_base * theta_sample[0] +
                                        r_image * a_image * theta_sample[2] +
                                        r_text * a_text * theta_sample[1] +
                                        r_text * a_synergy * a_image * theta_sample[3] -
                                        b_full)
                        linear_sample = max(min(linear_sample, 15), -15)  # clamp
                        
                        if self.use_guessing:
                            p_sample = c_param + (1 - c_param) / (1 + np.exp(-linear_sample))
                        else:
                            p_sample = 1 / (1 + np.exp(-linear_sample))
                        p_sample = max(min(p_sample, 1 - self.eps), self.eps)  # clamp
                        
                        
                        q_current = 1 - p_current
                        q_sample = 1 - p_sample
                        
                        # p_current*log(p_current/p_sample) + q_current*log(q_current/q_sample)
                        kl_value = p_current * np.log(p_current / p_sample) + q_current * np.log(q_current / q_sample)
                        kl_sum += kl_value
                    
                    
                    kli_value = kl_sum / n_samples
                    
                else:
                    
                    a = self.a_base_raw[base_idx].item()
                    if self.split_difficulty:
                        b_base = self.b_base_raw[base_idx].item()
                        b_text = self.b_text_raw[base_idx].item()
                        b_image = self.b_image_raw[base_idx].item()
                        b_synergy = self.b_synergy_raw[base_idx].item()
                        b_full = r_image * b_image + r_text * b_text + b_base + r_text * r_image * b_synergy
                    else:
                        b_full = self.b_raw[base_idx].item()
                    
                    
                    theta_val = theta.item()
                    z_current = a * (theta_val - b_full)
                    z_current = max(min(z_current, 15), -15)  # clamp
                    
                    if self.use_guessing:
                        c_param = self.c_raw[base_idx].item()
                        p_current = c_param + (1 - c_param) / (1 + np.exp(-z_current))
                    else:
                        p_current = 1 / (1 + np.exp(-z_current))
                    p_current = max(min(p_current, 1 - self.eps), self.eps)  # clamp
                    
                    
                    kl_sum = 0.0
                    stddev = 0.2 
                    theta_range = (theta_val - c * stddev, theta_val + c * stddev)
                    
                    
                    for i in range(n_samples):
                        
                        theta_sample = theta_range[0] + (theta_range[1] - theta_range[0]) * i / (n_samples - 1)
                        
                        
                        z_sample = a * (theta_sample - b_full)
                        z_sample = max(min(z_sample, 15), -15)  # clamp
                        
                        if self.use_guessing:
                            p_sample = c_param + (1 - c_param) / (1 + np.exp(-z_sample))
                        else:
                            p_sample = 1 / (1 + np.exp(-z_sample))
                        p_sample = max(min(p_sample, 1 - self.eps), self.eps)  # clamp
                        
                        
                        q_current = 1 - p_current
                        q_sample = 1 - p_sample
                        
                        kl_value = p_current * np.log(p_current / p_sample) + q_current * np.log(q_current / q_sample)
                        kl_sum += kl_value
                    
                    
                    kli_value = kl_sum / n_samples * (theta_range[1] - theta_range[0])
                
                
                kli_dict[item_name] = float(kli_value)
        
        self.train() 
        return kli_dict

    def compute_item_kli_optimized(self, student_id, model_name="Standard3PLIRT"):
        
        self.eval()  
        with torch.no_grad():
            
            if isinstance(student_id, str):
                student_idx = self.student_names.index(student_id)
            else:
                student_idx = student_id
            
            
            kli_dict = {}
            
            
            theta = self.theta[student_idx].detach().clone()  # (d,) or (1,)
            
            
            all_r_text = self.mask[:, 0]  # (num_items,)
            all_r_image = self.mask[:, 1]  # (num_items,)
            all_param_indices = self.param_indices  # (num_items,)
            
            if self.split_ability:
                
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    
                    b_full = all_r_image * b_image + all_r_text * b_text + b_base + all_r_text * all_r_image * b_synergy
                else:
                    a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    b_full = b
                
                
                theta_base = theta[0].item()
                theta_text = theta[1].item()
                theta_image = theta[2].item()
                theta_synergy = theta[3].item()
                
                
                linear_terms = (a_base * theta_base +
                            all_r_image * a_image * theta_image +
                            all_r_text * a_text * theta_text +
                            all_r_text * all_r_image * a_synergy * theta_synergy -
                            b_full)
                linear_terms = torch.clamp(linear_terms, min=-15, max=15)
                
                if self.use_guessing:
                    probs = c + (1 - c) / (1 + torch.exp(-linear_terms))
                else:
                    probs = 1 / (1 + torch.exp(-linear_terms))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)
                
                
                for j in range(self.num_items):
                    
                    r_text = all_r_text[j].item()
                    r_image = all_r_image[j].item()
                    a_base_val = a_base[j].item()
                    a_text_val = a_text[j].item()
                    a_image_val = a_image[j].item()
                    a_synergy_val = a_synergy[j].item()
                    
                    
                    g = np.array([
                        a_base_val,
                        r_text * a_text_val,
                        r_image * a_image_val,
                        r_text * r_image * a_synergy_val
                    ])
                    
                    
                    p = probs[j].item()
                    coefficient = p * (1 - p)
                    
                    
                    uncertainties = np.array([0.2, 0.1, 0.1, 0.05]) 
                    kli_value = np.sum((g**2) * uncertainties) * coefficient
                    
                    kli_dict[self.test_names[j]] = float(kli_value)
            else:
                
                if self.split_difficulty:
                    a_base, a_text, a_image, a_synergy, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    b_full = all_r_image * b_image + all_r_text * b_text + b_base + all_r_text * all_r_image * b_synergy
                else:
                    a_base, a_text, a_image, a_synergy, b, c = self.get_parameters()
                    b_full = b
                
                theta_val = theta.item()
                
                
                z = (theta_val - b_full) * a_base
                z = torch.clamp(z, min=-15, max=15)
                
                if self.use_guessing:
                    probs = c + (1 - c) / (1 + torch.exp(-z))
                else:
                    probs = 1 / (1 + torch.exp(-z))
                probs = torch.clamp(probs, self.eps, 1 - self.eps)
                
                
                for j in range(self.num_items):
                    p = probs[j].item()
                    a_val = a_base[j].item()
                    
                    fisher_info = (a_val ** 2) * p * (1 - p)
                    
                    kli_value = fisher_info * 0.2 
                    
                    kli_dict[self.test_names[j]] = float(kli_value)
        
        self.train() 
        return kli_dict