import cvxpy as cp
import numpy as np
import pandas as pd
import scipy as sp
import random
from scipy.stats import mode
from snorkel.labeling.model import LabelModel
import itertools
from numpy.linalg import matrix_rank, svd
from itertools import combinations
from tensor_decomp import mixture_tensor_decomp_full, mse_perm
from scipy.special import logsumexp
from scipy.stats import multivariate_normal

def learn_structure(sigma_O, gamma = 1):
    M = sigma_O.shape[0]
    lam = 4e-3/np.sqrt(M)

    O = 1/2*(sigma_O+sigma_O.T)
    O_root = np.real(sp.linalg.sqrtm(O))

    # low-rank matrix
    L_cvx = cp.Variable([M,M], PSD=True)

    # sparse matrix
    S = cp.Variable([M,M], PSD=True)

    # S-L matrix
    R = cp.Variable([M,M], PSD=True)

    objective = cp.Minimize(0.5*(cp.norm(R @ O_root, 'fro')**2) - cp.trace(R) + lam*(gamma*cp.pnorm(S,1) + cp.norm(L_cvx, "nuc")))
    constraints = [R == S - L_cvx, L_cvx>>0]

    prob = cp.Problem(objective, constraints)
    result = prob.solve(verbose=False)
    opt_error = prob.value

    #extract dependencies
    S_est, L_est = S.value, L_cvx.value
    return S_est, L_est

def learn_structure_debug(sigma_O, gamma=3, solver=cp.SCS, verbose=True, **solver_kwargs):
    """Debug-focused variant of learn_structure that logs solver diagnostics."""
    A = np.asarray(sigma_O)
    M = A.shape[0]
    lam = 4e-3/np.sqrt(M)

    O = 0.5 * (A + A.T)
    print("finite:", np.isfinite(O).all(), "nans:", np.isnan(O).sum(), "||O||_F:", np.linalg.norm(O))
    try:
        eigvals = np.linalg.eigvalsh(0.5 * (O + O.T))
        print("eig(O) min/max:", eigvals.min(), eigvals.max())
    except Exception as exc:
        print("eig error:", exc)

    O_root = np.real(sp.linalg.sqrtm(O))
    if not np.isfinite(O_root).all():
        print("sqrtm produced non-finite entries")

    L_cvx = cp.Variable((M, M), PSD=True)
    S = cp.Variable((M, M), PSD=True)
    R = cp.Variable((M, M), PSD=True)

    objective = cp.Minimize(0.5 * cp.norm(R @ O_root, 'fro') ** 2 - cp.trace(R) + lam * (gamma * cp.pnorm(S, 1) + cp.norm(L_cvx, "nuc")))
    problem = cp.Problem(objective, [R == S - L_cvx, L_cvx >> 0])

    try:
        if solver is None:
            problem.solve(verbose=verbose, **solver_kwargs)
        else:
            problem.solve(verbose=verbose, solver=solver, **solver_kwargs)
    except Exception as exc:
        print("solve exception:", repr(exc))

    print("status:", problem.status, "value:", problem.value)
    stats = problem.solver_stats
    if stats is not None:
        print("solver:", getattr(stats, 'solver_name', None),
              "time:", getattr(stats, 'solve_time', None),
              "iters:", getattr(stats, 'num_iters', None))
        extra = getattr(stats, 'extra_stats', None)
        if extra:
            try:
                print("extra_stats keys:", list(extra.keys()))
            except Exception:
                pass
    print("S is None?", S.value is None, "L is None?", L_cvx.value is None)
    return S.value, L_cvx.value

def decompose_precision(sigma_O, lam_S=1e-2, lam_L=1e-2, enforce_psd=True):
    M = sigma_O.shape[0]
    
    O = 1/2*(sigma_O+sigma_O.T)
    O_root = np.real(sp.linalg.sqrtm(O))

    # low-rank matrix
    L_cvx = cp.Variable([M,M], PSD=True)

    # sparse matrix
    S = cp.Variable([M,M], PSD=True)

    # S-L matrix
    R = cp.Variable([M,M], PSD=True)

    objective = cp.Minimize(0.5*(cp.norm(R @ O_root, 'fro')**2) - cp.trace(R) + lam_S*cp.pnorm(S,1) + lam_L*cp.norm(L_cvx, "nuc"))
    constraints = [R == S - L_cvx, L_cvx>>0]

    prob = cp.Problem(objective, constraints)
    result = prob.solve(verbose=False)
    opt_error = prob.value

    #extract dependencies
    S_est, L_est = S.value, L_cvx.value
    return S_est, L_est

