from typing import Dict, List, Tuple, Set, Any
import numpy as np
from dataclasses import dataclass

from .config import ExistenceConfig
from .belief import OTHER_TOKEN_ID

@dataclass
class SlotMetrics:

    seq_id: int
    position: int
    true_token_id: int

    h_slot_real: float
    h_slot_swap: float
    h_slot_shuffle: float

    cei_real: float
    cei_swap: float
    cei_shuffle: float

    stokes_check_data: Dict[str, Any] = None

    token_freq_rank: int = 0

def compute_log_odds_coordinates(
    belief_grid: Dict[Tuple[int, int], Dict[int, float]],
    candidate_set: Set[int],
    eps: float = 1e-12
) -> Dict[Tuple[int, int], Dict[int, float]]:

    pi = belief_grid[(0, 0)]

    u_grid = {}

    all_states = set(candidate_set) | {OTHER_TOKEN_ID}

    for (L, R), mu in belief_grid.items():
        u = {}
        for s in all_states:
            mu_s = mu.get(s, eps)
            pi_s = pi.get(s, eps)
            u[s] = np.log(max(mu_s, eps)) - np.log(max(pi_s, eps))
        u_grid[(L, R)] = u

    return u_grid

def compute_cell_holonomy(
    u_grid: Dict[Tuple[int, int], Dict[int, float]],
    L1: int, L2: int,
    R1: int, R2: int,
    candidate_set: Set[int]
) -> Dict[int, float]:

    all_states = set(candidate_set) | {OTHER_TOKEN_ID}

    omega = {}
    for s in all_states:
        omega[s] = (
            u_grid[(L2, R2)].get(s, 0.0) -
            u_grid[(L2, R1)].get(s, 0.0) -
            u_grid[(L1, R2)].get(s, 0.0) +
            u_grid[(L1, R1)].get(s, 0.0)
        )

    return omega

def compute_holonomy_magnitude(
    belief_grid: Dict[Tuple[int, int], Dict[int, float]],
    candidate_set: Set[int],
    config: ExistenceConfig
) -> Tuple[float, Dict[str, Any]]:

    eps = config.eps
    l_list = config.l_list
    r_list = config.r_list

    u_grid = compute_log_odds_coordinates(belief_grid, candidate_set, eps)

    all_states = list(set(candidate_set) | {OTHER_TOKEN_ID})

    h_cells = []

    omega_per_state = {s: [] for s in all_states}

    for i in range(len(l_list) - 1):
        L1, L2 = l_list[i], l_list[i + 1]
        for j in range(len(r_list) - 1):
            R1, R2 = r_list[j], r_list[j + 1]

            omega = compute_cell_holonomy(u_grid, L1, L2, R1, R2, candidate_set)

            for s in all_states:
                omega_per_state[s].append(omega[s])

            mu_corner = belief_grid[(L2, R2)]

            weighted_sq_sum = 0.0
            for s in all_states:
                w_s = mu_corner.get(s, eps)
                weighted_sq_sum += w_s * (omega[s] ** 2)

            h_cell = np.sqrt(weighted_sq_sum)
            h_cells.append(h_cell)

    h_slot = np.mean(h_cells) if h_cells else 0.0

    L_min, L_max = min(l_list), max(l_list)
    R_min, R_max = min(r_list), max(r_list)

    stokes_check_data = {
        "rectangle_holonomy": {},
        "sum_cell_holonomy": {},
        "states_sampled": []
    }

    for s in all_states[:5]:
        H_s = (
            u_grid[(L_max, R_max)].get(s, 0.0) -
            u_grid[(L_max, R_min)].get(s, 0.0) -
            u_grid[(L_min, R_max)].get(s, 0.0) +
            u_grid[(L_min, R_min)].get(s, 0.0)
        )

        S_s = sum(omega_per_state[s])

        stokes_check_data["rectangle_holonomy"][s] = H_s
        stokes_check_data["sum_cell_holonomy"][s] = S_s
        stokes_check_data["states_sampled"].append(s)

    return h_slot, stokes_check_data

def compute_poe_distribution(
    mu_L: Dict[int, float],
    mu_R: Dict[int, float],
    pi: Dict[int, float],
    candidate_set: Set[int],
    eps: float = 1e-12
) -> Dict[int, float]:

    all_states = set(candidate_set) | {OTHER_TOKEN_ID}

    poe_unnorm = {}
    for s in all_states:
        mu_L_s = mu_L.get(s, eps)
        mu_R_s = mu_R.get(s, eps)
        pi_s = pi.get(s, eps)

        poe_unnorm[s] = (mu_L_s * mu_R_s) / max(pi_s, eps)

    total = sum(poe_unnorm.values())
    if total <= 0:
        total = 1.0

    poe = {s: max(poe_unnorm[s] / total, eps) for s in all_states}

    total = sum(poe.values())
    poe = {s: poe[s] / total for s in all_states}

    return poe

def compute_kl_divergence(
    p: Dict[int, float],
    q: Dict[int, float],
    eps: float = 1e-12
) -> float:

    kl = 0.0
    for s, p_s in p.items():
        q_s = q.get(s, eps)
        if p_s > eps:
            kl += p_s * (np.log(max(p_s, eps)) - np.log(max(q_s, eps)))
    return max(kl, 0.0)

def compute_cei(
    belief_grid: Dict[Tuple[int, int], Dict[int, float]],
    candidate_set: Set[int],
    config: ExistenceConfig
) -> float:

    eps = config.eps
    l_max = config.l_max
    r_max = config.r_max

    mu_joint = belief_grid[(l_max, r_max)]
    mu_L = belief_grid[(l_max, 0)]
    mu_R = belief_grid[(0, r_max)]
    pi = belief_grid[(0, 0)]

    mu_poe = compute_poe_distribution(mu_L, mu_R, pi, candidate_set, eps)

    cei = compute_kl_divergence(mu_joint, mu_poe, eps)

    return cei

def compute_all_metrics(
    slot_results: List[Dict],
    config: ExistenceConfig,
    token_ranks: Dict[int, int]
) -> List[SlotMetrics]:

    all_metrics = []

    for slot_result in slot_results:
        slot = slot_result["slot"]
        candidate_set = slot_result["candidate_set"]
        beliefs = slot_result["beliefs"]

        h_real, stokes_data = compute_holonomy_magnitude(
            beliefs["real"], candidate_set, config
        )
        h_swap, _ = compute_holonomy_magnitude(
            beliefs["swap"], candidate_set, config
        )
        h_shuffle, _ = compute_holonomy_magnitude(
            beliefs["shuffle"], candidate_set, config
        )

        cei_real = compute_cei(beliefs["real"], candidate_set, config)
        cei_swap = compute_cei(beliefs["swap"], candidate_set, config)
        cei_shuffle = compute_cei(beliefs["shuffle"], candidate_set, config)

        freq_rank = token_ranks.get(slot.true_token_id, len(token_ranks) + 1)

        metrics = SlotMetrics(
            seq_id=slot.seq_id,
            position=slot.position,
            true_token_id=slot.true_token_id,
            h_slot_real=h_real,
            h_slot_swap=h_swap,
            h_slot_shuffle=h_shuffle,
            cei_real=cei_real,
            cei_swap=cei_swap,
            cei_shuffle=cei_shuffle,
            stokes_check_data=stokes_data,
            token_freq_rank=freq_rank
        )

        all_metrics.append(metrics)

    return all_metrics
