import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List, Dict
# import mlflow
import re

class ResponseDataset(Dataset):
    def __init__(self, response_data: np.ndarray, train_mask: np.ndarray):
        
        self.data = torch.tensor(response_data, dtype=torch.float32)
        self.train_mask = torch.tensor(train_mask, dtype=torch.float32)

        
        self.train_indices = self.train_mask == 1

        
        self.test_indices = self.train_mask == 0

        
        self.valid_data = self.data != -1

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        responses = self.data[idx]
        train_mask = self.train_indices[idx]
        return responses, train_mask, idx

class Standard3PLIRT(nn.Module):
    def __init__(
        self,
        response_data: np.ndarray,
        train_mask: np.ndarray,
        student_names: List[str],
        test_names: List[str],
        split_difficulty: bool = False,
        split_ability: bool = False,
        lr: float = 1e-3,
        batch_size: int = 64,
        max_epochs: int = 1000,
        device: str = None,
        eps: float = 1e-3,
        embedding_dim: int = 32,          
        hidden_sizes: List[int] = [128, 64, 32], 
        theta_init: float = 0.5,
        theta_max: float = 1.0, 
        difficulty_base_init: float = 0.0,
        difficulty_base_max: float = 3.0,
        difficulty_other_init: float = 0.0,
        difficulty_other_max: float = 1.5,
        a_init: float = 0.2,
        a_scale: float = 3.0,  
        c_init: float = 0.25,
        enable_abs_clamp: bool = False,
        use_guessing: bool = False       
    ):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        self.device = self.device if torch.cuda.is_available() else torch.device("cpu")
        self.split_difficulty = split_difficulty
        self.split_ability = split_ability
        self.enable_abs_clamp = enable_abs_clamp  
        self.use_guessing = use_guessing          
        print(f"Using device: {self.device}")
        print(f"Split difficulty: {split_difficulty}")
        print(f"Split ability: {split_ability}")
        print(f"Absolute clamp enabled: {enable_abs_clamp}")
        print(f"Use guessing parameter c: {use_guessing}")
        print("I'm not asplit Multimodal IRT")
        
        valid_data = response_data != -1
        train_data = np.array(train_mask) == 1
        test_data = np.array(train_mask) == 0

        total_valid = np.sum(valid_data)
        total_train = np.sum(train_data & valid_data)
        total_test = np.sum(test_data & valid_data)

        if total_valid > 0:
            train_ratio = total_train / total_valid * 100
            test_ratio = total_test / total_valid * 100
            print(f"Training data ratio: {train_ratio:.1f}% ({total_train}/{total_valid} valid responses)")
            print(f"Test data ratio: {test_ratio:.1f}% ({total_test}/{total_valid} valid responses)")
        else:
            print("No valid responses found in the data")

        self.eps = eps
        self.data = response_data
        self.train_mask = train_mask
        self.num_students = response_data.shape[0]
        self.num_items = response_data.shape[1]
        self.student_names = student_names
        self.test_names = test_names
        self.difficulty_base_max = difficulty_base_max
        self.difficulty_other_max = difficulty_other_max
        self.a_scale = a_scale  
        self.theta_max = theta_max 

        
        self.a_raw = nn.Parameter(torch.full((self.num_items,), a_init, device=self.device))
        self.c_raw = nn.Parameter(torch.full((self.num_unique() if split_difficulty else self.num_items,), c_init, device=self.device))

        
        if split_difficulty:
            self.problem_base_names = []
            self.problem_to_base = {}
            for name in test_names:
                base_name = re.sub(r"^(?:<no_question>|<no_image>|<no_info>)", "", name)
                if base_name not in self.problem_base_names:
                    self.problem_base_names.append(base_name)
                self.problem_to_base[name] = base_name
        else:
            self.problem_base_names = test_names
            self.problem_to_base = {name: name for name in test_names}

        if split_difficulty:
            self.num_unique_problems = len(self.problem_base_names)
        else:
            self.num_unique_problems = self.num_items

        
        if split_difficulty:
            # b_base: [0, difficulty_base_max]
            self.b_base_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_base_init, device=self.device))
            # b_text, b_image, b_synergy: [-difficulty_other_max, 0]
            self.b_text_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
            self.b_image_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
            self.b_synergy_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_other_init, device=self.device))
        else:
            # b: [0, difficulty_base_max]
            self.b_raw = nn.Parameter(torch.full((self.num_unique_problems,), difficulty_base_init, device=self.device))

        
        if split_difficulty:
            self.param_indices = torch.tensor(
                [self.problem_base_names.index(self.problem_to_base[name]) for name in test_names],
                device=self.device
            )
        else:
            self.param_indices = torch.tensor(
                [self.problem_base_names.index(name) for name in test_names],
                device=self.device
            )

        
        output_dim = 4 if split_ability else 1
        self.theta = nn.Parameter(torch.full((self.num_students, output_dim), theta_init, device=self.device))

        
        tags = []
        for name in self.test_names:
            if "<no_info>" in name:
                tags.append([0.0, 0.0])
            elif "<no_image>" in name:
                tags.append([1.0, 0.0])
            elif "<no_question>" in name:
                tags.append([0.0, 1.0])
            else:
                tags.append([1.0, 1.0])
        self.register_buffer("mask", torch.tensor(tags, dtype=torch.float32))

        self.to(self.device)
        self.dataset = ResponseDataset(response_data, train_mask)
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=0)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.max_epochs = max_epochs
        self.loss_history = []

    def num_unique(self):
        
        names = []
        for name in self.test_names:
            base_name = re.sub(r"^(?:<no_question>|<no_image>|<no_info>)", "", name)
            if base_name not in names:
                names.append(base_name)
        return len(names)

    def get_parameters(self):
        
        if self.split_difficulty:
            a = torch.exp(self.a_raw)  # shape: (num_items,)
            c = self.c_raw[self.param_indices]
            b_base = self.b_base_raw[self.param_indices]
            b_text = self.b_text_raw[self.param_indices]
            b_image = self.b_image_raw[self.param_indices]
            b_synergy = self.b_synergy_raw[self.param_indices]
            return a, b_base, b_text, b_image, b_synergy, c
        else:
            a = torch.exp(self.a_raw)  # shape: (num_items,)
            c = self.c_raw[self.param_indices]
            b = self.b_raw[self.param_indices]
            return a, b, c

    def compute_b_full(self, *args):
        if self.split_difficulty:
            b_base, b_text, b_image, b_synergy = args
            r_text = self.mask[:, 0].to(self.device)
            r_image = self.mask[:, 1].to(self.device)
            return (r_image * b_image.to(self.device)
                    + r_text * b_text.to(self.device)
                    + b_base.to(self.device)
                    + r_text * r_image * b_synergy.to(self.device))
        else:
            b = args[0]
            return b.to(self.device)

    def compute_theta_full(self, theta_components):
        if self.split_ability:
            r_text = self.mask[:, 0]
            r_image = self.mask[:, 1]
            theta_base = theta_components[:, 0].unsqueeze(1)
            theta_text = theta_components[:, 1].unsqueeze(1)
            theta_image = theta_components[:, 2].unsqueeze(1)
            theta_synergy = theta_components[:, 3].unsqueeze(1)
            return theta_base + theta_text * r_text + theta_image * r_image + theta_synergy * r_text * r_image
        else:
            return theta_components

    def forward(self, responses, response_mask, student_idx):
        
        theta_components = self.theta[student_idx]
        theta_full = self.compute_theta_full(theta_components)

        params = self.get_parameters()
        if self.split_difficulty:
            a, b_base, b_text, b_image, b_synergy, c = params
            b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
        else:
            a, b, c = params
            b_full = self.compute_b_full(b)

        z = (theta_full - b_full) * a
        z = torch.clamp(z, min=-15, max=15)
        
        if self.use_guessing:
            p = c + (1 - c) / (1 + torch.exp(-z))
        else:
            p = 1 / (1 + torch.exp(-z))
        p = torch.clamp(p, self.eps, 1 - self.eps)

        valid_responses = responses[response_mask]
        valid_p = p[response_mask]

        if valid_responses.numel() > 0:
            nll = -(valid_responses * torch.log(valid_p) + (1 - valid_responses) * torch.log(1 - valid_p)).mean()
        else:
            nll = torch.tensor(0.0, device=self.device)

        return nll

    def get_estimates(self):
        theta_components = self.theta.detach().cpu().numpy()
        theta_full = self.compute_theta_full(torch.tensor(theta_components, device=self.device)).detach().cpu().numpy()

        if self.split_ability:
            theta_dict = {
                self.student_names[i]: {
                    "theta_base": theta_components[i, 0],
                    "theta_text": theta_components[i, 1],
                    "theta_image": theta_components[i, 2],
                    "theta_synergy": theta_components[i, 3],
                }
                for i in range(self.num_students)
            }
        else:
            theta_dict = {
                self.student_names[i]: {
                    "theta": float(theta_components[i, 0]),
                    "theta_full": float(theta_full[i])
                }
                for i in range(self.num_students)
            }

        params = self.get_parameters()
        if self.split_difficulty:
            a, b_base, b_text, b_image, b_synergy, c = [p.detach().cpu().numpy() for p in params]
            b_full = self.compute_b_full(
                torch.tensor(b_base, device=self.device),
                torch.tensor(b_text, device=self.device),
                torch.tensor(b_image, device=self.device),
                torch.tensor(b_synergy, device=self.device),
            ).detach().cpu().numpy()
        else:
            a, b, c = [p.detach().cpu().numpy() for p in params]
            b_full = b

        result = {
            "discrimination": dict(zip(self.test_names, a)),
            "difficulty_full": dict(zip(self.test_names, b_full)),
            "guessing": dict(zip(self.test_names, c)),
            "theta": theta_dict,
        }

        if self.split_difficulty:
            result.update({
                "difficulty_base": dict(zip(self.test_names, b_base)),
                "difficulty_text": dict(zip(self.test_names, b_text)),
                "difficulty_image": dict(zip(self.test_names, b_image)),
                "difficulty_synergy": dict(zip(self.test_names, b_synergy)),
            })

        return result
    
    def clamp_parameters(self):
        
        with torch.no_grad():
            
            positive_params = [
                # (self.theta, self.theta_max),
                # (self.a_raw, self.a_scale),
                (self.c_raw, 1.0)
            ]
            if self.split_ability:
                self.theta.data[:, 0].clamp_(0, 0)
            else:
                # self.theta.data[:, 0].clamp_(0, self.theta_max)
                pass
            for param, max_val in positive_params:
                if self.enable_abs_clamp:
                    param.data.clamp_(0, max_val)
                else:
                    param.data.clamp_(min=0)

            if self.split_difficulty:
                # b_base: [0, difficulty_base_max]
                # if self.enable_abs_clamp:
                #     self.b_base_raw.data.clamp_(0, self.difficulty_base_max)
                # else:
                #     self.b_base_raw.data.clamp_(min=0)

                
                # negative_params = [
                #     # self.b_text_raw,
                #     # self.b_image_raw,
                #     # self.b_synergy_raw
                # ]
                
                # for param in negative_params:
                #     if self.enable_abs_clamp:
                #         param.data.clamp_(-self.difficulty_other_max, 0)
                #     else:
                #         param.data.clamp_(max=0)
                pass
            else:
                # b: [0, difficulty_base_max]
                # if self.enable_abs_clamp:
                #     self.b_raw.data.clamp_(0, self.difficulty_base_max)
                # else:
                #     self.b_raw.data.clamp_(min=0)
                pass
                    
    def fit(self):
        # mlflow.set_experiment("Standard3PLIRT_av_logloss")
        # with mlflow.start_run():
        best_loss = float("inf")
        best_epoch = 0
        patience = 100 
        for epoch in range(self.max_epochs):
            epoch_loss = 0.0
            batch_count = 0

            for responses, response_mask, student_idx in self.dataloader:
                responses = responses.to(self.device)
                response_mask = response_mask.to(self.device)
                student_idx = student_idx.to(self.device)

                self.optimizer.zero_grad()
                loss = self.forward(responses, response_mask, student_idx)

                if loss.item() > 0:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                    self.optimizer.step()
                    
                    self.clamp_parameters()

                    epoch_loss += loss.item()
                    batch_count += 1

            if batch_count > 0:
                avg_loss = epoch_loss / batch_count
                self.loss_history.append(avg_loss)
                # mlflow.log_metric("average_loss", avg_loss, step=epoch + 1)

                if avg_loss < best_loss:
                    best_loss = avg_loss
                    best_epoch = epoch

                if epoch - best_epoch >= patience:
                    print(f"Early stopping triggered. No improvement for {patience} epochs.")
                    print(f"Best loss was {best_loss:.4f} at epoch {best_epoch + 1}")
                    break

                if (epoch + 1) % 100 == 0 or epoch == 0:
                    print(f"Epoch {epoch+1}/{self.max_epochs}, Loss: {avg_loss:.4f}")

        # mlflow.end_run()
        return self.get_estimates()

    def evaluate_predictions(self) -> Dict[str, float]:
        
        self.eval()
        with torch.no_grad():
            test_mask = self.dataset.test_indices & self.dataset.valid_data
            test_pairs = [(i, j) for i in range(self.num_students) for j in range(self.num_items) if test_mask[i, j]]

            if not test_pairs:
                return {}

            student_indices = torch.tensor([p[0] for p in test_pairs], device=self.device)
            item_indices    = torch.tensor([p[1] for p in test_pairs], device=self.device)

            theta_components = self.theta[student_indices]
            masks = self.mask[item_indices]
            r_text  = masks[:, 0]
            r_image = masks[:, 1]
            if self.split_ability:
                theta_base    = theta_components[:, 0]
                theta_text    = theta_components[:, 1]
                theta_image   = theta_components[:, 2]
                theta_synergy = theta_components[:, 3]
                theta_full = theta_base + theta_text * r_text + theta_image * r_image + theta_synergy * r_text * r_image
            else:
                theta_full = theta_components.squeeze(-1)

            
            if self.split_difficulty:
                a, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                a = a[item_indices]
                b_full = (r_image * b_image[item_indices] +
                          r_text  * b_text[item_indices] +
                          b_base[item_indices] +
                          r_text * r_image * b_synergy[item_indices])
            else:
                a, b, c = self.get_parameters()
                a = a[item_indices]
                b_full = b[item_indices]

            z = (theta_full - b_full) * a
            z = torch.clamp(z, min=-15, max=15)
            if self.use_guessing:
                probs = c[item_indices] + (1 - c[item_indices]) / (1 + torch.exp(-z))
            else:
                probs = 1 / (1 + torch.exp(-z))
            probs = torch.clamp(probs, self.eps, 1 - self.eps)

            pred_binary = (probs > 0.5).float()
            true_responses = torch.tensor(self.data[student_indices.cpu(), item_indices.cpu()],
                                          device=self.device)

            
            probs_np       = probs.cpu().numpy()
            true_np        = true_responses.cpu().numpy()
            pred_binary_np = pred_binary.cpu().numpy()

            from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score

            
            metrics = {
                "roc_auc": roc_auc_score(true_np, probs_np),
                "accuracy": accuracy_score(true_np, pred_binary_np),
                "precision": precision_score(true_np, pred_binary_np),
                "recall": recall_score(true_np, pred_binary_np),
                "f1": f1_score(true_np, pred_binary_np),
            }

            
            item_idx_np = item_indices.cpu().numpy()
            names = self.test_names
            mask_shuffle = [
                ("shuffle" in names[j] or "from" in names[j])
                for j in item_idx_np
            ]
            import numpy as _np
            mask_shuffle = _np.array(mask_shuffle, dtype=bool)
            mask_normal  = ~mask_shuffle

            # shuffle AUC
            try:
                metrics["roc_auc_shuffle"] = roc_auc_score(
                    true_np[mask_shuffle], probs_np[mask_shuffle]
                )
            except ValueError:
                metrics["roc_auc_shuffle"] = 0

            # normal AUC
            try:
                metrics["roc_auc_normal"] = roc_auc_score(
                    true_np[mask_normal], probs_np[mask_normal]
                )
            except ValueError:
                metrics["roc_auc_normal"] = 0

            # perplexity
            nll = -(true_responses * torch.log(probs) +
                    (1 - true_responses) * torch.log(1 - probs)).mean()
            metrics["perplexity"] = torch.exp(nll).item()

        self.train()
        return metrics

    def predict(self):
        
        self.eval()
        with torch.no_grad():
            
            theta_full = self.compute_theta_full(self.theta)
            
            if self.split_difficulty:
                a, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
            else:
                a, b, c = self.get_parameters()
                b_full = self.compute_b_full(b)

            
            if a.dim() == 1:
                a = a.unsqueeze(0)
            if b_full.dim() == 1:
                b_full = b_full.unsqueeze(0)
            if c.dim() == 1:
                c = c.unsqueeze(0)

            
            z = (theta_full - b_full) * a
            z = torch.clamp(z, min=-15, max=15)
            if self.use_guessing:
                p = c + (1 - c) / (1 + torch.exp(-z))
            else:
                p = 1 / (1 + torch.exp(-z))
            p = torch.clamp(p, self.eps, 1 - self.eps)
            return p.cpu().numpy()
    
    def compute_item_fisher_information(self, model_name: str) -> Dict[str, float]:
        
        print(f"Computing item Fisher information for model: {model_name}")
        self.eval()
        
        with torch.no_grad():
            
            theta_full = self.compute_theta_full(self.theta)  # shape: (num_students, num_items)

            
            if self.split_difficulty:
                a, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                
                b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)  # shape: (num_items,)
            else:
                a, b, c = self.get_parameters()
                b_full = self.compute_b_full(b)  # shape: (num_items,)

            
            if a.dim() == 1:
                a = a.unsqueeze(0)         # shape: (1, num_items)
            if b_full.dim() == 1:
                b_full = b_full.unsqueeze(0)  # shape: (1, num_items)
            if c.dim() == 1:
                c = c.unsqueeze(0)         # shape: (1, num_items)

            
            z = (theta_full - b_full) * a  # shape: (num_students, num_items)
            sigma = 1 / (1 + torch.exp(-z))

            
            if self.use_guessing:
                # p = c + (1-c)*σ(z)
                p = c + (1 - c) * sigma
                dp_dtheta = (1 - c) * a * sigma * (1 - sigma)
            else:
                p = sigma
                dp_dtheta = a * sigma * (1 - sigma)

            
            info = (dp_dtheta ** 2) / (p * (1 - p))
            item_info = info.sum(dim=0)  # shape: (num_items,)

            
            item_info_dict = {name: float(val) for name, val in zip(self.test_names, item_info.cpu().numpy())}

        return item_info_dict

    def update_single_theta(
        self,
        model_name: str,
        problem_names: List[str],
        lr: float = 1e-3,
        max_epochs: int = 1000,
        patience: int = 50
    ):
        
        import torch.nn.utils

        print(f"Updating single theta for model: {model_name}")
        
        selected_indices = [i for i, name in enumerate(self.test_names) if name in problem_names]
        if len(selected_indices) == 0:
            print("Warning: No matching problems found for the specified problem names.")
            return self.get_estimates()
        selected_indices_tensor = torch.tensor(selected_indices, device=self.device)

        backup_requires_grad = {}
        for name, param in self.named_parameters():
            if name != "theta":
                backup_requires_grad[name] = param.requires_grad
                param.requires_grad = False  

        theta_optimizer = optim.Adam([self.theta], lr=lr)

        best_loss = float("inf")
        best_epoch = 0
        no_improve_counter = 0

        self.train()
        for epoch in range(max_epochs):
            epoch_loss = 0.0
            batch_count = 0

            for responses, train_mask, student_idx in self.dataloader:
                responses = responses.to(self.device)             # shape: (batch_size, num_items)
                train_mask = train_mask.to(self.device).bool()    
                student_idx = student_idx.to(self.device)

                responses_sub = responses[:, selected_indices]      # shape: (batch_size, len(selected_indices))
                train_mask_sub = train_mask[:, selected_indices]     
                mask_sub = self.mask[selected_indices].to(self.device)  # shape: (len(selected_indices), 2)

                theta_components = self.theta[student_idx]  # shape: (batch_size, output_dim)
                if self.split_ability:
                    theta_full_batch = self.compute_theta_full(theta_components)  # shape: (batch_size, num_items)
                    theta_sub = theta_full_batch[:, selected_indices]  # shape: (batch_size, len(selected_indices))
                else:
                    
                    theta_batch = self.theta[student_idx]  # shape: (batch_size, 1)
                    theta_sub = theta_batch.expand(-1, len(selected_indices))  # shape: (batch_size, len(selected_indices))

                
                if self.split_difficulty:
                    a, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                    
                    b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
                else:
                    a, b, c = self.get_parameters()
                    b_full = self.compute_b_full(b)
                a_sub = a[selected_indices]            # shape: (len(selected_indices),)
                b_sub = b_full[selected_indices]         # shape: (len(selected_indices),)
                c_sub = c[selected_indices]              # shape: (len(selected_indices),)

                z = (theta_sub - b_sub) * a_sub
                z = torch.clamp(z, min=-15, max=15)
                if self.use_guessing:
                    p = c_sub + (1 - c_sub) / (1 + torch.exp(-z))
                else:
                    p = 1 / (1 + torch.exp(-z))
                p = torch.clamp(p, self.eps, 1 - self.eps)

                valid_responses = responses_sub[train_mask_sub]
                valid_p = p[train_mask_sub]
                if valid_responses.numel() > 0:
                    loss = -(valid_responses * torch.log(valid_p) + (1 - valid_responses) * torch.log(1 - valid_p)).mean()
                else:
                    continue 

                
                if not torch.isfinite(loss):
                    print("Warning: Loss is not finite. Skipping this batch.")
                    continue

                theta_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_([self.theta], max_norm=1.0)
                theta_optimizer.step()

                epoch_loss += loss.item()
                batch_count += 1

            if batch_count > 0:
                avg_loss = epoch_loss / batch_count
                if epoch % 100 == 0 or epoch == 0:
                    print(f"[update_single_theta] Epoch {epoch+1}/{max_epochs}, Loss: {avg_loss:.6f}")

                if avg_loss < best_loss:
                    best_loss = avg_loss
                    best_epoch = epoch
                    no_improve_counter = 0
                else:
                    no_improve_counter += 1

                if no_improve_counter >= patience:
                    print(f"[update_single_theta] Early stopping triggered at epoch {epoch+1}.")
                    break
            else:
                print("[update_single_theta] no valid responses in this batch.")
                break

        for name, param in self.named_parameters():
            if name != "theta" and name in backup_requires_grad:
                param.requires_grad = backup_requires_grad[name]

        
        self.train()

        return self.get_estimates()

    def compute_item_kli(self, student_id: int, n_points: int = 5) -> Dict[str, float]:
        
        import numpy as np
        from numpy.polynomial.hermite import hermgauss
        
        self.eval()
        with torch.no_grad():
            points, weights = hermgauss(n_points)
            
            if isinstance(student_id, int) and 0 <= student_id < self.num_students:
                theta_components = self.theta[student_id:student_id+1]
            else:
                if isinstance(student_id, str) and student_id in self.student_names:
                    idx = self.student_names.index(student_id)
                    theta_components = self.theta[idx:idx+1]
                else:
                    raise ValueError(f"Invalid student_id: {student_id}")
            
            current_theta_full = self.compute_theta_full(theta_components)  # (1, num_items)
            
            if self.split_difficulty:
                a, b_base, b_text, b_image, b_synergy, c = self.get_parameters()
                b_full = self.compute_b_full(b_base, b_text, b_image, b_synergy)
            else:
                a, b, c = self.get_parameters()
                b_full = self.compute_b_full(b)
            
            z_current = (current_theta_full - b_full) * a
            z_current = torch.clamp(z_current, min=-15, max=15)
            if self.use_guessing:
                p_current = c + (1 - c) / (1 + torch.exp(-z_current))
            else:
                p_current = 1 / (1 + torch.exp(-z_current))
            p_current = torch.clamp(p_current, self.eps, 1 - self.eps)
            
            kli_values = {}
            
            for i, problem_name in enumerate(self.test_names):
                pred_estimate = p_current[0, i].item()
                q_estimate = 1 - pred_estimate
                
                alpha = a[i].item()
                beta = b_full[i].item()
                c_param = c[i].item() if self.use_guessing else 0
                
                kli_sum = 0.0
                
                if self.split_ability:
                    theta_current = current_theta_full[0, i].item()
                else:
                    theta_current = theta_components[0, 0].item()
                
                n = 5
                std_dev = 1.0 / np.sqrt(n)
                
                for j in range(n_points):
                    x = std_dev * np.sqrt(2) * points[j] + theta_current
                    
                    z = alpha * (x - beta)
                    z = np.clip(z, -15, 15)
                    
                    if self.use_guessing:
                        p = c_param + (1 - c_param) / (1 + np.exp(-z))
                    else:
                        p = 1 / (1 + np.exp(-z))
                    p = np.clip(p, self.eps, 1 - self.eps)
                    q = 1 - p
                    
                    if p <= 0 or q <= 0:
                        continue
                    kli_value = pred_estimate * np.log(pred_estimate / p) + q_estimate * np.log(q_estimate / q)
                    
                    kli_sum += weights[j] * kli_value
                
                kli_values[problem_name] = kli_sum / np.sqrt(np.pi)
            
        return kli_values