def get_weights(L_est, threshold = 0):
    eigenval, eigenvec = np.linalg.eig(L_est)
    idx = np.argsort(eigenval)[::-1]
    eigenval, eigenvec = eigenval[idx], eigenvec[:, idx]
    
     # 2. take the leading factor (assumed = Q) as starting weights
    weights = eigenvec[:, 0] * np.sqrt(eigenval[0])

    # # 3. Gram–Schmidt: subtract the projection on every retained spurious factor
    # for i in range(1, len(eigenval)):
    #     if eigenval[i] < threshold:        # skip small / negative factors
    #         continue
    #     spurious = eigenvec[:, i] * np.sqrt(eigenval[i])
    #     alpha = np.dot(spurious, weights) / np.dot(spurious, spurious)
    #     weights = weights - alpha * spurious

    # weights = weights / np.linalg.norm(weights)
    
    # Weird rule that works better. Sad
    for i in range(1, len(eigenval)):
        if eigenval[i] < 0:        # skip small / negative factors
            continue
        spurious = eigenvec[:, i] * np.sqrt(eigenval[i])
        weights = weights - spurious

    return weights

def majority_vote(df):
    # mode() returns a DataFrame: one column per tied mode,
    # so we just grab the first
    majority_votes = df.mode(axis=1)[0]
    return majority_votes

def ws_aggregate(df, seed=123, n_epochs=1000, decimals=6):
    """Snorkel label model aggregation with on-the-fly encoding."""
    encoded_df, inverse_mapping = encode_for_label_models(df, decimals=decimals)
    ws_indices = run_label_model(encoded_df, seed=seed, n_epochs=n_epochs)
    return np.array([inverse_mapping.get(idx, np.nan) for idx in ws_indices], dtype=float)

def caresl_aggregate(df, gamma=1.0):
    """CARE-SL aggregation with optional gamma tuning.

    Args:
        df: DataFrame of judge scores.
        gamma: Regularization strength passed into the sparse+low-rank decomposition.
    """
    corr_matrix = df.corr()  # Compute correlation matrix
    S_est, L_est = learn_structure(corr_matrix, gamma=gamma)  # Sparse + low-rank decomposition
    weights = get_weights(L_est)  # Weights from latent structure
    weighted_avg = sum(weights[i] * df.iloc[:, i].to_numpy() for i in range(len(weights))) / sum(weights)
    return weighted_avg

class UWS:
    def __init__(self, n_voters: int, dim=1):
        """
        Initializes the Smoothie class.

        Args:
            n_voters (int): number of generators. This can be the number of models or the number of prompts.
            dim (int): dimension of the embeddings
        """
        self.n_voters = n_voters
        self.dim = dim
        self.theta = np.ones(n_voters)


    def fit(self, lambda_arr: np.ndarray):
        """
        Fits weights using triplet method.

        Args:
            lambda (np.ndarray): embeddings from noisy voters. Has shape (n_samples, n_voters, dim)

        """
        n_samples, n_voters = lambda_arr.shape
        dim = self.dim

        diff = np.zeros(n_voters)  # E[||\lambda_i - y||^2]
        for i in range(n_voters):
            # Consider all other voters and select two at random
            other_idxs = np.delete(np.arange(n_voters), i)
            # Generate all unique pairs of indices
            rows, cols = np.triu_indices(len(other_idxs), k=1)
            pairs = np.vstack((other_idxs[rows], other_idxs[cols])).T

            index_diffs = []
            for j, k in pairs:
                index_diffs.append(
                    triplet(
                        lambda_arr[:, i], lambda_arr[:, j], lambda_arr[:, k]
                    )
                )

            # Set the difference to the average of all the differences
            diff[i] = np.mean(index_diffs)

        # Convert to cannonical parameters
        self.theta = dim / (2 * diff)
        self.theta = self.theta / self.theta.sum()


    def predict(self, lambda_arr: np.ndarray):
        """
        Predicts the true embedding using the weights

        Args:
            lambda_arr (np.ndarray): embeddings from noisy voters. Has shape (n_voters, dim)

        Returns:
            y_pred (np.ndarray): predicted true embedding. Has shape (dim)
        """
        predicted_y = 1 / self.theta.sum() * lambda_arr.dot(self.theta)
        return predicted_y


def triplet(i_arr: np.ndarray, j_arr: np.ndarray, k_arr: np.ndarray):
    """
    Applies triplet method to compute the difference between three voters

    Args:
        i_arr (np.ndarray): embeddings from voter i. Has shape (n_samples, dim)
        j_arr (np.ndarray): embeddings from voter j. Has shape (n_samples, dim)
        k_arr (np.ndarray): embeddings from voter k. Has shape (n_samples, dim)

    Returns:
        diff (float): difference between the three voters
    """
    diff_ij = (np.linalg.norm(i_arr - j_arr, ord=2) ** 2).mean()
    diff_ik = (np.linalg.norm(i_arr - k_arr, ord=2) ** 2).mean()
    diff_jk = (np.linalg.norm(j_arr - k_arr, ord=2) ** 2).mean()
    return 0.5 * (diff_ij + diff_ik - diff_jk)

