#abstract class
from abc import ABC, abstractmethod
from typing import Callable
import numpy as np
class MinimumBayesAction(ABC):
    @abstractmethod
    def __call__(self, samples: np.ndarray) -> np.ndarray:
        pass

class HammingMinimumBayesAction(MinimumBayesAction):
    def __call__(self, samples: np.ndarray) -> np.ndarray:
        n_samples, n_bits = samples.shape
        #check that samples are binary
        assert np.all(np.isin(samples, [0, 1])), "Samples must be binary (0 or 1)"
        # Compute the Hamming distance matrix
        probability_per_bit = np.mean(samples, axis=0)
        # Determine the minimum Bayes action for each bit
        mb_action = (probability_per_bit >= 0.5).astype(int)
        return np.array(mb_action)
    

class FBetaMinimumBayesAction(MinimumBayesAction):
    """
    Implements GFM (Algorithm 1) from Waegeman et al. (JMLR 2014),
    with the natural F_beta extension:
        F_beta(y,h) = (1+beta^2) * TP / (beta^2 * s_y + k)
    where s_y = sum_i y_i and k = sum_i h_i.

    https://jmlr.org/papers/volume15/waegeman14a/waegeman14a.pdf

    Input: samples ~ P(Y) as binary vectors of shape (n_samples, m).
    Output: Bayes-optimal h in {0,1}^m maximizing E[F_beta(Y,h)].
    """
    def __init__(self, beta: float = 1.0):
        if beta <= 0:
            raise ValueError("beta must be > 0")
        self.beta = float(beta)

    def __call__(self, samples: np.ndarray) -> np.ndarray:
        samples = np.asarray(samples)
        if samples.ndim != 2:
            raise ValueError("samples must be a 2D array of shape (n_samples, m)")
        if not np.isin(samples, [0, 1]).all():
            raise ValueError("samples must be binary (0/1)")

        n_samples, m = samples.shape
        if m == 0:
            return np.array([], dtype=int)

        # Empirical pmf over distinct label configurations y
        unique_y, counts = np.unique(samples, axis=0, return_counts=True)
        pmf = counts.astype(float) / counts.sum()  # shape (U,)
        s_y = unique_y.sum(axis=1).astype(float)   # shape (U,)

        beta2 = self.beta ** 2
        c = 1.0 + beta2

        # Build Delta matrix: shape (m, m), Delta[i, k-1] = sum_{y: y_i=1} c * P(y) / (beta^2*s_y + k)
        k_vals = np.arange(1, m + 1, dtype=float)[None, :]  # shape (1, m)

        # weights[u, k-1] = c * P(y_u) / (beta^2*s_y_u + k)
        denom = beta2 * s_y[:, None] + k_vals               # shape (U, m)
        weights = (c * pmf[:, None]) / denom                # shape (U, m)

        # Delta = unique_y^T @ weights   (because summing weights over y where y_i=1)
        Delta = unique_y.T @ weights                         # shape (m, m)

        # For k=0 case: h^(0)=0 and E[F_beta(Y,0)] = P(Y=0)
        is_zero = np.all(unique_y == 0, axis=1)
        p_y0 = float(pmf[is_zero][0]) if np.any(is_zero) else 0.0

        best_score = p_y0
        best_h = np.zeros(m, dtype=int)

        # For each k=1..m, choose top-k labels by Delta[:, k-1], score is sum of chosen deltas
        for k in range(1, m + 1):
            col = Delta[:, k - 1]
            # take any top-k in case of ties (paper allows arbitrary tie-breaking)
            topk = np.argpartition(-col, kth=k-1)[:k]
            score = col[topk].sum()

            if score > best_score:
                best_score = score
                h = np.zeros(m, dtype=int)
                h[topk] = 1
                best_h = h

        return best_h

