import numpy as np
from scipy.special import logsumexp

class PairwiseConflictEM:
    """
    用两两冲突矩阵 C 估计每个文档“正确”的后验概率 y_i 的变分-EM实现。
    - C 矩阵被二值化为 B 矩阵
    - TT (两真) 现在服从 Bernoulli(p_L) ,低冲突族 L (期望 p_L 很小)
    - TF/FT/FF_H 服从 Bernoulli(p_H) ,高冲突族 H (期望 p_H 很大)
    """

    def __init__(self,
                 prior_mu=0.65, prior_kappa=4.0,
                 mode='variational',     # 'variational' 或 'map'
                 max_iter=200, tol=1e-4, damping=0.5,
                 eps=1e-9, init_params=None,
                 normalize=False, seed=None):
        assert 0.5 < prior_mu < 1.0, "prior_mu 应 > 0.5（表达‘正确多数’）"
        assert mode in ('variational', 'map')
        self.prior_mu = prior_mu
        self.prior_kappa = prior_kappa
        self.mode = mode
        self.max_iter = max_iter
        self.tol = tol
        self.damping = damping
        self.eps = eps
        self.init_params = init_params or {}
        self.normalize = normalize
        self.rng = np.random.default_rng(seed)
        
        # 用于二值化的阈值 (应用于 0-1 范围内的冲突矩阵 C)
        self.binarize_threshold = self.init_params.get('threshold', 0.5)
        # 用于伯努利分布的初始值
        self.p_L_init = self.init_params.get('p_L', 0.2) # 低冲突族出现高冲突的概率
        self.p_H_init = self.init_params.get('p_H', 0.8) # 高冲突族出现高冲突的概率


        self.k_ = None
        self.q_ = None
        self.params_ = {}
        self.history_ = []

    @staticmethod
    def _sigmoid(x):
        out = np.empty_like(x)
        pos = x >= 0
        neg = ~pos
        out[pos] = 1 / (1 + np.exp(-x[pos]))
        ex = np.exp(x[neg])
        out[neg] = ex / (1 + ex)
        return out
    
    def _scale_to_unit(self, C):
        # 全局 min-max 到 [0,1]（跳过对角线）
        k = C.shape[0]
        mask = ~np.eye(k, dtype=bool)
        vals = C[mask]
        vmin, vmax = np.min(vals), np.max(vals)
        if vmax <= vmin + self.eps:
            D = np.clip(C, 0.0, 1.0)
        else:
            D = (C - vmin) / (vmax - vmin)
        np.fill_diagonal(D, 0.0)
        D = np.clip(D, 0.0, 1.0)
        return D

    def _log_bernoulli_pdf(self, B_matrix, p):
        """
        计算二值化矩阵 B 中每个元素来自 Bernoulli(p) 的 log 似然。
        """
        p = np.clip(p, self.eps, 1 - self.eps)
        log_p = np.log(p)
        log_1_p = np.log(1 - p)
        return B_matrix * log_p + (1 - B_matrix) * log_1_p

    def fit(self, C):
        """
        输入：
          C: (k,k) 矩阵，C_ij ∈ [0,1] 为 i、j 的冲突强度（越大越冲突）
        """
        C = np.asarray(C, dtype=float)
        assert C.ndim == 2 and C.shape[0] == C.shape[1], "C 必须是方阵"
        k = C.shape[0]
        self.k_ = k

        # 对称化 & 对角置 0 & 可选缩放
        C = 0.5 * (C + C.T)
        np.fill_diagonal(C, 0.0)
        
        if self.normalize:
            C = self._scale_to_unit(C) 
            
        # 数据二值化
        B = (C > self.binarize_threshold).astype(float)
        np.fill_diagonal(B, 0.0)

        # 初始化伯努利分布参数
        p_L = float(self.p_L_init)
        p_H = float(self.p_H_init)
        rho = float(self.init_params.get('rho', 0.5))
        rho = np.clip(rho, 1e-3, 1 - 1e-3)

        # Beta 先验（pi）
        alpha_pi = self.prior_mu * self.prior_kappa
        beta_pi  = (1 - self.prior_mu) * self.prior_kappa

        # 初始化 q (保持不变)
        row_mean = np.clip(B.mean(axis=1), 0.0, 1.0) 
        logits0 = -4.0 * (row_mean - row_mean.mean())
        q = self._sigmoid(logits0)
        q = np.clip(0.9 * q + 0.1 * self.rng.random(k), 1e-3, 1 - 1e-3)
        
        history = []
        for it in range(1, self.max_iter + 1):
            # ---- 先验项（pi） ----
            # 由于没有导入 digamma，使用对数近似 (log(a) - log(b) ~ digamma(a) - digamma(b))
            if self.mode == 'variational':
                a_pi = alpha_pi + np.sum(q)
                b_pi = beta_pi  + (k - np.sum(q))
                prior_term = np.log(a_pi) - np.log(b_pi) 
            else:  # MAP
                pi_map = (alpha_pi - 1 + np.sum(q)) / (alpha_pi + beta_pi - 2 + k)
                pi_map = np.clip(pi_map, 1e-3, 1 - 1e-3)
                prior_term = np.log(pi_map) - np.log(1 - pi_map)

            # ---- E 步：预计算 log 密度 (使用伯努利) ----
            ell_L = self._log_bernoulli_pdf(B, p_L)
            ell_H = self._log_bernoulli_pdf(B, p_H)
            ell_00 = logsumexp(
                np.stack([np.log(rho) + ell_H, np.log(1 - rho) + ell_L], axis=0),
                axis=0
            )

            # ---- E 步：更新 q (保持不变) ----
            A = ell_L - ell_H           
            B_term = ell_H - ell_00     
            np.fill_diagonal(A, 0.0)
            np.fill_diagonal(B_term, 0.0)

            s1 = A @ q
            s2 = B_term @ (1.0 - q)
            logits_new = prior_term + s1 + s2
            q_new = self._sigmoid(logits_new)

            # 阻尼与裁剪 (保持不变)
            q = self.damping * q_new + (1 - self.damping) * q
            q = np.clip(q, 1e-5, 1 - 1e-5)

            # ---- M 步：更新参数 ----
            diff = (np.log(rho) + ell_H) - (np.log(1 - rho) + ell_L)
            tau_H = self._sigmoid(diff)   # FF中来自H的责任度
            np.fill_diagonal(tau_H, 0.0)

            q_col = q.reshape(-1, 1)
            q_row = q.reshape(1, -1)
            QQ = q_col @ q_row                 
            Q1 = q_col @ (1 - q_row)           
            Q2 = (1 - q_col) @ q_row           
            FF = (1 - q_col) @ (1 - q_row)     

            W_L = QQ + FF * (1 - tau_H)        # 低冲突桶的加权责任
            W_H = Q1 + Q2 + FF * (tau_H)       # 高冲突桶的加权责任
            for M in (W_L, W_H):
                np.fill_diagonal(M, 0.0)

            triu_mask = np.triu(np.ones_like(B, dtype=bool), k=1)

            # 【核心修改 3：伯努利 M 步更新】
            b_vals = B[triu_mask] 

            # M 步：更新 p_L (低冲突族出现高冲突的概率)
            w_L = W_L[triu_mask]
            wsum_L = np.sum(w_L) + self.eps
            p_L_num = np.sum(w_L * b_vals)
            p_L = np.clip(p_L_num / wsum_L, self.eps, 1 - self.eps)
            
            # M 步：更新 p_H (高冲突族出现高冲突的概率)
            w_H = W_H[triu_mask]
            wsum_H = np.sum(w_H) + self.eps
            p_H_num = np.sum(w_H * b_vals)
            p_H = np.clip(p_H_num / wsum_H, self.eps, 1 - self.eps)

            # rho 的更新 (保持不变)
            FF_w = FF[triu_mask]
            tau_w = tau_H[triu_mask]
            rho_num = np.sum(FF_w * tau_w)
            rho_den = np.sum(FF_w) + self.eps
            rho = float(np.clip(rho_num / rho_den, 1e-3, 1 - 1e-3))

            # 收敛度量 (保持不变)
            delta_q = float(np.max(np.abs(q_new - q)))
            # 【日志更新】
            history.append(dict(iter=it, delta_q=delta_q, rho=float(rho),
                                p_L=float(p_L), p_H=float(p_H)))
            if delta_q < self.tol:
                break

        self.q_ = q
        # 【核心修改 4：输出参数】
        self.params_ = dict(
            p_L=float(p_L), p_H=float(p_H),
            rho=float(rho),
            alpha_pi=float(alpha_pi), beta_pi=float(beta_pi),
            mode=self.mode
        )
        self.history_ = history

        return {'y': q.copy(), 'params': self.params_.copy(), 'history': history}