class ContinuousLabelModel():
    def __init__(self, use_triplets=True):
        self.use_triplets = use_triplets  # only choice right now

    def fit(self, L_train, var_Y, median=True, seed=10):
        self.n, self.m = L_train.shape
        n, m = self.n, self.m
        self.O = np.transpose(L_train) @ L_train / self.n
        self.Sigma_hat = np.zeros([m + 1, m + 1])
        self.Sigma_hat[:m, :m] = self.O

        random.seed(seed)

        if median:
            # Init dict to collect accuracies in triplets
            acc_collection = {}
            for i in range(m):
                acc_collection[i] = []

            # Collect triplet results
            for i in range(m):
                for j in range(i+1, m):
                    for k in range(j+1, m):
                        acc_i = np.sqrt(self.O[i, j] * self.O[i, k] * var_Y / self.O[j, k])
                        acc_j = np.sqrt(self.O[j, i] * self.O[j, k] * var_Y / self.O[i, k])
                        acc_k = np.sqrt(self.O[k, i] * self.O[k, j] * var_Y / self.O[i, j])
                        acc_collection[i].append(acc_i)
                        acc_collection[j].append(acc_j)
                        acc_collection[k].append(acc_k)

            # Take medians
            for i in range(m):
                self.Sigma_hat[i, m] = np.median(acc_collection[i])
                self.Sigma_hat[m, i] = np.median(acc_collection[i])
        else:
            for i in range(m):
                idxes = set(range(m))
                idxes.remove(i)
                # triplet is now i,j,k
                [j, k] = random.sample(idxes, 2)
                # solve from triplet using conditional independence
                acc = np.sqrt(self.O[i, j] * self.O[i, k] * var_Y / self.O[j, k])
                self.Sigma_hat[i, m] = acc
                self.Sigma_hat[m, i] = acc

        # we filled in all but the right-bottom corner, add it in
        self.Sigma_hat[m, m] = var_Y
        return

    def predict(self, L):
        n, m = self.n, self.m
        self.Y_hat = np.zeros(self.n)
        for i in range(self.n):
            self.Y_hat[i] = np.expand_dims(self.Sigma_hat[m, :m], axis=0) \
                            @ np.linalg.inv(self.Sigma_hat[:m, :m]) \
                            @ np.expand_dims(L[i, :self.m], axis=1)
        return self.Y_hat

    def score(self, Y_samples, metric="mse"):
        err = 0
        for i in range(self.n):
            err += (Y_samples[i] - self.Y_hat[i]) ** 2
        return err / self.n

def uws_aggregate(df):
    n_voters = df.shape[1]
    uws = UWS(n_voters)
    uws.fit(df.to_numpy())
    return uws.predict(df.to_numpy())

# def uws_aggregate(df, mean_est, var_est):
#     n_voters = df.shape[1]
#     df_array = df.to_numpy()
#     normalized_df_array = (df_array - mean_est) / np.sqrt(var_est)
#     clm = ContinuousLabelModel()
#     clm.fit(normalized_df_array, var_Y=1)
#     normalized_pred = clm.predict(normalized_df_array)
#     unnormalized_pred = normalized_pred * np.sqrt(var_est) + mean_est
#     return unnormalized_pred

# def uws_aggregate(df, mean_est, var_est):
#     n_voters = df.shape[1]
#     df_array = df.to_numpy()
#     df_array = df_array
#     clm = ContinuousLabelModel()
#     clm.fit(df_array, var_Y=var_est)
#     normalized_pred = clm.predict(df_array)
#     unnormalized_pred = normalized_pred + mean_est
#     return unnormalized_pred

# function for align columns
def find_best_permutation(mu_hat, mu_true):
    k = mu_hat.shape[1]
    best_perm = None
    best_cost = np.inf

    for perm in itertools.permutations(range(k)):
        # total cost for this alignment
        cost = 0.0
        for j in range(k):
            cost += np.linalg.norm(mu_hat[:, perm[j]] - mu_true[:, j])
        if cost < best_cost:
            best_cost = cost
            best_perm = perm
    return list(best_perm)

def assert_dependency(mu_full, indep):
    for j, latent_list in indep.items():          # loop over all judges we modify
        if 'Q' in latent_list:
            # make the two rows that differ only in Q identical
            mu_full[1, j] = mu_full[0, j]           # (C=0) rows: (0,0) -> (0,1)
            mu_full[3, j] = mu_full[2, j]           # (C=1) rows: (1,0) -> (1,1)
        if 'C' in latent_list:
            # make the two rows that differ only in C identical
            mu_full[2, j] = mu_full[0, j]           # (Q=0) rows: (0,0) -> (1,0)
            mu_full[3, j] = mu_full[1, j]           # (Q=1) rows: (0,1) -> (1,1)

    return mu_full

# def for generating conditional mean given number of judges per group
def generate_mu_views(g, 
                      indep_structure=[
                        {0: ['Q'], 1: ['C']},
                        {0: ['C']},
                        {2: ['C'], 3: ['Q']}
                        ]):
    """
    Return a list [μ¹, μ², μ³] with shape (4, g) each,
    satisfying the three independence patterns you gave.
    Works for any g ≥ 4.
    """
    
    views = []
    thresh = 1.0                       # minimum singular value you want
    for struct in indep_structure:
        while True:
            M = np.random.uniform(1, 4, size=(4, g))   # each rows represent one combination of (C,Q)
            if matrix_rank(M) < 4:
                continue

            M = assert_dependency(M, struct)           # enforce independencies

            # use singular values (rectangular-safe) or eigenvalues of MMᵀ
            s_min = svd(M, compute_uv=False)[-1]       # smallest σ_i ≥ 0
            # alternatively: s_min = np.sqrt(eigvalsh(M @ M.T).min())

            if s_min >= thresh:
                views.append(M)
                break
    return views


