import numpy as np
import math
import collections
from src.correlation import Correlation
from src.utils import parse_json_block, get_range
from scipy.stats import gaussian_kde
from src.llm_clients.base_llm_client import BaseLLMClient
from prompts.token_prob import BASE_TASK_DESC, USER_MSG_WITHOUT_JOIN, USER_MSG_CHECK_SEMANTIC_GROUNDING, USER_MSG_DISAMBIGUATE_AND_SEMANTIC_GROUNDING, USER_MSG_WITH_NEW_CONTEXT
import ast
from scipy.integrate import quad
from scipy.optimize import brentq

"""
Support output format:
```json
{{
  "coefficient": "<predicted correlation coefficient>",
}}
```
The predicted coefficient is a floating point number. 
GPT-4o gives better predictions when asking it to output the response in this format 
compared to outputting the coefficient as a single number.
"""

class KDEPrior:
    def __init__(self, agent: BaseLLMClient, ignore_zero=False):
        self.agent = agent
        self.name = "kde_prior"
        self.check_semantic_grounding = False
        self.semantic_grounding_and_disambiguate = False
        self.use_new_context = False
        self.ignore_zero = ignore_zero
    
    def get_user_msg(self, correlation: Correlation, task_desc: str=BASE_TASK_DESC):
        self.correlation = correlation
        self.var1 = self.correlation.var1
        self.var2 = self.correlation.var2

        if self.check_semantic_grounding:
            return USER_MSG_CHECK_SEMANTIC_GROUNDING.format(
                task_desc=task_desc,
                table=self.var1.table,
                tbl_desc=self.var1.table_desc,
                attr1=self.var1.attr,
                attr2=self.var2.attr,
                var1_desc=self.var1.var_desc,
                var2_desc=self.var2.var_desc
            )
        elif self.semantic_grounding_and_disambiguate:
            return USER_MSG_DISAMBIGUATE_AND_SEMANTIC_GROUNDING.format(
                task_desc=task_desc,
                table=self.var1.table,
                tbl_desc=self.var1.table_desc,
                attr1=self.var1.attr,
                attr2=self.var2.attr,
                var1_desc=self.var1.var_desc,
                var2_desc=self.var2.var_desc
            )
        elif self.use_new_context:
            return USER_MSG_WITH_NEW_CONTEXT.format(
                task_desc=task_desc,
                context=self.correlation.new_context,
                table=self.var1.table,
                tbl_desc=self.correlation.new_context,
                attr1=self.var1.attr,
                attr2=self.var2.attr,
                var1_desc=self.var1.var_desc,
                var2_desc=self.var2.var_desc
            )
        else:
            return USER_MSG_WITHOUT_JOIN.format(
                task_desc=task_desc,
                table=self.var1.table,
                tbl_desc=self.var1.table_desc,
                attr1=self.var1.attr,
                attr2=self.var2.attr,
                var1_desc=self.var1.var_desc,
                var2_desc=self.var2.var_desc
            )
    
    def _detect_co_efficient_colon(self, tokens):
        """
        Returns the index of the *first* occurrence where
        the consecutive tokens are exactly ['co', 'efficient', ':'].
        If not found, returns None.
        """
        pattern = ["co", "efficient", '":']
        pattern_len = len(pattern)
        
        # We only need to scan until len(tokens) - pattern_len
        for i in range(len(tokens) - pattern_len + 1):
            # Slice out the next pattern_len tokens
            window = [t for t in tokens[i : i + pattern_len]]
            if window == pattern:
                return i  # Found it
        return None  # Not found
    
    def get_relevent_tokens(self, logits):
        tokens = [choice.token for choice in logits]
        # print(tokens)
        idx = self._detect_co_efficient_colon(tokens)
        # print("idx:", idx)
        logprobs = [choice.top_logprobs for choice in logits]
        relevant_logprobs = logprobs[idx+3:idx+7]
        # print(relevant_logprobs)
        sign_token = []
        for logprob in relevant_logprobs[0]:
            # trim the leading ' "'
            token = logprob.token.strip().lstrip('"')
            sign_token.append((token, logprob.logprob))

        # print("Sign Token:")
        # for k, v in sign_token:
        #     print("token:", k, '->', v)

        unit_token = []
        for logprob in relevant_logprobs[1]:
            token = logprob.token.strip()
            unit_token.append((token, logprob.logprob))
        # print("unit Token:")
        # for k, v in unit_token:
        #     print("token:", k, '->', v)

        coef_token = []
        for logprob in relevant_logprobs[3]:
            token = logprob.token.strip().strip('\n').strip('\r')
            if token.endswith('}'):
                token = token.strip('}')
            coef_token.append((token, logprob.logprob))
        # print("coef Token:")
        # for k, v in coef_token:
        #     print("token:", k, '->', v)
        return sign_token, unit_token, coef_token
    
    def get_distribution(self, logits):
        sign_token, unit_token, coef_token = self.get_relevent_tokens(logits)
        possible_values = []
        for sign, lp1 in sign_token:
            for unit, lp2 in unit_token:
                for coef, lp3 in coef_token:
                    while len(coef) < 2:
                        coef = coef + '0'
                    if sign in ('+', '-', '') and unit == '.':
                        unit = '0'
                    combined_text = sign + unit + coef 
                    # print((sign, unit, coef), '->', combined_text)
                    total_lp = lp1 + lp2 + lp3
                    try:
                        num_value = float(combined_text)  # e.g. +10.0
                        # print(num_value)
                    except ValueError:
                        continue  # If it doesn't parse, skip it

                    # Keep only if it's within [-100, +100]
                    if -100.0 <= num_value <= 100.0:
                        possible_values.append((num_value/100, total_lp))
        # print(possible_values)
        # create unnormalized probs
        distribution = collections.defaultdict(float)
        for value, lp in possible_values:
            distribution[value] += math.exp(lp)
        # print(distribution)
        normailizer = sum(distribution.values())
        # print("Normalizer:", normailizer)
        for k, v in distribution.items():
            distribution[k] = v / normailizer
        return list(distribution.items()), normailizer
    
    def get_prior(self, correlation: Correlation, temp=0.0, with_distribution=True):
        user_message = self.get_user_msg(correlation)
        # print(user_message)
        if with_distribution:
            response_msg, logits, usage = self.agent.call(user_message, with_log_prob=True, temp=temp)
            # print(response_msg)
        else:
            response_msg, usage = self.agent.call(user_message, with_log_prob=False, temp=temp)
        # print(response_msg)
        try:
            json_block = parse_json_block(response_msg)
        except Exception as e:
            # Raise an error to notify the caller that JSON parsing failed.
            raise ValueError(f"Error parsing JSON block from response: {e}")
        pred_coef = json_block['coefficient']
        normalizer = 0.0
        if with_distribution:
            distribution, normalizer = self.get_distribution(logits)
        else:
            distribution = []
        if self.check_semantic_grounding:
            semantic_grounding = json_block['semantic_grounding']
            combined_data = {'distribution': distribution, 'normalizer': normalizer, 'predicted_coef': pred_coef, 'usage': usage, 'response': response_msg, 'semantic_grounding': semantic_grounding}
        elif self.semantic_grounding_and_disambiguate:
            defaulting_to_zero = json_block['defaulting_to_zero']
            semantic_grounding = json_block['semantic_grounding']
            combined_data = {'distribution': distribution, 'normalizer': normalizer, 'predicted_coef': pred_coef, 'usage': usage, 'response': response_msg, 'defaulting_to_zero': defaulting_to_zero, 'semantic_grounding': semantic_grounding}
        else:
            combined_data = {'distribution': distribution, 'normalizer': normalizer, 'predicted_coef': pred_coef, 'usage': usage, 'response': response_msg}
        return combined_data
 
    def get_stats(self, distribution, r_obs, bw=0.3):
        # sort distribution by probs
        distribution.sort(key=lambda x: x[1], reverse=True)
        most_probable = distribution[0][0]
        values, probs = zip(*distribution)
        mean = sum([val * prob for val, prob in distribution])
        variance = sum([prob * (val - mean) ** 2 for val, prob in distribution])
        std = math.sqrt(variance)
        # construct a continuous distribution using kde
        if len(values) == 1:
            print("Only one value in the distribution")
            return {
                'most_probable': most_probable, 
                'dist_mean': mean,
                'predicted_std': std,
                'density_obs': -1,
                'density_mean': -1,
                'probability_005': -1,
                'probability_01': -1,
                'probability_015': -1,
                'probability_02': -1
            }
        kde_dist = gaussian_kde(np.asarray(values), weights=np.asarray(probs), bw_method=bw)
        # compute the density at the observed value
        scale_factor = 1/kde_dist.integrate_box_1d(-1, 1)
        # print("scaling factor:", 1/scale_factor)
        density_obs = kde_dist.pdf(r_obs) * scale_factor
        density_mean = kde_dist.pdf(mean) * scale_factor
        # print(density_obs)
        probability_005 = kde_dist.integrate_box_1d(*get_range(r_obs, 0.05)) * scale_factor
        probability_01 = kde_dist.integrate_box_1d(*get_range(r_obs, 0.1)) * scale_factor
        probability_015 = kde_dist.integrate_box_1d(*get_range(r_obs, 0.15)) * scale_factor
        probability_02 = kde_dist.integrate_box_1d(*get_range(r_obs, 0.2)) * scale_factor
        return {
            'most_probable': most_probable, 
            'dist_mean': mean,
            'predicted_std': std,
            'density_obs': density_obs[0],
            'density_mean': density_mean[0],
            'probability_005': round(probability_005, 3),
            'probability_01': round(probability_01, 3),
            'probability_015': round(probability_015, 3),
            'probability_02': round(probability_02, 3)
        }

    def get_summary_stats(self, distribution, bw=None, q_vals=[0.025, 0.975]):
        distribution = ast.literal_eval(distribution)
        values, probs = zip(*distribution)
        if len(values) == 1 and values[0] == 0:
            return 0.0, -1+q_vals[0]*2, 1-q_vals[0]*2
        
        if self.ignore_zero:
            if values[0] == 0:
                values = values[1:]
                probs = probs[1:]
                total_prob = sum(probs)
                probs = [prob/total_prob for prob in probs]

        kde_dist = gaussian_kde(np.asarray(values), weights=np.asarray(probs))

        # Create a fine grid over the interval [-1, 1]
        x_grid = np.linspace(-1, 1, 1000)

        # Evaluate the KDE on the grid using the evaluate() method.
        # This returns a 1D array of density values.
        pdf_values = kde_dist.evaluate(x_grid)

        # Compute the normalization constant over the interval [-1, 1] using the trapezoidal rule.
        norm_const = np.trapz(pdf_values, x_grid)

        # Define the truncated (and normalized) PDF on the grid.
        pdf_truncated = pdf_values / norm_const

        # Compute the CDF via cumulative integration.
        dx = x_grid[1] - x_grid[0]
        cdf_values = np.cumsum(pdf_truncated) * dx

        # Calculate the mean using vectorized integration.
        mean_val = np.trapz(x_grid * pdf_truncated, x_grid)

        xs = np.linspace(-1, 1, 1000)
        mode_val = xs[np.argmax(kde_dist(xs))]

        # Compute quantiles using interpolation on the precomputed CDF.
        # np.interp finds for each quantile probability (q) the corresponding x value.
        quantiles = {q: np.interp(q, cdf_values, x_grid) for q in q_vals}

        return mode_val, mean_val, quantiles[q_vals[0]], quantiles[q_vals[1]]
    
    def get_bandwidth(self, distribution):
        distribution = ast.literal_eval(distribution)
        values, probs = zip(*distribution)
        if len(values) == 1 and values[0] == 0:
            return 0.5
        if self.ignore_zero:
            if values[0] == 0:
                values = values[1:]
                probs = probs[1:]
                # normalize probs
                total_prob = sum(probs)
                probs = [prob/total_prob for prob in probs]
        kde_dist = gaussian_kde(np.asarray(values), weights=np.asarray(probs))
        return kde_dist.factor * math.sqrt(kde_dist._data_covariance[0][0])
    
    def get_density_at(self, r_obs, distribution, bw=None):
        distribution = ast.literal_eval(distribution)
        values, probs = zip(*distribution)
        if len(values) == 1 and values[0] == 0:
            return 0.5
        if self.ignore_zero:
            if values[0] == 0:
                values = values[1:]
                probs = probs[1:]
                # normalize probs
                total_prob = sum(probs)
                probs = [prob/total_prob for prob in probs]

        kde_dist = gaussian_kde(np.asarray(values), weights=np.asarray(probs))
        scale_factor = 1/kde_dist.integrate_box_1d(-1, 1)
        density_obs = kde_dist.pdf(r_obs) * scale_factor
        return density_obs[0]

    def get_ess(self, distribution):
        values, probs = zip(*ast.literal_eval(distribution))
        values = np.asarray(values, dtype=float)
        probs  = np.asarray(probs,  dtype=float)

        # degenerate prior
        if values.size == 1 and values[0] == 0:
            return 0.5

        # drop the explicit zero bucket if presen
        if self.ignore_zero and values[0] == 0:
            values, probs = values[1:], probs[1:]
        
        probs /= probs.sum()                      
        ess = np.sum(probs)**2 / np.sum(probs ** 2)  # effective sample size
        return ess