import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import poisson
from scipy.optimize import minimize_scalar
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import hashlib
import pickle
from pathlib import Path
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

@dataclass
class ExperimentConfig:
    d: int = 64
    n_minority: int = 100
    p_interfere_ppp: float = 5e-5
    p_interfere_cluster_near: float = 0.012
    p_interfere_cluster_far: float = 1e-7
    n_clusters: int = 10
    cluster_near_prob: float = 0.3
    shortlist_size: int = 10
    n_maj_values: List[int] = field(default_factory=lambda: [
        100, 500, 1000, 2000, 5000, 10000, 20000, 50000,
        100000, 200000, 350000, 500000, 750000, 1000000
    ])
    n_seeds: int = 30

    cache_dir: str = "./cache"
    use_cache: bool = True
    figsize: Tuple[int, int] = (12, 8)
    dpi: int = 150


class Cache:
    def __init__(self, d):
        self.d = Path(d)
        self.d.mkdir(parents=True, exist_ok=True)

    def key(self, c, t, s, n):
        h = f"{c.p_interfere_ppp}_{c.p_interfere_cluster_near}_{c.n_clusters}_{c.shortlist_size}_{t}_{s}_{n}"
        return hashlib.md5(h.encode()).hexdigest()

    def get(self, c, t, s, n):
        if not c.use_cache: return None
        f = self.d / f"{self.key(c, t, s, n)}.pkl"
        return pickle.load(open(f, 'rb')) if f.exists() else None

    def set(self, c, t, s, n, v):
        if c.use_cache:
            pickle.dump(v, open(self.d / f"{self.key(c, t, s, n)}.pkl", 'wb'))


def simulate_ppp(config, n_maj, rng):
    """PPP: each majority doc has uniform interference probability."""
    L = config.shortlist_size
    n_interferers = rng.binomial(n_maj, config.p_interfere_ppp, config.n_minority)
    return np.mean(n_interferers < L)


def simulate_clustered(config, n_maj, rng):
    """Clustered case with K clusters, some near (high p), some far (low p)."""
    L = config.shortlist_size
    K = config.n_clusters
    docs_per = n_maj // K

    is_near = rng.random(K) < config.cluster_near_prob

    successes = 0
    for _ in range(config.n_minority):
        total = 0
        for k in range(K):
            p = config.p_interfere_cluster_near if is_near[k] else config.p_interfere_cluster_far
            total += rng.binomial(docs_per, p)
        if total < L:
            successes += 1
    return successes / config.n_minority


def run(config, exp_type, n_maj, seed, cache):
    v = cache.get(config, exp_type, seed, n_maj)
    if v is not None: return v
    rng = np.random.default_rng(seed)
    v = simulate_ppp(config, n_maj, rng) if exp_type == 'ppp' else simulate_clustered(config, n_maj, rng)
    cache.set(config, exp_type, seed, n_maj, v)
    return v


def fit_theory(n_maj, empirical, L):
    def loss(lb):
        b = np.exp(lb)
        return np.sum((np.array([poisson.cdf(L - 1, b * n) for n in n_maj]) - empirical) ** 2)

    b = np.exp(minimize_scalar(loss, bounds=(-15, -3), method='bounded').x)
    return np.array([poisson.cdf(L - 1, b * n) for n in n_maj]), b


def run_experiment(config):
    cache = Cache(config.cache_dir)

    print("=" * 80)
    print("Clustered vs. Uniform Interference Experiment")
    print("=" * 80)
    print(f"d={config.d}, L={config.shortlist_size}, seeds={config.n_seeds}")
    print(f"PPP p={config.p_interfere_ppp:.2e}")
    print(f"Clustered: K={config.n_clusters}, p_near={config.p_interfere_cluster_near:.2e}")
    print("=" * 80)

    results = {'n_maj': config.n_maj_values, 'config': config,
               'ppp': {'means': [], 'stds': []}, 'clustered': {'means': [], 'stds': []}}

    for exp in ['ppp', 'clustered']:
        print(f"\n{exp.upper()}...")
        for n in tqdm(config.n_maj_values):
            vals = [run(config, exp, n, s, cache) for s in range(config.n_seeds)]
            results[exp]['means'].append(np.mean(vals))
            results[exp]['stds'].append(np.std(vals))

    for k in ['ppp', 'clustered']:
        results[k]['means'] = np.array(results[k]['means'])
        results[k]['stds'] = np.array(results[k]['stds'])

    results['theory'], results['beta'] = fit_theory(
        np.array(config.n_maj_values), results['ppp']['means'], config.shortlist_size)

    return results