def find_three_groups(S: np.ndarray, threshold: float):
    """
    Partition n judges into 3 as-even-as-possible groups
    whose between-group correlations |S[i,j]| <= threshold.
    
    Returns three lists of indices (g1, g2, g3), or None if no partition found.
    """
    n = S.shape[0]
    if n < 3:
        raise ValueError("Need at least 3 judges to form three groups.")

    # 1) determine group sizes
    base = n // 3
    rem  = n - 3*base
    sizes = [base, base, base + rem]   # last group gets the extra remainder

    idx = list(range(n))
    s1, s2, s3 = sizes

    # 2) brute-force search (first pick g1, then g2, g3 is the rest)
    for g1 in combinations(idx, s1):
        set_g1 = set(g1)
        rem1 = [i for i in idx if i not in set_g1]
        for g2 in combinations(rem1, s2):
            set_g2 = set(g2)
            g3 = [i for i in rem1 if i not in set_g2]
            if len(g3) != s3:
                continue

            # 3) check independence
            ok12 = all(abs(S[i,j]) <= threshold for i in g1 for j in g2)
            ok13 = all(abs(S[i,j]) <= threshold for i in g1 for j in g3)
            ok23 = all(abs(S[i,j]) <= threshold for i in g2 for j in g3)
            if ok12 and ok13 and ok23:
                return list(g1), list(g2), list(g3)
    return None

# # if we do not know a good threshold, just use the minimum val that gives us a threshold
# def find_three_groups_auto(S: np.ndarray):
#     """
#     Finds three disjoint 4-sets (g1,g2,g3) whose maximum
#     cross-block |S[i,j]| is as small as possible.

#     Returns
#     -------
#     (g1, g2, g3, threshold)
#       g1, g2, g3 : lists of 4 indices each
#       threshold : the value max|S[i,j]| for that partition
#     """
#     n = S.shape[0]
#     if n < 12:
#         raise ValueError("Need at least 12 judges")

#     all_quads = list(combinations(range(n), 4))
#     best_t = np.inf
#     best_groups = None

#     for g1, g2, g3 in combinations(all_quads, 3):
#         # skip if they overlap
#         if set(g1) & set(g2) or set(g1) & set(g3) or set(g2) & set(g3):
#             continue

#         # compute the worst cross-block correlation
#         t = 0.0
#         for A, B in ((g1, g2), (g1, g3), (g2, g3)):
#             # flatten the max over all i in A, j in B
#             t = max(t, np.max(np.abs(S[np.ix_(A, B)])))

#         # keep the triple with the smallest t
#         if t < best_t:
#             best_t = t
#             best_groups = (list(g1), list(g2), list(g3))

#             # optional early exit if you hit zero
#             if best_t == 0:
#                 break

#     if best_groups is None:
#         return None  # no valid partition found

#     g1, g2, g3 = best_groups
#     return g1, g2, g3, best_t


def find_three_groups_auto(S: np.ndarray):
    """
    Finds three disjoint groups of size ⌊n/3⌋, ⌊n/3⌋, and n - 2*⌊n/3⌋
    whose maximum cross-block |S[i,j]| is as small as possible.
    
    Uses binary search on threshold with connected components and DP.
    Much faster than exhaustive search while remaining exact.

    Returns
    -------
    (g1, g2, g3, threshold)
      g1, g2, g3 : lists of indices (sizes s1,s2,s3)
      threshold : the minimal possible max|S[i,j]| between any two groups
    """
    n = S.shape[0]
    if n < 3:
        raise ValueError("Need at least 3 judges to form three groups.")
    if S.shape[0] != S.shape[1]:
        raise ValueError("S must be square.")

    # Determine group sizes
    base = n // 3
    rem  = n - 3*base
    s1, s2, s3 = base, base, base + rem

    # Use symmetric absolute similarities, zero diagonal
    absS = np.maximum(np.abs(S), np.abs(S.T)).copy()
    np.fill_diagonal(absS, 0.0)

    # Get unique threshold candidates (upper triangular values)
    triu_indices = np.triu_indices(n, 1)
    candidates = np.unique(absS[triu_indices])
    
    # Binary search for minimal feasible threshold
    left, right = -1, len(candidates) - 1
    
    # Check if largest threshold is feasible (should always be true)
    if not _is_feasible_partition(absS, candidates[right], s1, s2, s3):
        return None
    
    # Binary search for minimal feasible threshold
    while right - left > 1:
        mid = (left + right) // 2
        if _is_feasible_partition(absS, candidates[mid], s1, s2, s3):
            right = mid
        else:
            left = mid
    
    # Reconstruct the actual partition
    groups = _find_partition(absS, candidates[right], s1, s2, s3)
    if groups is None:
        return None
        
    g1, g2, g3 = groups
    return g1, g2, g3, float(candidates[right])

