import numpy as np
import ast
from scipy.stats import gaussian_kde
from src.llm_clients.base_llm_client import BaseLLMClient
from scipy.stats import norm

class LCPrior:
    def __init__(self, agent: BaseLLMClient, ignore_zero=False):
        self.agent = agent
        self.name = "lc_prior"
        self.ignore_zero = ignore_zero
        
    def get_density_at(self, r_obs, distribution, bw=None):
        values, probs = zip(*ast.literal_eval(distribution))
        values = np.asarray(values, dtype=float)
        probs  = np.asarray(probs,  dtype=float)

        # degenerate prior
        if values.size == 1 and values[0] == 0:
            return 0.5

        # drop the explicit zero bucket if present because zero often means "I don't know" for the LLM
        if self.ignore_zero:
            if values[0] == 0:
                values, probs = values[1:], probs[1:]
                probs /= probs.sum()                      
       
        if bw is None:
            kde_dist = gaussian_kde(np.asarray(values), weights=np.asarray(probs))
            scale_factor = 1/kde_dist.integrate_box_1d(-1, 1)
            density = kde_dist.pdf(r_obs) * scale_factor
        else:
            sigmas = np.broadcast_to(bw, values.shape)  # supports scalar or vector

            cdf_hi = norm.cdf((1  - values) / sigmas)
            cdf_lo = norm.cdf((-1 - values) / sigmas)
            mass   = np.sum(probs * (cdf_hi - cdf_lo))

            if mass == 0:
                raise ValueError("Normalising mass is zero; choose a larger `bw`.")

            scale = 1.0 / mass                    

            pdf_vals = norm.pdf(r_obs, loc=values, scale=sigmas)  
            density  = scale * np.sum(probs * pdf_vals)

        return float(density)
    
    def posterior_pdf(self, r, values, probs, sigma):
        # r = np.asarray(r, float)
        r = np.atleast_1d(r).astype(float)  # now always 1D
        values = np.asarray(values, float)
        probs  = np.asarray(probs,  float)
        probs  = probs / probs.sum()
        # Numerator: sum_i p_i * N(r | r_i, σ^2)
        mix = np.sum(probs[:,None] * norm.pdf(r[None,:], loc=values[:,None], scale=sigma),
                    axis=0)

        # Denominator: 2 * Z, where Z = 0.5 * Σ p_i [Φ((1−r_i)/σ) − Φ((−1−r_i)/σ)]
        a = (-1 - values)/sigma
        b = ( 1 - values)/sigma
        Z = 0.5 * np.sum(probs * (norm.cdf(b) - norm.cdf(a)))
        dens = mix / (2*Z)
        # zero outside [-1,1]
        if dens.shape == ():
            return float(dens) if -1 <= r <= 1 else 0.0
        dens[(r < -1) | (r > 1)] = 0.0
        return dens
    

    def posterior_quantiles(self, values, probs, sigma, qs, grid_size=2001):
        """
        Compute quantiles by building a grid-based CDF and inverting.
        
        qs : list of floats in (0,1)
        Returns dict {q: quantile_value}.
        """
        # Build a fine grid over [-1,1]
        r = np.linspace(-1, 1, grid_size)
        pdf = self.posterior_pdf(r, values, probs, sigma)
        cdf = np.cumsum(pdf)
        cdf = cdf / cdf[-1]

        # Invert
        out = {}
        for q in qs:
            out[q] = np.interp(q, cdf, r)
        return out

    def posterior_mean(self, values, probs, sigma):
        """
        Closed-form posterior mean:
        """
        values = np.asarray(values, float)
        probs  = np.asarray(probs,  float)
        probs  = probs / probs.sum()
        a = (-1 - values)/sigma
        b = ( 1 - values)/sigma
        Φb, Φa = norm.cdf(b), norm.cdf(a)
        φb, φa = norm.pdf(b), norm.pdf(a)
        Z = 0.5 * np.sum(probs * (Φb - Φa))
        # avoid division by zero
        Z = max(Z, 1e-10)
        terms = values * (Φb - Φa) - sigma * (φb - φa)
        return np.sum(probs * terms) / (2*Z)
        
    def get_summary_stats(self, distribution, bw=None, q_vals=[0.025, 0.975]):
        distribution = ast.literal_eval(distribution)
        values, probs = zip(*distribution)
        if len(values) == 1 and values[0] == 0:
            return 0.0, -1+q_vals[0]*2, 1-q_vals[0]*2
        
        if self.ignore_zero:
            if values[0] == 0:
                values = values[1:]
                probs = probs[1:]
                total_prob = sum(probs)
                probs = [prob/total_prob for prob in probs]
                
        dist_mean = self.posterior_mean(values, probs, bw)
        quantiles_map = self.posterior_quantiles(values, probs, bw, q_vals)
        num = 1000
        rs = np.linspace(-1, 1, num)
        dens = self.posterior_pdf(rs, values, probs, bw)
        idx = np.argmax(dens)
        return rs[idx], dist_mean, quantiles_map[q_vals[0]], quantiles_map[q_vals[1]]