from fractions import Fraction
import os
import re
from typing import Any, Dict, List, Optional, Union

import numpy as np
from numpy.typing import NDArray
from scipy.spatial.distance import jensenshannon
from scipy.stats import chisquare, entropy


def calculate_metrics(
    observed_counts: Union[List[Union[int, float]], NDArray[np.float64]],
    expected_probabilities: Union[List[float], NDArray[np.float64]],
    num_samples: int,
) -> Dict[str, float]:
    """
    Calculates chi-square statistic, JS divergence, and KL divergence.
    """
    obs_counts_np: NDArray[np.float64] = np.array(observed_counts, dtype=np.float64)
    exp_probs_np: NDArray[np.float64] = np.array(
        [float(Fraction(e)) for e in expected_probabilities], dtype=np.float64
    )

    # Ensure exp_probs_np is not empty before summing
    if exp_probs_np.size > 0 and not np.isclose(np.sum(exp_probs_np), 1.0):
        if abs(float(np.sum(exp_probs_np)) - 1.0) > 1e-6:  # Cast sum to float
            raise ValueError(
                f"Expected probabilities must sum to 1. Sum is {np.sum(exp_probs_np)}"
            )
        # Avoid division by zero if sum is zero, though np.sum should be > 0 if not close to 1.0
        current_sum_exp_probs = np.sum(exp_probs_np)
        if current_sum_exp_probs != 0:
            exp_probs_np = exp_probs_np / current_sum_exp_probs
        else:  # This case (sum is 0 but not close to 1 and size > 0) is odd.
            pass  # Or raise error, depending on desired handling for all-zero expected probabilities.
    elif exp_probs_np.size == 0 and obs_counts_np.size != 0:
        raise ValueError(
            "Expected probabilities are empty but observed counts are not."
        )

    if len(obs_counts_np) != len(exp_probs_np):
        raise ValueError(
            "Observed counts and expected probabilities must have the same length."
        )

    observed_probabilities: NDArray[np.float64]
    if num_samples == 0:
        if np.sum(obs_counts_np) > 0:
            raise ValueError("num_samples is 0, but observed_counts sum to > 0.")
        observed_probabilities = np.zeros_like(exp_probs_np)
    else:
        observed_probabilities = obs_counts_np / num_samples
        current_sum_obs_probs: float = float(
            np.sum(observed_probabilities)
        )  # Cast sum to float
        if not np.isclose(current_sum_obs_probs, 1.0) and current_sum_obs_probs > 0:
            raise RuntimeError("The sum is not 1.")

    expected_counts: NDArray[np.float64] = exp_probs_np * num_samples

    chi2_stat: float
    cohen_w: float = np.nan
    if np.all(expected_counts == 0) and num_samples > 0:  # type: ignore
        chi2_stat = np.nan
        if np.all(obs_counts_np == 0):  # type: ignore
            chi2_stat = 0.0
    elif num_samples == 0 and np.all(obs_counts_np == 0):  # type: ignore
        chi2_stat = 0.0
    else:
        if float(np.sum(expected_counts)) < 1e-9:  # Cast sum to float
            chi2_stat = (
                np.inf if float(np.sum(obs_counts_np)) > 1e-9 else 0.0
            )  # Cast sum to float
            cohen_w = np.nan
        else:
            res = chisquare(f_obs=obs_counts_np, f_exp=expected_counts)
            chi2_stat = float(res.statistic)
            # df = len(obs_counts_np) - 1
            # cohen_w = chi2_stat / df
            n1 = np.sum(obs_counts_np)
            cohen_w = np.sqrt(chi2_stat / n1)

    # jensenshannon returns a float (or np.float64). Cast to Python float.
    js_div_val: float = (
        float(jensenshannon(observed_probabilities, exp_probs_np, base=np.e)) ** 2
    )

    m_probs = 0.5 * (observed_probabilities + exp_probs_np)
    js_part1 = entropy(pk=exp_probs_np, qk=m_probs)
    js_part2 = entropy(pk=observed_probabilities, qk=m_probs)
    js_pq = 0.5 * (js_part1 + js_part2)
    assert np.isclose(js_div_val, js_pq), f"{js_div_val}; {js_pq}"

    # entropy returns a float (or np.float64). Cast to Python float.
    kl_div_val: float = float(
        entropy(pk=observed_probabilities, qk=exp_probs_np, base=np.e)
    )

    # Total Variation Distance (TV): 0 <= TV <= 1
    # TV(P, Q) = 0.5 * sum |P_i - Q_i|
    tv_distance_val: float = float(0.5 * np.sum(np.abs(observed_probabilities - exp_probs_np)))

    return {
        "chi2_statistic": float(chi2_stat) if not np.isnan(chi2_stat) else np.nan,
        "effect_size_w": float(cohen_w) if not np.isnan(cohen_w) else np.nan,
        "js_divergence": js_div_val,  # Already a float
        "kl_divergence": kl_div_val,  # Already a float
        "tv_distance": tv_distance_val,
    }


def parse_experiment_params_from_path(
    path_to_experiment_leaf_dir: str,
) -> Optional[Dict[str, Any]]:
    """
    Parses experiment parameters from the directory name of a specific run.
    Example dirname: '0_experiment.num_samples=1000,model=deepseek_r1,prompt.type=rsm,prompt=rsp,...'
    """
    dirname: str = os.path.basename(path_to_experiment_leaf_dir)
    params: Dict[str, Any] = {}

    model_match = re.search(r"model=([^,]+)", dirname)
    params["model"] = model_match.group(1) if model_match else None

    prompt_type_match = re.search(r"prompt\.type=([^,]+)", dirname)
    params["prompt_type"] = prompt_type_match.group(1) if prompt_type_match else None

    prompt_name_match = re.search(
        r"prompt=([^,]+)", dirname
    )  # This is the main prompt identifier
    params["prompt_name"] = prompt_name_match.group(1) if prompt_name_match else None

    num_samples_match = re.search(r"num_samples=(\d+)", dirname)
    params["num_samples"] = (
        int(num_samples_match.group(1))
        if num_samples_match and num_samples_match.group(1).isdigit()
        else None
    )

    # Check if any critical param is None or 0 for num_samples (or if num_samples itself is None)
    if not all(
        params.get(key) is not None
        for key in ["model", "prompt_type", "prompt_name", "num_samples"]
    ):
        print(
            f"Warning: Could not parse all critical parameters from directory name: {dirname}. Parsed: {params}"
        )
        return params
    return params