def _is_feasible_partition(absS, threshold, s1, s2, s3):
    """
    Check if a partition with given group sizes is feasible for threshold.
    Returns True if feasible, False otherwise.
    """
    n = absS.shape[0]
    
    # Create adjacency matrix: edges where |S[i,j]| > threshold must stay together
    adj = absS > threshold
    
    # Find connected components
    components = _find_connected_components(adj)
    
    # Check if components can be partitioned into required sizes
    return _can_partition_components(components, s1, s2, s3)


def _find_partition(absS, threshold, s1, s2, s3):
    """
    Find the actual partition for the given threshold.
    Returns (g1, g2, g3) or None if not feasible.
    """
    n = absS.shape[0]
    
    # Create adjacency matrix
    adj = absS > threshold
    
    # Find connected components
    components = _find_connected_components(adj)
    
    # Try to partition components into required sizes
    partition = _partition_components(components, s1, s2, s3)
    if partition is None:
        return None
    
    # Map component assignments back to node indices
    g1, g2, g3 = [], [], []
    for i, comp in enumerate(components):
        if partition[i] == 0:
            g1.extend(comp)
        elif partition[i] == 1:
            g2.extend(comp)
        else:
            g3.extend(comp)
    
    return sorted(g1), sorted(g2), sorted(g3)


def _find_connected_components(adj):
    """Find connected components using BFS."""
    from collections import deque
    
    n = adj.shape[0]
    visited = np.zeros(n, dtype=bool)
    components = []
    
    for i in range(n):
        if not visited[i]:
            # BFS from node i
            queue = deque([i])
            visited[i] = True
            component = [i]
            
            while queue:
                u = queue.popleft()
                # Find neighbors
                neighbors = np.where(adj[u])[0]
                for v in neighbors:
                    if not visited[v]:
                        visited[v] = True
                        queue.append(v)
                        component.append(v)
            
            components.append(component)
    
    return components


def _can_partition_components(components, s1, s2, s3):
    """Check if components can be partitioned into required sizes using DP."""
    sizes = [len(comp) for comp in components]
    total = sum(sizes)
    
    # Quick check
    if total != s1 + s2 + s3 or any(size > max(s1, s2, s3) for size in sizes):
        return False
    
    # DP: dp[i][j][k] = can we achieve size j in group1 and size k in group2 using first i components?
    m = len(sizes)
    
    dp = np.zeros((m + 1, s1 + 1, s2 + 1), dtype=bool)
    dp[0][0][0] = True
    
    for i in range(1, m + 1):
        size = sizes[i - 1]
        for j in range(s1 + 1):
            for k in range(s2 + 1):
                if dp[i-1][j][k]:
                    # Place in group1
                    if j + size <= s1:
                        dp[i][j + size][k] = True
                    # Place in group2  
                    if k + size <= s2:
                        dp[i][j][k + size] = True
                    # Place in group3
                    dp[i][j][k] = True
    
    return dp[m][s1][s2]


def _partition_components(components, s1, s2, s3):
    """Find actual partition of components into groups using DP backtracking."""
    sizes = [len(comp) for comp in components]
    m = len(sizes)
    
    # DP table
    dp = np.zeros((m + 1, s1 + 1, s2 + 1), dtype=bool)
    dp[0][0][0] = True
    
    for i in range(1, m + 1):
        size = sizes[i - 1]
        for j in range(s1 + 1):
            for k in range(s2 + 1):
                if dp[i-1][j][k]:
                    if j + size <= s1:
                        dp[i][j + size][k] = True
                    if k + size <= s2:
                        dp[i][j][k + size] = True
                    dp[i][j][k] = True
    
    if not dp[m][s1][s2]:
        return None
    
    # Backtrack to find actual assignment
    assignment = [0] * m
    j, k = s1, s2
    
    for i in range(m, 0, -1):
        size = sizes[i - 1]
        
        # Try group1 first
        if j >= size and dp[i-1][j - size][k]:
            assignment[i-1] = 0
            j -= size
        # Then group2
        elif k >= size and dp[i-1][j][k - size]:
            assignment[i-1] = 1
            k -= size
        # Finally group3
        else:
            assignment[i-1] = 2
    
    return assignment


def decompose_covariance(Sigma_hat, lam_S=1e-2, lam_L=1e-2, enforce_psd=True):
    m = Sigma_hat.shape[0]
    S = cp.Variable((m, m))                      # sparse part  Σ̂-support
    L = cp.Variable((m, m), PSD=enforce_psd)     # low-rank   mixture means

    loss = 0.5 * cp.norm(S + L - Sigma_hat, 'fro')**2
    reg  = lam_S * cp.norm1(S) + lam_L * cp.normNuc(L)
    cp.Problem(cp.Minimize(loss + reg)).solve()
    return S.value, L.value