class ArgminMinimumBayesAction(MinimumBayesAction):
    def __init__(self, loss_fn):
        self.loss_fn = loss_fn

    def __call__(self, samples: np.ndarray) -> np.ndarray:
        n_samples, n_bits = samples.shape
        
        n_samples = len(samples)
        if n_samples == 1:
            return samples[0]
        
        # Compute all pairwise neural metric distances in batches
        # For each candidate i, compute d(i, j) for all samples j
        predictions = []
        references = []
        comparison_indices = []  # Track which (i, j) pair each comparison belongs to
        
        for i in range(n_samples):
            for j in range(n_samples):
                if i != j:
                    predictions.append(samples[i])  # Candidate y'_i as PREDICTION
                    references.append(samples[j])   # Sample y_j as REFERENCE
                    comparison_indices.append((i, j))
        
        # Compute all neural metric distances in one batch call
        input_s = list(zip(predictions, references))
        results = np.array([self.loss_fn(pred, ref) for pred, ref in input_s])

        # Organize distances by sample index
        avg_distances = [[] for _ in range(n_samples)]
        for idx, (i, j) in enumerate(comparison_indices):
            distance = results[idx]
            avg_distances[i].append(distance)
        
        # Compute average distance for each sample
        avg_distances = [np.mean(distances) for distances in avg_distances]
        
        # Return sample with minimum average distance (argmin)
        argmin_idx = np.argmin(avg_distances)
        return samples[argmin_idx]

    
class ExactMatchMinimumBayesAction(MinimumBayesAction):
    def __call__(self, samples: np.ndarray) -> np.ndarray:
        n_samples, n_bits = samples.shape
        #check that samples are binary
        assert np.all(np.isin(samples, [0, 1])), "Samples must be binary (0 or 1)"
        
        #build probabilities over classes 
        probability_per_bit = np.mean(samples, axis=0)

        #Take the bit/class with the highest probability
        mb_action = np.zeros(n_bits, dtype=int)
        mb_action[np.argmax(probability_per_bit)] = 1
        return np.array(mb_action)

class CosineMinimumBayesAction(MinimumBayesAction):
    def __call__(self, samples: np.ndarray) -> np.ndarray:
        """
        For cosine distance, the Bayes optimal action is the mean embedding (normalized).
        """
        #assert that samples have unit norm
        norms = np.linalg.norm(samples, axis=1)
        assert np.allclose(norms, 1.0), "All samples must have unit norm for cosine distance."

        mean_embedding = np.mean(samples, axis=0)
        # Normalize to unit length
        mean_embedding_norm = np.linalg.norm(mean_embedding)
        return mean_embedding / mean_embedding_norm