def plot(results, path):
    c = results['config']
    n = np.array(results['n_maj'])

    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=c.figsize, dpi=c.dpi)

    # Values
    m = results['ppp']['means']
    m2 = results['clustered']['means']
    theory = results['theory']

    # Theory
    ax.plot(n, theory, 'r-', lw=2.5, label=f'Theory (L={c.shortlist_size})', zorder=5)

    # PPP
    s = results['ppp']['stds']
    ci = 1.96 * s / np.sqrt(c.n_seeds)
    ax.fill_between(n, m - ci, m + ci, color='#1F78B4', alpha=0.25)
    ax.plot(n, m, '#1F78B4', lw=2.5, ls='--', marker='o', ms=6, mfc='white', mew=2,
            label='Empirical: PPP/Uniform', zorder=4)

    # Clustered
    s2 = results['clustered']['stds']
    ci2 = 1.96 * s2 / np.sqrt(c.n_seeds)
    ax.fill_between(n, m2 - ci2, m2 + ci2, color='#33A02C', alpha=0.25)
    ax.plot(n, m2, '#33A02C', lw=2.5, ls=':', marker='s', ms=6, mfc='white', mew=2,
            label=f'Empirical: Clustered (K={c.n_clusters})', zorder=3)

    def nc(means, th=0.5):
        idx = np.where(means < th)[0]
        return n[idx[0]] if len(idx) else None

    nc_ppp, nc_cl = nc(results['ppp']['means']), nc(results['clustered']['means'])
    if nc_ppp:
        ax.axvline(nc_ppp, color='#1F78B4', ls=':', alpha=0.6)
        ax.text(nc_ppp * 1.1, 0.55, f'$N_c^{{PPP}}$≈{nc_ppp:,}', fontsize=10, color='#1F78B4')
    if nc_cl:
        ax.axvline(nc_cl, color='#33A02C', ls=':', alpha=0.6)
        ax.text(nc_cl * 1.1, 0.45, f'$N_c^{{Clust}}$≈{nc_cl:,}', fontsize=10, color='#33A02C')

    ax.set_xlabel('Number of Majority Documents ($N_{maj}$)', fontsize=12)
    ax.set_ylabel('Minority Success Probability (Recall@1)', fontsize=12)
    ax.set_title(f'Phase Transition: Clustered vs Uniform Interference\n'
                 f'($d$={c.d}, $L$={c.shortlist_size}, $K$={c.n_clusters})', fontsize=14, fontweight='bold')

    ax.set_xscale('log')
    ax.set_xlim(n[0] * 0.7, n[-1] * 1.5)
    ax.set_ylim(-0.02, 1.05)
    ax.legend(loc='upper right', fontsize=11)
    ax.grid(True, alpha=0.4)

    plt.tight_layout()
    plt.savefig(path, dpi=c.dpi, bbox_inches='tight')
    plt.savefig(path.replace('.pdf', '.png'), dpi=c.dpi, bbox_inches='tight')

    print(f"\n {path}")


def plot_var(results, path):
    c = results['config']
    n = np.array(results['n_maj'])

    fig, (a1, a2) = plt.subplots(1, 2, figsize=(14, 5), dpi=c.dpi)

    a1.plot(n, results['ppp']['stds'], 'o-', color='#1F78B4', lw=2, ms=6, label='PPP')
    a1.plot(n, results['clustered']['stds'], 's-', color='#33A02C', lw=2, ms=6, label='Clustered')
    a1.set_xscale('log');
    a1.set_xlabel('$N_{maj}$', fontsize=12);
    a1.set_ylabel('Std Dev', fontsize=12)
    a1.set_title('Variance Comparison', fontsize=13, fontweight='bold');
    a1.legend(fontsize=11);
    a1.grid(alpha=0.4)
    a1.set_xlim(n[0] * 0.7, n[-1] * 1.5)

    a2.errorbar(n, results['ppp']['means'], results['ppp']['stds'], fmt='o-', color='#1F78B4',
                lw=2, ms=5, capsize=3, label='PPP', alpha=0.8)
    a2.errorbar(n, results['clustered']['means'], results['clustered']['stds'], fmt='s-', color='#33A02C',
                lw=2, ms=5, capsize=3, label='Clustered', alpha=0.8)
    a2.set_xscale('log');
    a2.set_xlabel('$N_{maj}$', fontsize=12);
    a2.set_ylabel('Recall@1', fontsize=12)
    a2.set_title('Performance ± std', fontsize=13, fontweight='bold');
    a2.legend(fontsize=11);
    a2.grid(alpha=0.4)
    a2.set_ylim(-0.05, 1.1);
    a2.set_xlim(n[0] * 0.7, n[-1] * 1.5)

    plt.tight_layout()
    plt.savefig(path, dpi=c.dpi, bbox_inches='tight')
    plt.savefig(path.replace('.pdf', '.png'), dpi=c.dpi, bbox_inches='tight')
    print(f" {path}")


def summary(results):
    n = np.array(results['n_maj'])

    print("\n" + "=" * 80 + "\nSUMMARY\n" + "=" * 80)
    for name in ['ppp', 'clustered']:
        m, s = results[name]['means'], results[name]['stds']
        nc50 = n[np.where(m < 0.5)[0][0]] if np.any(m < 0.5) else ">max"
        nc10 = n[np.where(m < 0.1)[0][0]] if np.any(m < 0.1) else ">max"
        nc01 = n[np.where(m < 0.01)[0][0]] if np.any(m < 0.01) else ">max"
        print(f"\n{name.upper()}: init={m[0]:.3f}±{s[0]:.3f}, final={m[-1]:.3f}±{s[-1]:.3f}")
        print(f"  Nc(50%)={nc50}, Nc(10%)={nc10}, Nc(1%)={nc01}")
    print(f"\nβ={results['beta']:.2e}\n" + "=" * 80)


def main():
    """
    Clustered vs. Uniform Interference Experiment

    Validates phase transition from Theorem 2.2: as N_maj increases past critical
    threshold Nc, minority retrieval success collapses to zero.
    """
    config = ExperimentConfig(
        d=64,
        n_minority=100,
        p_interfere_ppp=5e-5,
        p_interfere_cluster_near=0.012,
        p_interfere_cluster_far=1e-7,
        n_clusters=10,
        cluster_near_prob=0.3,
        shortlist_size=10,
        n_maj_values=[100, 500, 1000, 2000, 5000, 10000, 20000, 50000,
                      100000, 200000, 350000, 500000, 750000, 1000000],
        n_seeds=30,
        cache_dir="./cache",
        use_cache=True
    )

    results = run_experiment(config)
    summary(results)

    out = Path("./results")
    out.mkdir(exist_ok=True)
    plot(results, str(out / "ppp_assumption_clustered_vs_uniform.pdf"))
    print(f"\n All outputs in {out}")

    return results


if __name__ == "__main__":
    main()