def _gaussian_block_logpdf_batch(Xv, mu_vk, cov):
    """
    Vectorized multivariate Gaussian log-density for all components in one view.
    """
    n_samples, d_view = Xv.shape
    try:
        factor, lower = sp.linalg.cho_factor(cov, lower=True, check_finite=False)
    except np.linalg.LinAlgError:
        # add progressively larger ridge if needed for numerical stability
        jitter = 1e-9
        while True:
            try:
                factor, lower = sp.linalg.cho_factor(
                    cov + jitter * np.eye(d_view), lower=True, check_finite=False
                )
                break
            except np.linalg.LinAlgError:
                jitter *= 10
    log_det = 2.0 * np.sum(np.log(np.diag(factor)))
    diff = Xv[:, None, :] - mu_vk.T[None, :, :]
    flat = diff.reshape(-1, d_view)
    sol = sp.linalg.cho_solve((factor, lower), flat.T, check_finite=False).T
    quad = np.sum(flat * sol, axis=1).reshape(n_samples, -1)
    return (-0.5 * (d_view * np.log(2.0 * np.pi) + log_det + quad)).T


def caret_aggregate(J, lam_S=0.1, lam_L=0.001, class_balance=50, ranks=(2,3,4)):
    # 1) precision S+L and group split
    # J = rank_gaussianize_cols(J)
    Sigma_hat = np.cov(J.T, bias=False)
    S_hat, L_hat = decompose_precision(Sigma_hat, lam_S=lam_S, lam_L=lam_L)
    G1, G2, G3, _ = find_three_groups_auto(S_hat)
    X1, X2, X3 = J[:, G1], J[:, G2], J[:, G3]
    n, p = J.shape
    d1, d2, d3 = X1.shape[1], X2.shape[1], X3.shape[1]

    # 2) empirical centered tri-view tensor (target)
    T_emp = np.einsum("ni,nj,nk->ijk", X1, X2, X3) / n

    # 3) pick rank by reconstruction MSE to empirical tensor
    best = None
    for K in ranks:
        w_rec, mu1, mu2, mu3 = mixture_tensor_decomp_full(
            w=np.ones(n)/n, x1=X1.T, x2=X2.T, x3=X3.T, k=K, debug=False
        )
        T_hat = np.einsum("i,ji,ki,li->jkl", w_rec, mu1, mu2, mu3)
        err = mse_perm(T_emp, T_hat, return_perm=False)
        print(err)
        if (best is None) or (err < best["err"]):
            best = {"K": K, "w": w_rec, "mu1": mu1, "mu2": mu2, "mu3": mu3, "err": err}

    K = best["K"]
    w_rec = best["w"]
    mu_hat_1, mu_hat_2, mu_hat_3 = best["mu1"], best["mu2"], best["mu3"]
    R = K

    # 4) per-view covariances (tiny ridge)
    ridge = 1e-6
    Sigma1_hat = np.cov(X1, rowvar=False) + ridge * np.eye(d1)
    Sigma2_hat = np.cov(X2, rowvar=False) + ridge * np.eye(d2)
    Sigma3_hat = np.cov(X3, rowvar=False) + ridge * np.eye(d3)

    # 5) component posteriors (all components j in 0..R-1)
    def comp_ll(j):
        return ( np.log(w_rec[j] + 1e-12)
               + multivariate_normal.logpdf(X1, mean=mu_hat_1.T[j], cov=Sigma1_hat, allow_singular=True)
               + multivariate_normal.logpdf(X2, mean=mu_hat_2.T[j], cov=Sigma2_hat, allow_singular=True)
               + multivariate_normal.logpdf(X3, mean=mu_hat_3.T[j], cov=Sigma3_hat, allow_singular=True) )
    log_like = np.vstack([comp_ll(r) for r in range(R)])              # (R,n)
    post = np.exp(log_like - logsumexp(log_like, axis=0, keepdims=True))

    # 6) map components -> Q=1 via generalized eig of (L_hat, S_hat)
    eps = 1e-6
    # evals, evecs = linalg.eigh(L_hat + eps*np.eye(p), S_hat + eps*np.eye(p))
    # v = evecs[:, -1].real
    evals, evecs = np.linalg.eigh(L_hat)
    v = evecs[:, np.argmax(evals.real)].real
    if v.sum() < 0: v = -v

    mu_full = np.zeros((p, R))
    mu_full[G1, :], mu_full[G2, :], mu_full[G3, :] = mu_hat_1, mu_hat_2, mu_hat_3
    scores = v @ mu_full  # (R,)

    # pick Q=1 components by median split (works for any R)
    thr = np.median(scores)
    q1 = np.where(scores >= thr)[0]
    if q1.size == 0: q1 = np.array([int(np.argmax(scores))])
    if q1.size == R: q1 = np.array([int(np.argmax(scores))])

    # 7) posterior for Q=1 and hard labels
    p_hat = post[q1, :].sum(axis=0)
    cutoff = np.percentile(p_hat, class_balance)
    y_pred = (p_hat > cutoff).astype(int)
    return y_pred


