import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ncx2, binom, poisson
from matplotlib.patches import Circle
from scipy.special import betainc

import matplotlib
import matplotlib.backends.backend_pgf
def common_texification(text):
    return text
matplotlib.backends.backend_pgf.common_texification = common_texification
import tikzplotlib
import matplotlib.pyplot as plt
from pathlib import Path

plt.style.use(Path.cwd().parent / "latex_tikz.mplstyle")

class RAGShortlistSimulator:
    """
    Simulates Theorem 1: minority query success under majority interference.
    """

    def __init__(self, dim=2, sigma_maj=1.0, minority_distance=3.0,
                 epsilon=0.3, L=20):

        self.dim = dim
        self.sigma_maj = sigma_maj
        self.mu_maj = np.zeros(dim)

        # Minority query position (away from majority)
        self.x_min = np.zeros(dim)
        self.x_min[0] = minority_distance * sigma_maj

        # Target document position (close to query)
        self.d_min = self.x_min.copy()
        self.d_min[0] += epsilon  # slightly offset

        # Retrieval ball radius
        self.r_ball = np.linalg.norm(self.x_min - self.d_min)
        self.L = L

        self._compute_interference_prob()

    def _compute_interference_prob(self):
        """Compute probability that a majority document falls within retrieval ball."""
        # Non-centrality parameter
        nc = np.linalg.norm(self.x_min - self.mu_maj) ** 2 / self.sigma_maj ** 2

        # Threshold for ball
        t = (self.r_ball ** 2) / (self.sigma_maj ** 2)

        # Probability a majority doc is within ball
        self.p_interference = ncx2.cdf(t, df=self.dim, nc=nc)

        # Critical population, Nc = L / p
        self.N_c = self.L / self.p_interference

    def compute_theory_success(self, N_maj):
        n = int(N_maj)
        k = self.L - 1
        p = float(self.p_interference)

        # Binomial CDF: Pr[X ≤ k] = I_{1-p}(n-k, k+1)
        a = n - k
        b = k + 1
        x = 1.0 - p

        if a <= 0:
            return 1.0
        return float(betainc(a, b, x))


    def simulate_empirical_success(self, N_maj, trials=5000, seed=42):
        """Empirical simulation of shortlist inclusion."""
        rng = np.random.default_rng(seed)
        # Ensure N_maj is proper int
        N_maj_int = int(N_maj)
        K = rng.binomial(N_maj_int, self.p_interference, size=trials)
        successes = (K <= self.L - 1)
        success_rate = np.mean(successes)
        se = np.sqrt(max(success_rate * (1 - success_rate) / trials, 1e-12))
        return success_rate, se, self.p_interference

    def sample_interference_counts(self, N_maj, trials=10000, seed=42):
        rng = np.random.default_rng(seed)
        return rng.binomial(int(N_maj), self.p_interference, size=trials)



def plot_main_phase_transition(save_path='fig1_main_transition'):
    """Main phase transition demonstration."""
    L = 20
    minority_distance = 1.5
    epsilon = 0.5
    sigma_maj = 1.0

    sim = RAGShortlistSimulator(
        dim=2, sigma_maj=sigma_maj, minority_distance=minority_distance,
        epsilon=epsilon, L=L
    )

    print(f"Main Experiment - N_c: {sim.N_c:.1f}, p: {sim.p_interference:.4f}")

    N_min = int(0.3 * sim.N_c)
    N_max = int(2.5 * sim.N_c)
    N_vals = np.arange(N_min, N_max, max(10, (N_max - N_min) // 40))

    results = {'N_maj': [], 'theory': [], 'empirical': [], 'ci_lo': [], 'ci_hi': []}

    for N in N_vals:
        theory_prob = sim.compute_theory_success(N)
        emp_prob, se, _ = sim.simulate_empirical_success(N, trials=5000)
        results['N_maj'].append(N)
        results['theory'].append(theory_prob)
        results['empirical'].append(emp_prob)
        results['ci_lo'].append(max(0.0, emp_prob - 1.96 * se))
        results['ci_hi'].append(min(1.0, emp_prob + 1.96 * se))

    fig, ax = plt.subplots(figsize=(8, 5))
    yerr = np.vstack([
        np.array(results['empirical']) - np.array(results['ci_lo']),
        np.array(results['ci_hi']) - np.array(results['empirical'])
    ])

    ax.errorbar(results['N_maj'], results['empirical'], yerr=yerr,
                fmt='o', alpha=0.6, markersize=4, label='Empirical (±95% CI)',
                color='steelblue')
    ax.plot(results['N_maj'], results['theory'],
            linewidth=2.5, label='Theory', color='darkred')
    ax.axvline(sim.N_c, linestyle='--', linewidth=2,
               label=f'$N_c$ ≈ {sim.N_c:.0f}', color='green', alpha=0.7)

    ax.set_xlabel('Number of Majority Documents ($N_{maj}$)')
    ax.set_ylabel('Minority Success Probability')
    ax.legend(loc='best', fontsize=11)

    plt.tight_layout()
    plt.savefig(save_path + '.png', dpi=300)  # Increased DPI
    tikzplotlib.save(
        save_path + '.tex',
        strict=False,
        textsize=10.0,
        override_externals=True,
        externalize_tables=False
    )
    plt.show()

    return pd.DataFrame(results), sim


def plot_shortlist_dependence(save_path='fig3_shortlist_dependence'):
    print("\nExperiment 3: Shortlist Size Dependence...")

    minority_distance = 1.5
    epsilon = 0.5
    sigma_maj = 1.0
    L_vals = [10, 20, 50, 100]

    fig, ax = plt.subplots(figsize=(9, 5.5))
    colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(L_vals)))

    for idx, L in enumerate(L_vals):
        sim = RAGShortlistSimulator(
            dim=2, sigma_maj=sigma_maj,
            minority_distance=minority_distance, epsilon=epsilon, L=L
        )

        N_vals = np.linspace(0.4 * sim.N_c, 1.8 * sim.N_c, 60)
        theory = [sim.compute_theory_success(N) for N in N_vals]

        ax.plot(N_vals / sim.N_c, theory, linewidth=2.5,
                label=f'L={L} ($N_c$≈{sim.N_c:.0f})', color=colors[idx])

        print(f"  L={L}: N_c={sim.N_c:.1f}")

    ax.axvline(1.0, linestyle='--', color='gray', linewidth=1.5, alpha=0.6)
    ax.set_xlabel('Normalized Majority Population ($N_{maj} / N_c$)', fontsize=12)
    ax.set_ylabel('Minority Success Probability', fontsize=12)

    ax.legend(loc='best', fontsize=10)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
    tikzplotlib.save(
        save_path + '.tex',
        strict=False,
        textsize=10.0,
        override_externals=True,
        externalize_tables=False
    )
    plt.show()


