from dataclasses import dataclass
from enum import Enum

import numpy as np
import torch
from scipy.stats import norm
from typing import Tuple, Union

from sympy import symbols
from sympy.stats import Normal, independent, E, std
from sympy.stats.rv import RandomSymbol


IGNORED_VALUE = -1.
INVALID_VALUE = -2.


c_JL = 8        # original JL embedding (gaussian)
# c_JL = 4      # optimal embedding? (gaussian + QR)

N_JL = lambda v, eps, c=c_JL: c * np.log(v) / eps**2
epsilon_JL = lambda v, n, c=c_JL: np.sqrt(c * np.log(v) / n)


@dataclass
class EmbeddingType(str, Enum):
    IDENTITY = "identity"
    SCALED_GAUSSIAN = "scaled_gaussian"
    SCALED_GAUSSIAN_QR = "scaled_gaussian_qr"
    RAW_GAUSSIAN_QR = "raw_gaussian_qr"


@dataclass
class EstimationMethod(str, Enum):
    EXACT = "exact"
    CLT = "clt"


class BoundMethod(str, Enum):
    CLT = "clt"
    DETERMINISTIC = "deterministic"


def create_embedding_matrix(V: int, M: int, embedding_type: EmbeddingType | str, device=None):

    assert type(embedding_type) in [EmbeddingType, str]

    if type(embedding_type) is str:
        embedding_type = EmbeddingType(embedding_type)

    match embedding_type.value:

        case EmbeddingType.IDENTITY:
            assert M == V
            I = torch.eye(V)
            E = I  # (V, V)

        case EmbeddingType.SCALED_GAUSSIAN:
            G = torch.randn(V, M)  # (V, M)
            E = G * np.sqrt(1 / M)

        case EmbeddingType.RAW_GAUSSIAN_QR | EmbeddingType.SCALED_GAUSSIAN_QR:

            G = torch.randn(V, M)  # (V, M)
            Q, _ = torch.linalg.qr(G, mode="reduced")  # (V, M)

            factor = np.sqrt(V / M) if embedding_type.value == EmbeddingType.SCALED_GAUSSIAN_QR else 1

            E = factor * Q

        case _:
            raise ValueError(f"unsupported {embedding_type = }")

    if device is not None:
        E = E.to(device)

    return E  # (V, M)


def calculate_embedding_gram_entry_mean_and_variance(
    V: int,
    M: int,
    embedding_type: EmbeddingType,
    diag: bool,  # True for diagonal entry, False for off-diagonal
) -> Tuple[float, float]:
    """
    Returns (μ, var) for a single entry of the Gram matrix G = EᵀE:
      • IDENTITY: exact values.
      • NORMALIZED_GAUSSIAN: E_{ki} ∼ N(0, 1/M), so G_ij = ∑ₖ E_{ki} E_{kj}.
      • RAW_GAUSSIAN_QR: E = Q with Q from reduced QR on N(0,1)^(V×M),
        so G = QᵀQ = I (for M=V) or a random projection for M<V.
      • NORMALIZED_GAUSSIAN_QR: same as RAW, but rescaled so that
        EᵀE has unit expected diagonal.

    Parameters:
      V, M           dimensions of the embedding
      embedding_type which embedding scheme to use
      diag           whether to compute the diagonal (True) or off-diagonal (False)
    Returns:
      (mean, variance) of the chosen entry G_ij.
    """

    match embedding_type.value:

        case EmbeddingType.IDENTITY:
            assert M == V, "Identity requires M == V"
            if diag:
                mu, var = 1.0, 0.0
            else:
                mu, var = 0.0, 0.0

        case EmbeddingType.SCALED_GAUSSIAN:
            if diag:
                mu = 1.0
                var = 2.0 / M
            else:
                mu = 0.0
                var = 1.0 / M

        case EmbeddingType.RAW_GAUSSIAN_QR | EmbeddingType.SCALED_GAUSSIAN_QR:

            # For Q ∈ ℝ^{V×M} Haar, QQ^T has:
            #   E[diag] = M/V,   Var(diag) = 2 M (V−M)/[V^2 (V+2)]
            #   E[off]  = 0,     Var(off)  = M/[V (V+2)]

            var_diag = 2.0 * M * (V - M) / (V ** 2 * (V + 2))
            var_off = M / (V * (V + 2))
            norm_scale = V / M

            if embedding_type.value == EmbeddingType.RAW_GAUSSIAN_QR:
                if diag:
                    mu = M / V
                    var = var_diag
                else:
                    mu = 0.0
                    var = var_off

            elif embedding_type.value == EmbeddingType.SCALED_GAUSSIAN_QR:
                if diag:
                    mu = 1.0
                    var = var_diag * norm_scale ** 2
                else:
                    mu = 0.0
                    var = var_off * norm_scale ** 2

        case _:
            raise ValueError(f"unknown embedding_type={embedding_type!r}")

    return mu, var


