from src.priors.gaussian_prior import TruncatedGaussian, BASE_TASK_DESC
from src.correlation import Correlation
from src.utils import parse_json_block
import json
from scipy.stats import truncnorm, dirichlet, norm

class GaussianMixture(TruncatedGaussian):
    def __init__(self, agent):
        super().__init__(agent)
        self.name = 'gaussian_mixture'

    def get_user_msg(self, correlation: Correlation, task_desc: str):
        return super().get_user_msg(correlation, task_desc)
    
    def get_prior(self, correlation: Correlation, num_components: int=10):
        # get a list of means and stds, i.e. a list of gaussian distributions
        # load task descriptions from json
        with open('prompts/truncated_gaussian_params_task_description_rephrasings.json') as f:
            rephrasings = json.load(f)
        all_task_descriptions = [BASE_TASK_DESC] + rephrasings
        task_descriptions = all_task_descriptions[:num_components]
        total_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
        mean_l, std_l, resp_l = [], [], []
        for task_desc in task_descriptions:
            user_message = self.get_user_msg(correlation, task_desc=task_desc)
            print(user_message)
            response, usage = self.agent.call(user_message)
            total_usage['input_tokens'] += usage['input_tokens']
            total_usage['output_tokens'] += usage['output_tokens']
            total_usage['total_tokens'] += usage['total_tokens']
            json_block = parse_json_block(response)
            pred_coef, pred_std = json_block['coefficient'], json_block['standard deviation']
            mean_l.append(round(float(pred_coef), 3))
            std_l.append(round(float(pred_std), 3))
            resp_l.append(response)
        combined_data = {'mean_l': mean_l, 'std_l': std_l, 'usage': total_usage, 'responses': resp_l}
        return combined_data
    
    def get_stats(self, r_obs, means, stds, weights=None, a=-1, b=1):
        K = len(means)
        if weights is None:
            weights = dirichlet.rvs([1] * K)[0]
        density_obs = 0
        density_mean = 0
        probability_005 = 0
        probability_01 = 0
        probability_015 = 0
        probability_02 = 0
        for w, r_pred, sigma in zip(weights, means, stds):
            scale_factor = 1 / (norm.cdf(1, loc=r_pred, scale=sigma) - norm.cdf(-1, loc=r_pred, scale=sigma))
            density_obs += w * norm.pdf(r_obs, loc=r_pred, scale=sigma) * scale_factor
            density_mean += w * norm.pdf(r_pred, loc=r_pred, scale=sigma) * scale_factor
            probability_005 += w * self.get_prob(r_obs, 0.05, r_pred, sigma, scale_factor)
            probability_01 += w * self.get_prob(r_obs, 0.1, r_pred, sigma, scale_factor)
            probability_015 += w * self.get_prob(r_obs, 0.15, r_pred, sigma, scale_factor)
            probability_02 += w * self.get_prob(r_obs, 0.2, r_pred, sigma, scale_factor)
            # # Convert bounds to scaled space
            # a_scaled = (a - r_pred) / sigma
            # b_scaled = (b - r_pred) / sigma
            # # Compute densities
            # density_obs += w * truncnorm.pdf(r_obs, a_scaled, b_scaled, loc=r_pred, scale=sigma)
            # density_mean += w * truncnorm.pdf(r_pred, a_scaled, b_scaled, loc=r_pred, scale=sigma)
            # probability_005 += w * (truncnorm.cdf(r_obs+0.05, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.05, a_scaled, b_scaled, loc=r_pred, scale=sigma))
            # probability_01 += w * (truncnorm.cdf(r_obs+0.1, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.1, a_scaled, b_scaled, loc=r_pred, scale=sigma))
            # probability_015 += w * (truncnorm.cdf(r_obs+0.15, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.15, a_scaled, b_scaled, loc=r_pred, scale=sigma))
            # probability_02 += w * (truncnorm.cdf(r_obs+0.2, a_scaled, b_scaled, loc=r_pred, scale=sigma) - truncnorm.cdf(r_obs-0.2, a_scaled, b_scaled, loc=r_pred, scale=sigma))
        return {
            'density_obs': density_obs,
            'density_mean': density_mean,
            'probability_005': round(probability_005, 3),
            'probability_01': round(probability_01, 3),
            'probability_015': round(probability_015, 3),
            'probability_02': round(probability_02, 3)
        }