import numpy as np
from scipy.stats import truncnorm, norm
from src.correlation import Correlation
from src.utils import parse_json_block, get_range
import matplotlib.pyplot as plt
from src.llm_clients.base_llm_client import BaseLLMClient
from prompts.gaussian import BASE_TASK_DESC, USER_MSG_WITH_JOIN, USER_MSG_WITHOUT_JOIN

"""
truncated gaussian prior
The current implementation uses only one LLM call
"""

class TruncatedGaussian:
    def __init__(self, agent: BaseLLMClient):
        self.agent = agent
        self.name = "gaussian_prior"
    
    def get_user_msg(self, correlation: Correlation, task_desc: str=BASE_TASK_DESC):
        self.correlation = correlation
        self.var1 = self.correlation.var1
        self.var2 = self.correlation.var2
        # If tables are different or join key is provided, use join message
        if self.var1.table != self.var2.table or self.correlation.join_key:
            return USER_MSG_WITH_JOIN.format(
                attr1=self.var1.attr,
                table1=self.var1.table,
                attr2=self.var2.attr,
                table2=self.var2.table,
                join_key=self.correlation.join_key,
                granu=self.correlation.granu
            )
        # Otherwise use message for a single table
        else:
            return USER_MSG_WITHOUT_JOIN.format(
                task_desc=task_desc,
                table=self.var1.table,
                tbl_desc=self.var1.table_desc,
                attr1=self.var1.attr,
                attr2=self.var2.attr,
                var1_desc=self.var1.var_desc,
                var2_desc=self.var2.var_desc
            )

    def get_prior(self, correlation: Correlation, temp: float=0.0):
        user_message = self.get_user_msg(correlation)
        # print(user_message)
        response, usage = self.agent.call(user_message, temp=temp)
        # print(response)
        json_block = parse_json_block(response)
        pred_coef, pred_std = json_block['coefficient'], json_block['standard deviation']
        combined_data = {'predicted_coef': float(pred_coef), 'predicted_std': float(pred_std), 'usage': usage, 'response': response}
        return combined_data    
    
    def get_prob(self, r_obs, error_bound, r_pred, sigma, scale_factor):
        lower, upper = get_range(r_obs, error_bound)
        return (norm.cdf(upper, loc=r_pred, scale=sigma) - norm.cdf(lower,loc=r_pred, scale=sigma))*scale_factor
        
    def get_stats(self, r_obs, r_pred, sigma, a=-1, b=1):
        scale_factor = 1 / (norm.cdf(1, loc=r_pred, scale=sigma) - norm.cdf(-1, loc=r_pred, scale=sigma))
        # print(scale_factor)
        density_obs = norm.pdf(r_obs, loc=r_pred, scale=sigma) * scale_factor
        density_mean = norm.pdf(r_pred, loc=r_pred, scale=sigma) * scale_factor
        probability_005 = self.get_prob(r_obs, 0.05, r_pred, sigma, scale_factor)
        probability_01 = self.get_prob(r_obs, 0.1, r_pred, sigma, scale_factor)
        probability_015 = self.get_prob(r_obs, 0.15, r_pred, sigma, scale_factor)
        probability_02 = self.get_prob(r_obs, 0.2, r_pred, sigma, scale_factor)
        # Convert bounds to scaled space
        # a_scaled = (a - r_pred) / sigma
        # b_scaled = (b - r_pred) / sigma
        # # Compute densities
        # density_obs = truncnorm.pdf(r_obs, a_scaled, b_scaled, loc=r_pred, scale=sigma)
        # density_mean = truncnorm.pdf(r_pred, a_scaled, b_scaled, loc=r_pred, scale=sigma)
        # probability_005 = truncnorm.cdf(r_obs+0.05, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.05, a_scaled, b_scaled, loc=r_pred, scale=sigma)
        # probability_01 = truncnorm.cdf(r_obs+0.1, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.1, a_scaled, b_scaled, loc=r_pred, scale=sigma)
        # probability_015 = truncnorm.cdf(r_obs+0.15, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.15, a_scaled, b_scaled, loc=r_pred, scale=sigma)
        # probability_02 = truncnorm.cdf(r_obs+0.2, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.2, a_scaled, b_scaled, loc=r_pred, scale=sigma)
        # return a dictionary
        return {
            'density_obs': density_obs,
            'density_mean': density_mean,
            'probability_005': round(probability_005, 3),
            'probability_01': round(probability_01, 3),
            'probability_015': round(probability_015, 3),
            'probability_02': round(probability_02, 3)
        }
    
    def get_summary_stats(self, predicted_coef, predicted_std, q_vals=[0.025, 0.975]):
        # Define the truncation bounds
        lower, upper = -1, 1
        # Convert the bounds to the standardized form required by truncnorm
        predicted_std = max(predicted_std, 1e-5)  # Avoid division by zero
        a = (lower - predicted_coef) / predicted_std
        b = (upper - predicted_coef) / predicted_std
        
        # Create the truncated normal distribution
        dist = truncnorm(a, b, loc=predicted_coef, scale=predicted_std)
        
        # Get the mean of the truncated distribution
        mean_val = dist.mean()
        
        # Compute the quantiles using the percent point function (ppf)
        quantiles = dist.ppf(q_vals)
    
        return mean_val, quantiles[0], quantiles[1]
    
    def get_density_at(self, r_obs, predicted_coef, predicted_std):
        # Define the truncation bounds
        lower, upper = -1, 1
        # Avoid division by zero
        predicted_std = max(predicted_std, 1e-5)
        # Convert the bounds to the standardized form required by truncnorm
        a = (lower - predicted_coef) / predicted_std
        b = (upper - predicted_coef) / predicted_std
        
        # Create the truncated normal distribution
        dist = truncnorm(a, b, loc=predicted_coef, scale=predicted_std)
        
        # Get the density at r_obs
        density_value = dist.pdf(r_obs)
        
        return density_value

    def plot(self, r_pred, sigma, a=-1, b=1):
        x = np.linspace(a, b, 100)
        y = truncnorm.pdf(x, a, b, loc=r_pred, scale=sigma)
        plt.plot(x, y)
        plt.title(f"Truncated Gaussian Prior: mean={r_pred}, std={sigma}")
        plt.xlabel("r")
        plt.ylabel("Density")
        plt.show()
    