import numpy as np
import math
from src.correlation import Correlation
from src.utils import parse_json_block
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from src.llm_clients.base_llm_client import BaseLLMClient
from prompts.token_prob import BASE_TASK_DESC, USER_MSG_WITHOUT_JOIN_SINGLE_NUM

"""
Only support single number prediction. e.g., "+100", "-15", "+45"...
The output correlation coefficient needs to be a single number prefixed with + or -, multiplied by 100
"""

class TokenProbDistSingleToken:
    def __init__(self, agent: BaseLLMClient):
        self.agent = agent
        self.name = "token_prob_dist_single_token"
    
    def get_user_msg(self, correlation: Correlation, task_desc: str=BASE_TASK_DESC):
        self.correlation = correlation
        self.var1 = self.correlation.var1
        self.var2 = self.correlation.var2
    
        return USER_MSG_WITHOUT_JOIN_SINGLE_NUM.format(
            task_desc=task_desc,
            table=self.var1.table,
            tbl_desc=self.var1.table_desc,
            attr1=self.var1.attr,
            attr2=self.var2.attr,
            var1_desc=self.var1.var_desc,
            var2_desc=self.var2.var_desc
        )
    
    def get_distribution(self, logits):
        token_probs = [[], []]
        for i, choice in enumerate(logits):
            logprobs_data = choice.top_logprobs
            # print("Logprobs data:", logprobs_data)
            for token_prob in logprobs_data:
                token_probs[i].append((token_prob.token, token_prob.logprob))
        possible_values = []
        for text1, lp1 in token_probs[0]:
            for text2, lp2 in token_probs[1]:
                if text1 != '-' and text1 != '+':
                    continue
                combined_text = text1 + text2  # e.g. "+10", "-15", "010"...
                total_lp = lp1 + lp2

                # Try to parse as an integer or float
                try:
                    num_value = float(combined_text)  # e.g. +10.0
                except ValueError:
                    continue  # If it doesn't parse, skip it

                # Keep only if it's within [-100, +100]
                if -100.0 <= num_value <= 100.0:
                    possible_values.append((num_value/100, total_lp))

        possible_values.sort(key=lambda x: x[1], reverse=True)

        unnormalized_probs = [math.exp(lp) for (_, lp) in possible_values]
        normalizer = sum(unnormalized_probs)

        distribution = []
        for (val, lp) in possible_values:
            p = math.exp(lp) / normalizer
            distribution.append((val, p))
        return distribution
    
    def get_prior(self, correlation: Correlation):
        user_message = self.get_user_msg(correlation)
        print(user_message)
        response_msg, logits, usage = self.agent.call(user_message, with_log_prob=True)
        print(response_msg)
        distribution = self.get_distribution(logits)
        combined_data = {'distribution': distribution, 'usage': usage, 'response': response_msg}
        return combined_data
    
    def get_stats(self, distribution, r_obs):
        values, probs = zip(*distribution)
        mean = sum([val * prob for val, prob in distribution])
        variance = sum([prob * (val - mean) ** 2 for val, prob in distribution])
        std = math.sqrt(variance)
        # construct a continuous distribution using kde
        kde_dist = gaussian_kde(values, weights=probs, bw_method=0.3)
        # compute the density at the observed value
        density_obs = kde_dist(r_obs)[0]
        density_mean = kde_dist(mean)[0]
        probability_005 = kde_dist.integrate_box_1d(r_obs-0.05, r_obs+0.05)
        probability_01 = kde_dist.integrate_box_1d(r_obs-0.1, r_obs+0.1)
        probability_015 = kde_dist.integrate_box_1d(r_obs-0.15, r_obs+0.15)
        probability_02 = kde_dist.integrate_box_1d(r_obs-0.2, r_obs+0.2)
        return {
            'predicted_coef': mean,
            'predicted_std': std,
            'density_obs': density_obs,
            'density_mean': density_mean,
            'probability_005': round(probability_005, 3),
            'probability_01': round(probability_01, 3),
            'probability_015': round(probability_015, 3),
            'probability_02': round(probability_02, 3)
        }