def calculate_embedding_noise_entry_mean_and_variance(
    V: int,
    M: int,
    embedding_type: EmbeddingType,
    diag: bool,  # is diagonal entry
) -> Tuple[float, float]:

    mu_G_ij, var_G_ij = calculate_embedding_gram_entry_mean_and_variance(
        V=V, M=M, embedding_type=embedding_type, diag=diag,
    )

    mu_eps_ij = mu_G_ij - (1. if diag else 0.)
    var_eps_ij = var_G_ij

    return mu_eps_ij, var_eps_ij


def calculate_embedding_noise_entry_mean_and_std(
    V: int,
    M: int,
    embedding_type: EmbeddingType,
    diag: bool,  # is diagonal entry
) -> Tuple[float, float]:

    mu_eps_ij, var_eps_ij = calculate_embedding_noise_entry_mean_and_variance(
        V=V, M=M, embedding_type=embedding_type, diag=diag,
    )
    std_eps_ij = np.sqrt(var_eps_ij)

    return mu_eps_ij, std_eps_ij


def calculate_AR_single_entry_margin_mu_and_sigma(
    V: int,
    D: Union[int, np.ndarray],
    N: Union[int, np.ndarray],
    N_facts: int,
    embedding_type: EmbeddingType,
) -> Tuple[float, float]:
    """
    Compute the mean mu_Δ and std-dev sigma_Δ of the margin
      Δ = y_i − y_j
    fully separating signal and noise, including:
      - self-signal for i and for j (if j appears in any fact),
      - 1st and 2nd order noise terms,
      - cross-fact noise for all other facts.

    Keys: retrieval uses correct key m with noise ε_k_mm,
    but wrong value j if in some fact n contributes a self-signal·ε_k_nm.
    """

    if embedding_type == EmbeddingType.SCALED_GAUSSIAN:

        mean_delta = 1
        n = 1 / N
        d = 1 / D
        var_delta = 3*n + 5*d + (2*N_facts+5)*n*d
        sigma_delta = np.sqrt(var_delta)

        return mean_delta, sigma_delta

    # print("here!")

    # 1. Compute entry‐wise noise moments for k and v embeddings
    mu_k_d, var_k_d = calculate_embedding_noise_entry_mean_and_variance(V, N, embedding_type, diag=True)
    mu_k_o, var_k_o = calculate_embedding_noise_entry_mean_and_variance(V, N, embedding_type, diag=False)
    mu_v_d, var_v_d = calculate_embedding_noise_entry_mean_and_variance(V, D, embedding_type, diag=True)
    mu_v_o, var_v_o = calculate_embedding_noise_entry_mean_and_variance(V, D, embedding_type, diag=False)

    # 2. Mean of δ_ij
    #    E[ (1+ek_mm)(1+ev_ii - ev_ij) ] - E[ ek_nm(1+ev_jj - ev_ij) ]
    mean_delta = (1 + mu_v_d - mu_v_o) * (1 + mu_k_d - mu_k_o)

    # 3. Var[A] for A = (1+ek_mm)*(1+ev_ii - ev_ij)
    var_U = var_k_d
    var_V = var_v_d + var_v_o
    E_U = 1 + mu_k_d
    E_V = 1 + mu_v_d - mu_v_o

    VarA = (
        var_U * var_V
        + var_U * (E_V ** 2)
        + var_V * (E_U ** 2)
    )

    # 4. Var[B] for B = - ek_nm*(1+ev_jj - ev_ij)
    var_Xp = var_k_o
    E_Xp = mu_k_o
    # same V‐stats as above
    VarB = (
        var_Xp * var_V
        + var_Xp * (E_V ** 2)
        + var_V * (E_Xp ** 2)
    )

    # 5. Var[C] over the (N_facts−2) cross‐fact terms T_qp = ek_qm*(ev_ip - ev_jp)
    #    Here E[ev_ip - ev_jp] = 0, Var = 2*var_v_o
    VarT = 2 * var_v_o * (var_k_o + mu_k_o ** 2)
    VarC = (N_facts - 2) * VarT

    # 6. Total variance & std‐dev
    variance_delta = VarA + VarB + VarC
    sigma_delta = np.sqrt(variance_delta)

    return mean_delta, sigma_delta