class ReverseKLMinimumBayesAction(MinimumBayesAction):
    """Bayes-optimal estimator for reverse KL divergence loss.

    For KL(p_est || p_true), i.e. KL(q || P), the Bayes estimator is the
    normalized geometric mean of the posterior samples:

        q_i ∝ exp( E[ log P_i ] )

    With MC samples p^(m), estimate E[log P_i] by the empirical mean of logs.

    Notes:
    - If any sample has P_i = 0, then log P_i = -inf, which forces q_i = 0
      under the exact objective. If you don't want support collapse due to
      numerical/estimation artifacts, set epsilon > 0 to floor probabilities.
    """

    def __init__(self, epsilon: float = 1e-12, strict_zeros: bool = False):
        """
        Args:
            epsilon: Floor applied to probabilities before taking logs.
                     Set to 0.0 to use exact logs (may produce -inf).
            strict_zeros: If True, do NOT floor; any zero in any sample forces
                          q_i = 0 (exact reverse-KL behavior). Overrides epsilon.
        """
        if epsilon < 0:
            raise ValueError("epsilon must be >= 0")
        self.epsilon = float(epsilon)
        self.strict_zeros = bool(strict_zeros)

    def __call__(self, samples: np.ndarray) -> np.ndarray:
        """
        Args:
            samples: array of shape (n_samples, K), each row a probability simplex.

        Returns:
            q: array of shape (K,), Bayes-optimal action under reverse-KL.
        """
        samples = np.asarray(samples, dtype=np.float64)
        if samples.ndim != 2:
            raise ValueError(f"samples must be 2D (n_samples, K), got shape {samples.shape}")

        # Basic sanity: nonnegative and row-normalized (allow small numerical drift)
        if np.any(samples < -1e-15):
            raise ValueError("samples contains negative entries")
        row_sums = samples.sum(axis=1, keepdims=True)
        if np.any(row_sums <= 0):
            raise ValueError("some sample rows sum to <= 0")
        samples = samples / row_sums  # re-normalize defensively

        if self.strict_zeros:
            # Exact: log(0)=-inf will propagate into a_i and yield q_i=0
            log_samples = np.log(samples)
        else:
            eps = self.epsilon
            if eps > 0.0:
                # Floor then renormalize each row so it's still a simplex
                floored = np.maximum(samples, eps)
                floored /= floored.sum(axis=1, keepdims=True)
                log_samples = np.log(floored)
            else:
                log_samples = np.log(samples)  # may contain -inf if zeros exist

        a_hat = np.mean(log_samples, axis=0)  # shape (K,)

        # Convert to q via softmax(a_hat) using log-sum-exp stabilization
        # Handle -inf entries: exp(-inf)=0 automatically, as long as not all are -inf
        a_max = np.max(a_hat)
        if not np.isfinite(a_max):
            raise ValueError(
                "All coordinates have -inf mean log-probability. "
                "This can happen if every coordinate is zero in at least one sample. "
                "Use epsilon>0 (non-strict) or check your samples."
            )

        weights = np.exp(a_hat - a_max)
        Z = weights.sum()
        if Z <= 0 or not np.isfinite(Z):
            raise ValueError("Normalization failed (numerical issue). Try increasing epsilon.")
        q = weights / Z

        # Final defensive renorm
        return q / q.sum()


class KLMinimumBayesAction(MinimumBayesAction):
    """Bayes-optimal estimator for KL divergence loss.

    For KL(p_true || p_est), i.e. KL(P || q), the Bayes estimator is the
    posterior mean of the samples:

        q_i = E[ P_i ]
    """
    
    def __call__(self, samples: np.ndarray) -> np.ndarray:
        """
        Args:
            samples: array of shape (n_samples, K), each row a probability simplex.

        Returns:
            q: array of shape (K,), Bayes-optimal action under KL.
        """
        samples = np.asarray(samples, dtype=np.float64)
        if samples.ndim != 2:
            raise ValueError(f"samples must be 2D (n_samples, K), got shape {samples.shape}")

        # Basic sanity: nonnegative and row-normalized (allow small numerical drift)
        if np.any(samples < -1e-15):
            raise ValueError("samples contains negative entries")
        row_sums = samples.sum(axis=1, keepdims=True)
        if np.any(row_sums <= 0):
            raise ValueError("some sample rows sum to <= 0")
        samples = samples / row_sums  # re-normalize defensively

        q = np.mean(samples, axis=0)  # shape (K,)

        #assert that q is a valid probability distribution so it sums to 1
        assert np.isclose(q.sum(), 1.0), "Output is not a valid probability distribution"
        return q

class MAPAction(MinimumBayesAction):
    def __call__(self, samples: np.ndarray, encoded_samples: np.ndarray = None) -> np.ndarray:
        
        n_samples = len(samples)
        if n_samples == 1:
            return samples[0]
        
        # Compute the empirical pmf
        unique_samples, counts = np.unique(samples, axis=0, return_counts=True)
        pmf = counts.astype(float) / counts.sum()
        
        # Find the sample with the highest probability
        max_idx = np.argmax(pmf)

        #Return encoded sample if provided
        if encoded_samples is not None:
            return encoded_samples[max_idx]
        return unique_samples[max_idx]
