import math
from typing import Dict, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def _ensure_tensor(value: torch.Tensor, device: torch.device) -> torch.Tensor:
    if isinstance(value, torch.Tensor):
        return value.to(device)
    return torch.tensor(value, dtype=torch.float32, device=device)


def _get_mean_cov(input_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    mean = input_tensor.mean(dim=1)
    centered = input_tensor - mean.unsqueeze(1)
    cov = centered.transpose(1, 2) @ centered / input_tensor.size(1)
    return mean, cov


def _metric_error(mean, cov, mean_lim, cov_lim) -> Tuple[float, float]:
    mean_err = torch.linalg.vector_norm(mean - mean_lim, ord=2, dim=-1).mean()
    cov_err = torch.linalg.matrix_norm(cov - cov_lim, ord='fro', dim=(-2, -1)).mean()
    return mean_err.item(), cov_err.item()


def _sample_tokens(distribution: torch.distributions.MultivariateNormal, n: int) -> torch.Tensor:
    tokens = distribution.sample((n,))
    return tokens.unsqueeze(0)


def _project_qkv(tokens: torch.Tensor, wq: torch.Tensor, wk: torch.Tensor, wv: torch.Tensor) -> Tuple[torch.Tensor, ...]:
    q = tokens @ wq.T
    k = tokens @ wk.T
    v = tokens @ wv.T
    return q, k, v


def attention_operator(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tiling: bool = False) -> torch.Tensor:
    scale = 1.0 / math.sqrt(q.size(-1))
    if tiling:
        return F.scaled_dot_product_attention(q, k, v)
    scores = torch.matmul(q, k.transpose(-1, -2)) * scale
    probs = torch.softmax(scores, dim=-1)
    return torch.matmul(probs, v)


def _one_monte_carlo(
    distribution: torch.distributions.MultivariateNormal,
    n: int,
    k: int,
    mean_lim: torch.Tensor,
    cov_lim: torch.Tensor,
    wq: torch.Tensor,
    wk: torch.Tensor,
    wv: torch.Tensor,
    tiling: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    mean_errors = []
    cov_errors = []

    for _ in range(k):
        x_sample = _sample_tokens(distribution, n)
        q, k_tokens, v = _project_qkv(x_sample, wq, wk, wv)
        attn_layer = attention_operator(q, k_tokens, v, tiling=tiling)
        mean_est, cov_est = _get_mean_cov(attn_layer)
        mean_err, cov_err = _metric_error(mean_est, cov_est, mean_lim, cov_lim)
        mean_errors.append(mean_err)
        cov_errors.append(cov_err)

    mean_errors = torch.tensor(mean_errors)
    cov_errors = torch.tensor(cov_errors)

    mean_err_mean = mean_errors.mean()
    mean_err_std = mean_errors.std() / np.sqrt(k)
    cov_err_mean = cov_errors.mean()
    cov_err_std = cov_errors.std() / np.sqrt(k)

    return mean_err_mean, mean_err_std, cov_err_mean, cov_err_std


def _full_monte_carlo(
    distribution: torch.distributions.MultivariateNormal,
    n_min: int,
    n_max: int,
    nb_tot: int,
    k: int,
    mean_lim: torch.Tensor,
    cov_lim: torch.Tensor,
    wq: torch.Tensor,
    wk: torch.Tensor,
    wv: torch.Tensor,
    tiling: bool,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    n_values = np.geomspace(n_min, n_max, nb_tot).astype(int)

    mean_err_means = []
    mean_err_stds = []
    cov_err_means = []
    cov_err_stds = []

    for n in n_values:
        mean_m, mean_s, cov_m, cov_s = _one_monte_carlo(
            distribution,
            n,
            k,
            mean_lim,
            cov_lim,
            wq,
            wk,
            wv,
            tiling,
        )
        mean_err_means.append(mean_m)
        mean_err_stds.append(mean_s)
        cov_err_means.append(cov_m)
        cov_err_stds.append(cov_s)

    return (
        np.array(n_values),
        np.array(mean_err_means),
        np.array(mean_err_stds),
        np.array(cov_err_means),
        np.array(cov_err_stds),
    )


def _plot_convergence(
    ax,
    n_values,
    err_means,
    err_stds,
    slope,
    intercept,
    fit_cov,
    title,
    ylabel,
):
    ax.errorbar(
        n_values,
        err_means,
        yerr=err_stds,
        fmt='o',
        capsize=3,
        label='Data ± std (k reps)',
        color='blue',
    )

    n_fit = np.geomspace(n_values.min(), n_values.max(), 100)
    log_n_fit = np.log(n_fit)
    fit_line = np.exp(intercept) * n_fit ** slope
    slope_std = np.sqrt(fit_cov[0, 0])
    ax.plot(n_fit, fit_line, '--', color='red', label=f'Fit: $O(n^{{{slope:.2f} \\pm {slope_std:.2f}}})$')

    fit_var = log_n_fit**2 * fit_cov[0, 0] + fit_cov[1, 1] + 2 * log_n_fit * fit_cov[0, 1]
    fit_std = np.sqrt(fit_var)
    upper = np.exp(intercept + slope * log_n_fit + fit_std)
    lower = np.exp(intercept + slope * log_n_fit - fit_std)
    ax.fill_between(n_fit, lower, upper, alpha=0.2, color='red', label='Fit ± 1σ')

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Sample size n')
    ax.set_ylabel(ylabel)
    ax.legend()
    ax.set_title(title)


def polynomial_fit(
    n_values,
    mean_err_means,
    mean_err_stds,
    cov_err_means,
    cov_err_stds,
    sanity_checks: bool = True,
):
    log_n = np.log(n_values)

    (slope_mean, intercept_mean), cov_mean = np.polyfit(log_n, np.log(mean_err_means), deg=1, cov=True)
    (slope_cov, intercept_cov), cov_cov = np.polyfit(log_n, np.log(cov_err_means), deg=1, cov=True)

    slope_mean_std = np.sqrt(cov_mean[0, 0])
    slope_cov_std = np.sqrt(cov_cov[0, 0])

    if sanity_checks:
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        _plot_convergence(
            axes[0],
            n_values,
            mean_err_means,
            mean_err_stds,
            slope_mean,
            intercept_mean,
            cov_mean,
            'Mean Convergence',
            'L2 Mean Error',
        )
        _plot_convergence(
            axes[1],
            n_values,
            cov_err_means,
            cov_err_stds,
            slope_cov,
            intercept_cov,
            cov_cov,
            'Covariance Convergence',
            'Frobenius Cov Error',
        )
        plt.tight_layout()

    return slope_mean, slope_mean_std, slope_cov, slope_cov_std



def run_experiment_finite_gaussian(
    sigma: torch.Tensor,
    n_reference: int,
    d: int,
    *,
    n_min: int = 64,
    n_max: Optional[int] = None,
    nb_tot: int = 8,
    k: int = 5,
    weights: Optional[Dict[str, torch.Tensor]] = None,
    tiling: bool = False,
    seed: int = 0,
    device: Optional[torch.device] = None,
    sanity_checks: bool = True,
) -> Dict[str, np.ndarray]:
    """
    Run XP2: finite-sample Gaussian reference + Monte Carlo subsampling.

    Args:
        sigma: covariance matrix (d x d).
        n_reference: number of tokens used to build the finite reference distribution.
        d: embedding dimension.
        n_min/n_max/nb_tot/k: Monte Carlo sweep parameters.
        weights: optional dict with "q", "k", "v" matrices of shape [d, d].
        tiling: if True, use scaled_dot_product_attention.
        seed: random seed.
        device: torch device.
        sanity_checks: if True, plot log-log fits.

    Returns:
        Dict with MC arrays and polynomial fit slopes.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.manual_seed(seed)
    sigma = _ensure_tensor(sigma, device)

    if n_max is None:
        n_max = max(n_min * 4, n_reference)

    if weights is None:
        weight_scale = 1.0 / math.sqrt(d)
        wq = torch.randn(d, d, device=device) * weight_scale
        wk = torch.randn(d, d, device=device) * weight_scale
        wv = torch.randn(d, d, device=device) * weight_scale
    else:
        wq = _ensure_tensor(weights["q"], device)
        wk = _ensure_tensor(weights["k"], device)
        wv = _ensure_tensor(weights["v"], device)

    mean = torch.zeros(d, device=device)
    distribution = torch.distributions.MultivariateNormal(mean, covariance_matrix=sigma)

    reference_tokens = _sample_tokens(distribution, n_reference)
    q_ref, k_ref, v_ref = _project_qkv(reference_tokens, wq, wk, wv)
    attn_reference = attention_operator(q_ref, k_ref, v_ref, tiling=tiling)
    mean_lim, cov_lim = _get_mean_cov(attn_reference)

    n_values, mean_means, mean_stds, cov_means, cov_stds = _full_monte_carlo(
        distribution,
        n_min,
        n_max,
        nb_tot,
        k,
        mean_lim,
        cov_lim,
        wq,
        wk,
        wv,
        tiling,
    )

    slope_mean, slope_mean_std, slope_cov, slope_cov_std = polynomial_fit(
        n_values,
        mean_means,
        mean_stds,
        cov_means,
        cov_stds,
        sanity_checks=sanity_checks,
    )

    return {
        "n_values": n_values,
        "mean_err_means": mean_means,
        "mean_err_stds": mean_stds,
        "cov_err_means": cov_means,
        "cov_err_stds": cov_stds,
        "slope_mean": slope_mean,
        "slope_mean_std": slope_mean_std,
        "slope_cov": slope_cov,
        "slope_cov_std": slope_cov_std,
    }