def calculate_AR_expected_accuracy(
    V: int,
    D: Union[int, np.ndarray],
    N: Union[int, np.ndarray],
    N_facts: int,
    embedding_type: EmbeddingType,
    method: EstimationMethod,
) -> Union[float, np.ndarray]:
    """
    Returns P(correct recall) ≈ [Φ(mu_delta/sigma_delta)]^(V-1),
    where (mu_delta, sigma_delta) = calculate_AR_margin_mu_and_sigma(...).
    """

    mu_d, sigma_d = calculate_AR_single_entry_margin_mu_and_sigma(V, D, N, N_facts, embedding_type)

    if method is EstimationMethod.CLT:
        z_single_entry = mu_d / sigma_d
        p_single_entry = norm.cdf(z_single_entry)
        N_wrong_entries = V - 1
        p_all_entries = p_single_entry ** N_wrong_entries
        accuracy = p_all_entries

    else:
        raise NotImplementedError("Only CLT estimation is supported.")

    return accuracy


def calculate_MQAR_expected_accuracy(
    V: int,
    D: Union[int, np.ndarray],
    N: Union[int, np.ndarray],
    N_facts: int,
    N_queries: int,
    embedding_type: EmbeddingType,
    method: EstimationMethod,
) -> float:

    single_query_accuracy = calculate_AR_expected_accuracy(V, D, N, N_facts, embedding_type, method)
    multi_query_accuracy = single_query_accuracy ** N_queries

    return multi_query_accuracy


class Eps:

    def __init__(
        self,
        name: str,
        V: int,
        dim: Union[int, symbols],
        embedding_type: EmbeddingType,
    ):
        self.name = name

        # compute diag moments
        mu_d, var_d = calculate_embedding_noise_entry_mean_and_variance(
            V, dim, embedding_type, diag=True,
        )
        # compute off-diag moments
        mu_o, var_o = calculate_embedding_noise_entry_mean_and_variance(
            V, dim, embedding_type, diag=False,
        )

        # store the raw parameters
        self._mu_d, self._sigma_d = mu_d, np.sqrt(var_d)
        self._mu_o, self._sigma_o = mu_o, np.sqrt(var_o)

        # prototypes (for e.g. declaring independence)
        self._diag = Normal(f"{name}_d", mu_d, self._sigma_d)
        self._off = Normal(f"{name}_o", mu_o, self._sigma_o)

        independent(self._diag, self._off)

    def __getitem__(self, ij: str) -> RandomSymbol:
        i, j = ij
        if i == j:
            mu, sigma = self._mu_d, self._sigma_d
            proto = self._diag
        else:
            mu, sigma = self._mu_o, self._sigma_o
            proto = self._off

        # new Normal with same mu,sigma but a fresh name
        rv = Normal(f"{self.name}_{i}{j}", mu, sigma)

        # declare independent of the prototype (and thus from the other family)
        independent(rv, proto)

        return rv

    __call__ = __getitem__

    def __str__(self) -> str:
        """
        Return a summary of the diagonal and off-diagonal noise parameters.
        """
        var_d = self._sigma_d ** 2
        var_o = self._sigma_o ** 2
        return (
            f"{self.name}: "
            f"\n\tmu_d={self._mu_d}, var_d={var_d} (std={self._sigma_d})"
            f"\n\tmu_o={self._mu_o}, var_o={var_o} (std={self._sigma_o})"
        )

    def __repr__(self) -> str:
        return self.__str__()


