from typing import Union, Tuple
import numpy as np
import torch
from scipy import stats
from sklearn.metrics import mutual_info_score
import torch.nn.functional as F
def kl_divergence(
    p: Union[np.ndarray, torch.Tensor],
    q: Union[np.ndarray, torch.Tensor],
    epsilon: float = 1e-10
) -> float:
    if isinstance(p, torch.Tensor):
        p = p.detach().cpu().numpy()
    if isinstance(q, torch.Tensor):
        q = q.detach().cpu().numpy()
    p = p + epsilon
    q = q + epsilon
    p = p / np.sum(p)
    q = q / np.sum(q)
    return np.sum(p * np.log(p / q))
def js_divergence(
    p: Union[np.ndarray, torch.Tensor],
    q: Union[np.ndarray, torch.Tensor],
    epsilon: float = 1e-10
) -> float:
    if isinstance(p, torch.Tensor):
        p = p.detach().cpu().numpy()
    if isinstance(q, torch.Tensor):
        q = q.detach().cpu().numpy()
    p = p + epsilon
    q = q + epsilon
    p = p / np.sum(p)
    q = q / np.sum(q)
    m = 0.5 * (p + q)
    js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))
    return js_div
def wasserstein_distance(
    p: Union[np.ndarray, torch.Tensor],
    q: Union[np.ndarray, torch.Tensor]
) -> float:
    if isinstance(p, torch.Tensor):
        p = p.detach().cpu().numpy()
    if isinstance(q, torch.Tensor):
        q = q.detach().cpu().numpy()
    p = p.flatten()
    q = q.flatten()
    p = p / np.sum(p)
    q = q / np.sum(q)
    p_positions = np.arange(len(p))
    q_positions = np.arange(len(q))
    return stats.wasserstein_distance(p_positions, q_positions, p, q)
def cosine_similarity(
    p: Union[np.ndarray, torch.Tensor],
    q: Union[np.ndarray, torch.Tensor]
) -> float:
    if isinstance(p, np.ndarray):
        p = torch.from_numpy(p)
    if isinstance(q, np.ndarray):
        q = torch.from_numpy(q)
    return F.cosine_similarity(p.unsqueeze(0), q.unsqueeze(0)).item()
def mutual_information(
    p: Union[np.ndarray, torch.Tensor],
    q: Union[np.ndarray, torch.Tensor],
    bins: int = 50
) -> float:
    if isinstance(p, torch.Tensor):
        p = p.detach().cpu().numpy()
    if isinstance(q, torch.Tensor):
        q = q.detach().cpu().numpy()
    hist_p, _ = np.histogram(p, bins=bins, density=True)
    hist_q, _ = np.histogram(q, bins=bins, density=True)
    hist_p = hist_p / np.sum(hist_p)
    hist_q = hist_q / np.sum(hist_q)
    return mutual_info_score(hist_p, hist_q)
def compute_distribution_metrics(
    p: Union[np.ndarray, torch.Tensor],
    q: Union[np.ndarray, torch.Tensor],
    bins: int = 50
) -> dict:
    metrics = {
        'cosine_similarity': cosine_similarity(p, q),
    }
    return metrics