def caret_aggregate_fast(J, lam_S=0.1, lam_L=0.001, class_balance=50, ranks=(2,3,4)):
    """
    Vectorized counterpart of `caret_aggregate` that batches Gaussian log-likelihoods.
    """
    Sigma_hat = np.cov(J.T, bias=False)
    S_hat, L_hat = decompose_precision(Sigma_hat, lam_S=lam_S, lam_L=lam_L)
    G1, G2, G3, _ = find_three_groups_auto(S_hat)
    X1, X2, X3 = J[:, G1], J[:, G2], J[:, G3]
    n, p = J.shape
    d1, d2, d3 = X1.shape[1], X2.shape[1], X3.shape[1]

    T_emp = np.einsum("ni,nj,nk->ijk", X1, X2, X3) / n

    best = None
    for K in ranks:
        w_rec, mu1, mu2, mu3 = mixture_tensor_decomp_full(
            w=np.ones(n)/n, x1=X1.T, x2=X2.T, x3=X3.T, k=K, debug=False
        )
        T_hat = np.einsum("i,ji,ki,li->jkl", w_rec, mu1, mu2, mu3)
        err = mse_perm(T_emp, T_hat, return_perm=False)
        print(err)
        if (best is None) or (err < best["err"]):
            best = {"K": K, "w": w_rec, "mu1": mu1, "mu2": mu2, "mu3": mu3, "err": err}

    K = best["K"]
    w_rec = best["w"]
    mu_hat_1, mu_hat_2, mu_hat_3 = best["mu1"], best["mu2"], best["mu3"]
    R = K

    ridge = 1e-6
    Sigma1_hat = np.cov(X1, rowvar=False) + ridge * np.eye(d1)
    Sigma2_hat = np.cov(X2, rowvar=False) + ridge * np.eye(d2)
    Sigma3_hat = np.cov(X3, rowvar=False) + ridge * np.eye(d3)

    log_weights = np.log(w_rec + 1e-12)[:, None]
    log_like = log_weights + _gaussian_block_logpdf_batch(X1, mu_hat_1, Sigma1_hat)
    log_like += _gaussian_block_logpdf_batch(X2, mu_hat_2, Sigma2_hat)
    log_like += _gaussian_block_logpdf_batch(X3, mu_hat_3, Sigma3_hat)
    post = np.exp(log_like - logsumexp(log_like, axis=0, keepdims=True))

    evals, evecs = np.linalg.eigh(L_hat)
    v = evecs[:, np.argmax(evals.real)].real
    if v.sum() < 0:
        v = -v

    mu_full = np.zeros((p, R))
    mu_full[G1, :], mu_full[G2, :], mu_full[G3, :] = mu_hat_1, mu_hat_2, mu_hat_3
    scores = v @ mu_full

    thr = np.median(scores)
    q1 = np.where(scores >= thr)[0]
    if q1.size == 0:
        q1 = np.array([int(np.argmax(scores))])
    if q1.size == R:
        q1 = np.array([int(np.argmax(scores))])

    p_hat = post[q1, :].sum(axis=0)
    cutoff = np.percentile(p_hat, class_balance)
    y_pred = (p_hat > cutoff).astype(int)
    return y_pred


