from src.correlation import Correlation
from src.priors.gaussian_prior import TruncatedGaussian
from src.utils import parse_json_block
from prompts.range_estimation import instruction_reasoning_range, prompt_range_prediction
from scipy.stats import norm
import numpy as np


class RangeGaussian(TruncatedGaussian):
    def __init__(self, agent):
        self.agent = agent
        self.name = 'range_gaussian_with_predicted_mean_and_confidence_level'
    
    def get_user_msg(self, correlation: Correlation, task_desc: str=instruction_reasoning_range):
        self.correlation = correlation
        self.var1 = self.correlation.var1
        self.var2 = self.correlation.var2
       
        return prompt_range_prediction.format(
            task_desc=instruction_reasoning_range,
            table=self.var1.table,
            tbl_desc=self.var1.table_desc,
            attr1=self.var1.attr,
            attr2=self.var2.attr,
            var1_desc=self.var1.var_desc,
            var2_desc=self.var2.var_desc
        )
    
    def build_fisher_gaussian(self, r_min, r_max, confidence=0.95):
        # 1. Convert endpoints to z-space
        eps = 1e-12
        z_min = np.arctanh(max(-1+eps, r_min))  # tanh^{-1}(r_min)
        z_max = np.arctanh(min(1-eps, r_max))  # tanh^{-1}(r_max)

        # 2. Mean is the midpoint in z-space
        z_mean = 0.5 * (z_min + z_max)

        # 3. Determine std by assuming [z_min, z_max] is ~ 'confidence' coverage
        z_alpha2 = norm.ppf((1 + confidence) / 2.0)  # ~1.96 for 95%
        half_width = 0.5 * (z_max - z_min)
        z_std = half_width / z_alpha2
        return z_mean, z_std

    def probability_in_interval(self, r_center, radius, z_mean, z_std):
        """
        Compute P(r in [r_center - radius, r_center + radius]) given a normal
        distribution in z-space with mean=z_mean, std=z_std. Clamps the interval
        to [-1, 1] to avoid invalid arctanh.
        """
        # Clamp to valid correlation range
        eps = 1e-12 
        r_lower = max(-1.0+eps, r_center - radius)
        r_upper = min(1.0-eps, r_center + radius)

        # Convert to z-space
        z_lower = np.arctanh(r_lower)
        z_upper = np.arctanh(r_upper)

        # Standardize
        z_lower_std = (z_lower - z_mean) / z_std
        z_upper_std = (z_upper - z_mean) / z_std

        # Probability is difference of standard normal CDFs
        return norm.cdf(z_upper_std) - norm.cdf(z_lower_std)
    
    def pdf_in_r_space(self, r, z_mean, z_std):
        """
        PDF of the Fisher-z-based normal at correlation r.
        That is, we compute the PDF of z ~ N(z_mean, z_std^2) at z = arctanh(r),
        and then apply the Jacobian for the transform r = tanh(z).

        PDF(r) = PDF(z) * |dz/dr|.
        where z = arctanh(r) and dz/dr = 1/(1 - r^2).
        """
        if r <= -1.0 or r >= 1.0:
            return 0.0  # outside valid correlation range
        
        z = np.arctanh(r)
        # PDF of z in normal
        pdf_z = norm.pdf(z, loc=z_mean, scale=z_std)
        # Jacobian: dz/dr = 1 / (1 - r^2)
        jacobian = 1.0 / (1.0 - r**2)
        return pdf_z * jacobian
    
    def get_stats(self, r_obs, r_pred, z_mean, z_std):
        density_obs = self.pdf_in_r_space(r_obs, z_mean, z_std)
        density_mean = self.pdf_in_r_space(r_pred, z_mean, z_std)

        probability_005 = self.probability_in_interval(r_obs, 0.05, z_mean, z_std)
        probability_01  = self.probability_in_interval(r_obs, 0.10, z_mean, z_std)
        probability_015 = self.probability_in_interval(r_obs, 0.15, z_mean, z_std)
        probability_02  = self.probability_in_interval(r_obs, 0.20, z_mean, z_std)

        return {
            'density_obs': density_obs,
            'density_mean': density_mean,
            'probability_005': round(probability_005, 3),
            'probability_01': round(probability_01, 3),
            'probability_015': round(probability_015, 3),
            'probability_02': round(probability_02, 3)
        }


    
    def find_standard_deviation(self, a, b, confidence):
        """
        Compute the standard deviation (σ) such that X% of values fall within [a, b] for a Gaussian distribution.

        :param a: Lower bound of the range
        :param b: Upper bound of the range
        :param mu: Mean of the distribution
        :param confidence: Desired probability (X%) that values fall within [a, b] (e.g., 0.95 for 95%)
        :return: Standard deviation (σ)
        """
        # Convert confidence into lower/upper tail probabilities
        alpha = (1.0 - confidence) / 2.0  # Probability in each tail
        lower_tail = alpha
        upper_tail = 1.0 - alpha

        # Get z-scores for these tail probabilities
        z_lower = norm.ppf(lower_tail)  # e.g., ~ -1.96 for 95%
        z_upper = norm.ppf(upper_tail)  # e.g., ~ +1.96 for 95%

        # Solve for sigma
        sigma = (b - a) / (z_upper - z_lower)
        
        return sigma
    
    def get_prior(self, correlation: Correlation):
        user_message = self.get_user_msg(correlation)
        print(user_message)
        response, usage = self.agent.call(user_message)
        print(response)
        json_block = parse_json_block(response)
        mu, lower, upper, confidence_level = float(json_block['coefficient']), float(json_block['lower_bound']), float(json_block['upper_bound']), float(json_block['confidence_level'])
        # mu, lower, upper = float(json_block['coefficient']), float(json_block['lower_bound']), float(json_block['upper_bound'])
        # confidence_level = 0.95
        # lower, upper, confidence_level = float(json_block['lower_bound']), float(json_block['upper_bound']), float(json_block['confidence_level'])
        # mu = (lower+ upper) / 2
        # sigma = (upper - lower) / 4
        # sigma = self.find_standard_deviation(lower, upper, confidence_level)
        z_mean, z_std = self.build_fisher_gaussian(lower, upper, confidence_level)
        combined_data = {'predicted_coef': mu, 'z_std': z_std, "z_mean": z_mean,  "confidence_level": confidence_level, 'usage': usage, 'response': response, 'response_json': json_block}
        return combined_data
    