def calculate_AR_lower_bound_accuracy(
    V: int,
    D: Union[int, np.ndarray],
    N: Union[int, np.ndarray],
    N_facts: int,
    embedding_type: EmbeddingType,
    method: BoundMethod,
) -> float:

    # First compute both signal and noise moments
    mu_delta, sigma_delta = calculate_AR_single_entry_margin_mu_and_sigma(V, D, N, N_facts, embedding_type)

    # Use a unified CLT-based bound: treat y^i ~ N(mu_signal, sigma_signal^2),
    # y^j ~ N(mu_noise, sigma_noise^2) for j != i.
    # Then P(success) >= [Phi((mu_signal-mu_noise)/sqrt(sigma_signal^2+sigma_noise^2))]^(V-1)
    if method == BoundMethod.CLT:
        delta = mu_delta / sigma_delta
        N_wrong_coordinates = V - 1
        accuracy_bound = norm.cdf(delta) ** N_wrong_coordinates

    # Deterministic ETF-like high-prob bound
    elif method == BoundMethod.DETERMINISTIC:
        # Use signal and noise means directly to compute margin
        margin = mu_delta
        # delta chosen to preserve a positive half-margin across all noise terms
        N_wrong_facts = N_facts - 1
        delta = margin / (16 * N_wrong_facts)
        min_dim = np.minimum(D, N)
        accuracy_bound = 1 - 4 * V**2 * np.exp(-0.5 * min_dim * delta**2)
    else:
        raise NotImplementedError("Unknown bound method.")

    return accuracy_bound


def calculate_MQAR_theoretical_JL_success(
        V: int, D_axis: np.array, N_axis: np.array,
        N_facts: int,
        avoid_eps_larger_than_1: bool = True,
        avoid_N_larger_than_D: bool = True,
        facts_noise_factor: float = None,
) -> float:

    D_axis = D_axis[:, None]
    N_axis = N_axis[None, :]

    eps_v = epsilon_JL(v=V, n=D_axis)
    eps_k = epsilon_JL(v=V, n=N_axis)

    if facts_noise_factor is None:
        facts_noise_factor = 2 * N_facts - 3  # JL theory

    facts_noise_factor = np.max([facts_noise_factor, 0])

    is_JL_successful = (1 - eps_v - 3 * eps_k) > facts_noise_factor * eps_v * eps_k
    is_JL_successful = is_JL_successful.astype(float)

    # ignore invalid-epsilon values
    if avoid_eps_larger_than_1:
        is_JL_successful[(eps_v > 1) | (eps_k > 1)] = INVALID_VALUE

    # ignore (N > D) values
    if avoid_N_larger_than_D:
        is_JL_successful[N_axis > D_axis] = IGNORED_VALUE

    return is_JL_successful


def create_dim_and_epsilon_axes(V: int, min_dim: int, max_dim: int, n_steps: int, axis_type: str):

    if axis_type == "linear_dim":

        step = (max_dim - min_dim) / n_steps

        dim_axis = np.arange(min_dim, max_dim, step=step).astype(int)[::-1]
        eps_axis = epsilon_JL(v=V, n=dim_axis)

    elif axis_type == "linear_epsilon":

        max_epsilon = epsilon_JL(v=V, n=min_dim)
        min_epsilon = epsilon_JL(v=V, n=max_dim)
        step = (max_epsilon / min_epsilon) / n_steps

        raw_eps_axis = np.arange(min_epsilon, max_epsilon, step=step)
        dim_axis = np.unique(N_JL(v=V, eps=raw_eps_axis).astype(int))[::-1]
        eps_axis = epsilon_JL(v=V, n=dim_axis)

    else:
        raise ValueError(f"unknown {axis_type = }")

    return dim_axis, eps_axis