def caret_aggregate_binary(
    J: np.ndarray,
    lam_S: float = 0.1,
    lam_L: float = 0.001,
    ranks = (2, 3, 4),
    clip_eps: float = 1e-6,
    map_rule: str = "median",   # "median" | "top1" | "sign"
    threshold: float | None = 0.5,   # if None, will use class_balance
    class_balance: float | None = 50,  # percentile of positives to enforce (e.g., 50)
    return_all: bool = True,
):
    """
    CARET (binary): S+L -> tri-view tensor -> Bernoulli-mixture posteriors -> map components to Q=1.

    Args:
        J: (n, p) binary matrix of judge scores (0/1). Filter NaNs beforehand.
        lam_S, lam_L: penalties for sparse+low-rank precision split.
        ranks: candidate mixture ranks K for CP decomposition.
        clip_eps: clamps mu to [eps, 1-eps] in Bernoulli log-likelihood.
        map_rule:
            - "median": components with v^T mu >= median(scores) are Q=1 (fallback to top1 if degenerate)
            - "top1": only the highest–score component is Q=1
            - "sign": components with v^T mu > 0 are Q=1 (fallback to top1 if none)
        threshold: probability cutoff for label; if None, use class_balance.
        class_balance: if set (0–100), pick cutoff so that that percentile is positive (enforces class ratio).
        return_all: if True, also returns diagnostics.

    Returns:
        y_pred: (n,) hard labels in {0,1}
        p_hat:  (n,) estimated P(Q=1 | J)
        extras: dict with internals (only if return_all=True)
    """
    J = np.asarray(J, dtype=float)
    n, p = J.shape

    # 1) precision S+L and grouping
    Sigma_hat = np.cov(J.T, bias=False)
    S_hat, L_hat = decompose_precision(Sigma_hat, lam_S=lam_S, lam_L=lam_L)
    G1, G2, G3, _ = find_three_groups_auto(S_hat)
    X1, X2, X3 = J[:, G1], J[:, G2], J[:, G3]
    d1, d2, d3 = X1.shape[1], X2.shape[1], X3.shape[1]

    # 2) tri-view moment (un-centered; for binary this equals joint probs of "1")
    T_emp = np.einsum("ni,nj,nk->ijk", X1, X2, X3) / n  # (d1,d2,d3)

    # 3) choose K by tensor reconstruction error
    best = None
    for K in ranks:
        w_rec, mu1, mu2, mu3 = mixture_tensor_decomp_full(
            w=np.ones(n)/n, x1=X1.T, x2=X2.T, x3=X3.T, k=K, debug=False
        )
        T_hat = np.einsum("i,ji,ki,li->jkl", w_rec, mu1, mu2, mu3)
        err = mse_perm(T_emp, T_hat, return_perm=False)
        if (best is None) or (err < best["err"]):
            best = {"K": K, "w": w_rec, "mu1": mu1, "mu2": mu2, "mu3": mu3, "err": err}

    K = best["K"]
    w_rec = best["w"]
    mu1, mu2, mu3 = best["mu1"], best["mu2"], best["mu3"]   # shapes: (d_v, K)

    # 4) product-Bernoulli mixture posteriors
    def bernoulli_block_ll(Xv, mu_vk):
        mu_vk = np.clip(mu_vk, clip_eps, 1 - clip_eps)      # (d_v,)
        return (Xv * np.log(mu_vk) + (1 - Xv) * np.log(1 - mu_vk)).sum(axis=1)

    comp_ll = []
    for j in range(K):
        ll = np.log(w_rec[j] + clip_eps)
        ll += bernoulli_block_ll(X1, mu1.T[j])
        ll += bernoulli_block_ll(X2, mu2.T[j])
        ll += bernoulli_block_ll(X3, mu3.T[j])
        comp_ll.append(ll)

    log_like = np.vstack(comp_ll)                           # (K, n)
    post = np.exp(log_like - logsumexp(log_like, axis=0, keepdims=True))  # (K,n)

    # 5) map components -> Q=1 using L's dominant direction
    evals, evecs = np.linalg.eigh(L_hat)
    v = evecs[:, int(np.argmax(evals.real))].real
    if v.sum() < 0: v = -v

    mu_full = np.zeros((p, K))
    mu_full[G1, :], mu_full[G2, :], mu_full[G3, :] = mu1, mu2, mu3  # restore original judge order
    scores = v @ mu_full                                            # (K,)

    if map_rule == "median":
        thr = np.median(scores)
        q1 = np.where(scores >= thr)[0]
        if (q1.size == 0) or (q1.size == K):
            q1 = np.array([int(np.argmax(scores))])                 # fallback
    elif map_rule == "sign":
        q1 = np.where(scores > 0)[0]
        if q1.size == 0: q1 = np.array([int(np.argmax(scores))])
    else:  # "top1"
        q1 = np.array([int(np.argmax(scores))])

    # 6) probabilities & hard labels
    p_hat = post[q1, :].sum(axis=0)                                  # (n,)

    if class_balance is not None:
        cutoff = np.percentile(p_hat, class_balance)
    else:
        cutoff = float(threshold)

    y_pred = (p_hat > cutoff).astype(int)

    if return_all:
        extras = dict(
            K=K, weights=w_rec, mu=(mu1, mu2, mu3), groups=(G1, G2, G3),
            S_hat=S_hat, L_hat=L_hat, comp_scores=scores, q1_idx=q1,
            tensor_err=best["err"], post=post
        )
        return y_pred, p_hat, extras
    return y_pred, p_hat

def encode_for_label_models(df, decimals=6):
    """Discretize continuous judge outputs for label-model training."""
    rounded = df.round(decimals)
    unique_values = pd.unique(rounded.values.ravel())
    unique_values = [val for val in unique_values if not pd.isna(val)]
    if not unique_values:
        raise ValueError('No valid judge scores to encode.')
    unique_values = np.array(unique_values, dtype=float)
    unique_values.sort()
    mapping = {val: idx for idx, val in enumerate(unique_values)}
    encoded = rounded.replace(mapping).astype(int)
    inverse_mapping = {idx: val for val, idx in mapping.items()}
    return encoded, inverse_mapping


def run_label_model(encoded_df, seed=123, n_epochs=1000):
    """Train a Snorkel label model and return hard predictions."""
    labels = encoded_df.to_numpy().astype(int)
    cardinality = int(labels.max()) + 1
    model = LabelModel(cardinality=cardinality, verbose=False)
    model.fit(L_train=labels, n_epochs=n_epochs, log_freq=100, seed=seed)
    return model.predict(labels, tie_break_policy='random').astype(int)


def latent_labels_from_weights(n, w):
    # row order: (C,Q) = (0,0),(0,1),(1,0),(1,1)
    counts = [int(n*wi) for wi in w]
    # minor fix to hit exact n if rounding lops off/over by 1-2:
    while sum(counts) < n: counts[np.argmin(counts)] += 1
    while sum(counts) > n: counts[np.argmax(counts)] -= 1

    C = np.concatenate([np.zeros(counts[0],dtype=int),
                        np.zeros(counts[1],dtype=int),
                        np.ones( counts[2],dtype=int),
                        np.ones( counts[3],dtype=int)])
    Q = np.concatenate([np.zeros(counts[0],dtype=int),
                        np.ones( counts[1],dtype=int),
                        np.zeros(counts[2],dtype=int),
                        np.ones( counts[3],dtype=int)])
    return C, Q
