
from src.priors.gaussian_prior import TruncatedGaussian, BASE_TASK_DESC
from src.correlation import Correlation
from src.utils import parse_json_block
import json
from prompts.reasoning import instruction_reasoning_parameters, prompt

class TruncatedGaussianReasoning(TruncatedGaussian):
    def __init__(self, agent):
        super().__init__(agent)
        self.name = "truncated_gaussian_with_reasoning"
    
    def get_user_msg(self, correlation: Correlation, task_desc: str=instruction_reasoning_parameters):
        self.correlation = correlation
        self.var1 = self.correlation.var1
        self.var2 = self.correlation.var2
        return prompt.format(
            task_desc=task_desc,
            table=self.var1.table,
            tbl_desc=self.var1.table_desc,
            attr1=self.var1.attr,
            attr2=self.var2.attr,
            var1_desc=self.var1.var_desc,
            var2_desc=self.var2.var_desc
        )
    
    def get_prior(self, correlation: Correlation):
        user_message = self.get_user_msg(correlation)
        print(user_message)
        response, usage = self.agent.call(user_message)
        json_block = parse_json_block(response)
        direction, pred_coef, pred_std = json_block['direction'], json_block['coefficient'], json_block['standard deviation']
        # pred_coef, pred_std = json_block['coefficient'], json_block['standard deviation']
        pred_coef, pred_std = float(pred_coef), float(pred_std)
        if direction == 'positive':
            pred_coef = abs(pred_coef)
        elif direction == 'negative':
            pred_coef = -abs(pred_coef)
        combined_data = {'predicted_coef': pred_coef, 'predicted_std': pred_std, 'usage': usage, 'response': response, 'response_json': json_block}
        return combined_data