import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.init as init
import numpy as np
from typing import List, Dict
from sklearn.metrics import (
    roc_auc_score, accuracy_score,
    precision_score, recall_score, f1_score
)

class ResponseDataset(Dataset):
    def __init__(self, response_data: np.ndarray, train_mask: np.ndarray):
        self.data = torch.tensor(response_data, dtype=torch.float32)
        self.train_mask = torch.tensor(train_mask, dtype=torch.int8)

        # train/test mask
        self.train_indices = self.train_mask == 1
        self.test_indices  = self.train_mask == 0

        # valid data mask
        self.valid_data = self.data != -1

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        
        return self.data[idx], self.train_indices[idx], idx


class MultiDim3PLIRT(nn.Module):
    def __init__(
        self,
        response_data: np.ndarray,
        train_mask: np.ndarray,
        student_names: List[str],
        test_names: List[str],
        n_dims: int,
        lr: float = 1e-3,
        batch_size: int = 64,
        max_epochs: int = 1000,
        device: str = None,
        eps: float = 1e-6,
        theta_init: float = 0.0,     
        a_init: float = 0.2,         
        a_scale: float = 3.0,
        b_init: float = 0.0,
        b_max: float = 3.0,
        c_init: float = 0.25,
        use_guessing: bool = True
    ):
        super().__init__()
        self.device = torch.device(device if device else
                                   ("cuda" if torch.cuda.is_available() else "cpu"))
        print(f"Using device: {self.device}")

        self.response_data = response_data
        self.train_mask_np = train_mask
        self.num_students, self.num_items = response_data.shape
        self.student_names = student_names
        self.test_names    = test_names
        self.n_dims        = n_dims
        self.eps           = eps
        self.a_scale       = a_scale
        self.b_max         = b_max
        self.use_guessing  = use_guessing

        self.a_raw = nn.Parameter(
            torch.randn(self.num_items, n_dims, device=self.device) * a_init
        )

        self.b_raw = nn.Parameter(
            torch.full((self.num_items,), b_init, device=self.device)
        )

        self.c_raw = nn.Parameter(
            torch.full((self.num_items,), c_init, device=self.device)
        )

        self.theta = nn.Parameter(
            torch.randn(self.num_students, n_dims, device=self.device) * theta_init
        )


        self.dataset    = ResponseDataset(response_data, train_mask)
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        self.optimizer  = optim.Adam(self.parameters(), lr=lr)
        self.max_epochs = max_epochs

        self.to(self.device)

    def get_parameters(self):
        a = torch.clamp(torch.exp(self.a_raw), min=0.0)  # (I, D)
        b = self.b_raw      # (I,)
        c = torch.clamp(self.c_raw, min=0.0, max=1.0)             # (I,)
        return a, b, c

    def forward(self, responses, train_mask, student_idx):
        """
        responses: (B, I), train_mask: (B, I), student_idx: (B,)
        """
        theta_batch = self.theta[student_idx]  # (B, D)
        a, b, c     = self.get_parameters()    # a:(I,D), b,c:(I,)

        # z = a·θ - b
        dot = torch.sum(a.unsqueeze(0) * theta_batch.unsqueeze(1), dim=2)  # (B, I)
        z   = torch.clamp(dot - b.unsqueeze(0), -15, 15)

        if self.use_guessing:
            p = c.unsqueeze(0) + (1 - c).unsqueeze(0) * torch.sigmoid(z)
        else:
            p = torch.sigmoid(z)
        p = torch.clamp(p, self.eps, 1 - self.eps)

        valid = train_mask & (responses != -1)
        if valid.any():
            y       = responses[valid]
            p_valid = p[valid]
            nll = -(y * torch.log(p_valid) + (1 - y) * torch.log(1 - p_valid)).mean()
        else:
            nll = torch.tensor(0.0, device=self.device)
        return nll

    def fit(self) -> Dict[str, Dict]:
        best_loss  = float('inf')
        best_epoch = 0
        patience   = 100

        for epoch in range(1, self.max_epochs + 1):
            total_loss = 0.0
            count      = 0
            for responses, train_mask, stu_idx in self.dataloader:
                responses   = responses.to(self.device)
                train_mask  = train_mask.to(self.device)
                stu_idx     = stu_idx.to(self.device)

                self.optimizer.zero_grad()
                loss = self.forward(responses, train_mask, stu_idx)
                if loss.item() > 0:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
                    self.optimizer.step()
                    total_loss += loss.item()
                    count += 1

            if count > 0:
                avg = total_loss / count
                if avg < best_loss:
                    best_loss, best_epoch = avg, epoch
                if epoch - best_epoch >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
                if epoch == 1 or epoch % 100 == 0:
                    print(f"Epoch {epoch}/{self.max_epochs}, Loss: {avg:.4f}")

        return self.get_estimates()

    def get_estimates(self) -> Dict[str, Dict]:
        a, b, c = [p.detach().cpu().numpy() for p in self.get_parameters()]
        theta   = self.theta.detach().cpu().numpy()

        return {
            "discrimination": {self.test_names[i]: a[i].tolist()   for i in range(self.num_items)},
            "difficulty":     {self.test_names[i]: float(b[i])     for i in range(self.num_items)},
            "guessing":       {self.test_names[i]: float(c[i])     for i in range(self.num_items)},
            "theta":          {self.student_names[i]: theta[i].tolist() for i in range(self.num_students)},
        }

    def evaluate_predictions(self) -> Dict[str, float]:
        self.eval()
        with torch.no_grad():
            mask  = (self.train_mask_np == 0) & (self.response_data != -1)
            pairs = [(i, j)
                     for i in range(self.num_students)
                     for j in range(self.num_items) if mask[i, j]]
            if not pairs:
                return {}

            stud_idx = torch.tensor([i for i, _ in pairs], device=self.device)
            item_idx = torch.tensor([j for _, j in pairs], device=self.device)
            theta_batch = self.theta[stud_idx]
            a, b, c = self.get_parameters()
            dot = torch.sum(a[item_idx] * theta_batch, dim=1)
            z = torch.clamp(dot - b[item_idx], -15, 15)

            if self.use_guessing:
                probs = c[item_idx] + (1 - c[item_idx]) * torch.sigmoid(z)
            else:
                probs = torch.sigmoid(z)
            probs = torch.clamp(probs, self.eps, 1 - self.eps)

            true = torch.tensor([self.response_data[i, j] for i, j in pairs],
                                device=self.device).float()
            pred = (probs > 0.5).float()

            probs_np = probs.cpu().numpy()
            true_np  = true.cpu().numpy()
            pred_np  = pred.cpu().numpy()

            return {
                "roc_auc":   roc_auc_score(true_np, probs_np),
                "accuracy":  accuracy_score(true_np, pred_np),
                "precision": precision_score(true_np, pred_np),
                "recall":    recall_score(true_np, pred_np),
                "f1":        f1_score(true_np, pred_np),
            }

    def evaluate_validation(self) -> Dict[str, float]:
        self.eval()
        with torch.no_grad():
            mask  = (self.train_mask_np == 2) & (self.response_data != -1)
            pairs = [(i, j)
                     for i in range(self.num_students)
                     for j in range(self.num_items) if mask[i, j]]
            if not pairs:
                return {}

            stud_idx = torch.tensor([i for i, _ in pairs], device=self.device)
            item_idx = torch.tensor([j for _, j in pairs], device=self.device)
            theta_batch = self.theta[stud_idx]
            a, b, c = self.get_parameters()
            dot = torch.sum(a[item_idx] * theta_batch, dim=1)
            z = torch.clamp(dot - b[item_idx], -15, 15)

            if self.use_guessing:
                probs = c[item_idx] + (1 - c[item_idx]) * torch.sigmoid(z)
            else:
                probs = torch.sigmoid(z)
            probs = torch.clamp(probs, self.eps, 1 - self.eps)

            true = torch.tensor([self.response_data[i, j] for i, j in pairs],
                                device=self.device).float()
            pred = (probs > 0.5).float()

            probs_np = probs.cpu().numpy()
            true_np  = true.cpu().numpy()
            pred_np  = pred.cpu().numpy()

            return {
                "roc_auc":   roc_auc_score(true_np, probs_np),
                "accuracy":  accuracy_score(true_np, pred_np),
                "precision": precision_score(true_np, pred_np),
                "recall":    recall_score(true_np, pred_np),
                "f1":        f1_score(true_np, pred_np),
            }

    def predict(self) -> np.ndarray:
        self.eval()
        with torch.no_grad():
            theta = self.theta.unsqueeze(1)  # (S,1,D)
            a, b, c = self.get_parameters()
            dot = torch.sum(a.unsqueeze(0) * theta, dim=2)  # (S,I)
            z = torch.clamp(dot - b.unsqueeze(0), -15, 15)
            if self.use_guessing:
                p = c.unsqueeze(0) + (1 - c).unsqueeze(0) * torch.sigmoid(z)
            else:
                p = torch.sigmoid(z)
            return p.cpu().numpy()

    def compute_item_fisher_information(
        self,
        model: str
    ) -> Dict[str, np.ndarray]:
        if model not in self.student_names:
            raise ValueError(f"Unknown model: {model}")
        student_idx = self.student_names.index(model)

        self.eval()
        with torch.no_grad():
            # θ: (1,1,D)
            theta = self.theta[student_idx].unsqueeze(0).unsqueeze(1)  
            a, b, c = self.get_parameters()   # a:(I,D), b,c:(I,)

            dot   = torch.sum(a.unsqueeze(0) * theta, dim=2)        # (1,I)
            z     = torch.clamp(dot - b.unsqueeze(0), -15, 15)      # (1,I)
            sigma = torch.sigmoid(z)                                # (1,I)
            if self.use_guessing:
                P = c.unsqueeze(0) + (1 - c).unsqueeze(0) * sigma   # (1,I)
            else:
                P = sigma                                          # (1,I)

            num   = (1 - c).unsqueeze(0)**2 * (sigma * (1 - sigma))**2  # (1,I)
            denom = P * (1 - P)                                         # (1,I)
            factor = num / (denom + self.eps)                           # (1,I)

            item_info: Dict[str, np.ndarray] = {}
            D = self.n_dims

            for i, name in enumerate(self.test_names):
                ai    = a[i].unsqueeze(1)        # (D,1)
                outer = ai @ ai.T               # (D,D)
                scalar = factor[0, i]           
                Ii     = scalar * outer         # (D,D)
                item_info[name] = Ii.cpu().numpy()

            return item_info



    def update_single_theta(
        self,
        model_name: str,
        problem_names: List[str],
        lr: float = 1e-3,
        max_epochs: int = 1000,
        patience: int = 50
    ) -> Dict[str, Dict]:
        
        try:
            target_idx = self.student_names.index(model_name)
        except ValueError:
            print(f"Model name {model_name} not found")
            return self.get_estimates()

        backup = {n: p.requires_grad for n, p in self.named_parameters()}
        for n, p in self.named_parameters():
            if n != "theta":
                p.requires_grad = False

        theta_opt = optim.Adam([self.theta], lr=lr)
        best_loss, best_epoch, no_improve = float('inf'), 0, 0

        self.train()
        for epoch in range(1, max_epochs + 1):
            epoch_loss, count = 0.0, 0
            for responses, train_mask, stu_idx in self.dataloader:
                responses  = responses.to(self.device)
                train_mask = train_mask.to(self.device)
                stu_idx     = stu_idx.to(self.device)

                mask_student = (stu_idx == target_idx)
                if not mask_student.any():
                    continue
                responses  = responses[mask_student]
                train_mask = train_mask[mask_student]

                selected = [i for i,n in enumerate(self.test_names) if n in problem_names]
                resp_sub = responses[:, selected]
                mask_sub = train_mask[:, selected] & (resp_sub != -1)
                if not mask_sub.any():
                    continue

                theta_batch = self.theta[target_idx].unsqueeze(0)  # (1, D)
                a, b, c     = self.get_parameters()
                a_sub, b_sub, c_sub = a[selected], b[selected], c[selected]
                dot = torch.sum(theta_batch.unsqueeze(1) * a_sub.unsqueeze(0), dim=2)
                z   = torch.clamp(dot - b_sub.unsqueeze(0), -15, 15)
                if self.use_guessing:
                    p = c_sub.unsqueeze(0) + (1 - c_sub.unsqueeze(0)) * torch.sigmoid(z)
                else:
                    p = torch.sigmoid(z)
                p = torch.clamp(p, self.eps, 1 - self.eps)

                y = resp_sub[mask_sub]
                p_valid = p[mask_sub]
                loss = -(y * torch.log(p_valid) + (1 - y) * torch.log(1 - p_valid)).mean()
                if not torch.isfinite(loss):
                    continue

                theta_opt.zero_grad()
                loss.backward()

                torch.nn.utils.clip_grad_norm_([self.theta], 1.0)
                theta_opt.step()

                epoch_loss += loss.item()
                count      += 1

            if count > 0:
                avg = epoch_loss / count
                if avg < best_loss:
                    best_loss, best_epoch, no_improve = avg, epoch, 0
                else:
                    no_improve += 1
                if no_improve >= patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
                if epoch == 1 or epoch % 100 == 0:
                    print(f"Epoch {epoch}, Loss: {avg:.6f}")

        for n, p in self.named_parameters():
            p.requires_grad = backup[n]

        return self.get_estimates()