def plot_embedding_visualization(save_path='fig4_embedding_space'):
    print("\nExperiment 4: Embedding Space Visualization...")

    L = 20
    minority_distance = 1.5
    epsilon = 0.5

    sim = RAGShortlistSimulator(
        dim=2, sigma_maj=1.0, minority_distance=minority_distance,
        epsilon=epsilon, L=L
    )

    fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
    rng = np.random.default_rng(42)

    scenarios = [
        (int(0.5 * sim.N_c), 'Sub-Critical: $N_{maj} = 0.5 N_c$'),
        (int(1.5 * sim.N_c), 'Super-Critical: $N_{maj} = 1.5 N_c$')
    ]

    for ax, (N_maj, title) in zip(axes, scenarios):
        maj_docs = rng.normal(0, sim.sigma_maj, size=(N_maj, 2))

        distances = np.linalg.norm(maj_docs - sim.x_min, axis=1)
        interferers = maj_docs[distances <= sim.r_ball]
        non_interferers = maj_docs[distances > sim.r_ball]

        ax.scatter(non_interferers[:, 0], non_interferers[:, 1],
                   s=8, alpha=0.3, color='lightblue', label='Majority docs', rasterized=False)
        ax.scatter(interferers[:, 0], interferers[:, 1],
                   s=15, alpha=0.6, color='red', label=f'Interferers (K={len(interferers)})', rasterized=False)

        ax.plot(sim.x_min[0], sim.x_min[1],
                marker='o', linestyle='None',
                markersize=8,  # points, not s=area
                markerfacecolor='gold',
                markeredgecolor='black',
                markeredgewidth=1,
                label='Minority query',
                zorder=10, alpha=0.6)

        ax.plot(sim.d_min[0], sim.d_min[1],
                marker='s', linestyle='None',
                markersize=8,
                markerfacecolor='lime',
                markeredgecolor='black',
                markeredgewidth=1,
                label='Target doc',
                zorder=10, alpha=0.8)

        theta = np.linspace(0, 2*np.pi, 100)
        circle_x = sim.x_min[0] + sim.r_ball * np.cos(theta)
        circle_y = sim.x_min[1] + sim.r_ball * np.sin(theta)
        ax.plot(circle_x, circle_y, '--', color='orange',
                linewidth=2, label='Retrieval ball')

        bound = max(minority_distance + 2.0, 4.0)
        ax.set_xlim(-bound, bound)
        ax.set_ylim(-bound, bound)
        ax.set_aspect('equal')
        ax.set_xlabel('Dimension 1')
        ax.set_ylabel('Dimension 2')
        ax.set_title(title, fontweight='bold')
        ax.legend(loc='upper left', framealpha=0.9)
        ax.grid(True, alpha=0.3, linestyle=':', linewidth=0.5)

        success = "SUCCESS" if len(interferers) < L else "FAILURE"
        box_color = 'lightgreen' if len(interferers) < L else 'lightcoral'
        edge_color = 'darkgreen' if len(interferers) < L else 'darkred'

        ax.text(0.95, 0.05, f'{success}\n$K={len(interferers)}, L={L}$',
                transform=ax.transAxes, fontsize=10, fontweight='bold',
                ha='right', va='bottom',
                bbox=dict(boxstyle='round,pad=0.5',
                          facecolor='white', alpha=0.9,
                          edgecolor=edge_color, linewidth=1.5))

    plt.tight_layout()
    plt.savefig(save_path + '.png', dpi=150, bbox_inches='tight')

    tikzplotlib.save(
        save_path + '.tex',
        strict=False,
        textsize=10.0,
        override_externals=True,
        externalize_tables=False
    )

    plt.show()


if __name__ == "__main__":
    plot_main_phase_transition()
    plot_embedding_visualization()
    plot_shortlist_dependence()



