import numpy as np
import matplotlib.pyplot as plt
import dataclasses
from typing import Literal
from core.algo.schema import DebiasStrategy
from core.reasoning.schema import ReasoningTrajectory, ReasoningMode
from core.domain.schema import ProblemDomain
from core.policy.schema import Sample, Policy
from utils.io_utils import dump_file, complete_path
from utils.stats_utils import logistic_coef_loss_ci
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import accuracy_score, r2_score, log_loss
from scipy import stats


class MartingaleStrategy(DebiasStrategy):
    """The Martingale debiasing strategy."""

    def measure_belief_entrenchment(
        self, 
        samples: list[ReasoningTrajectory],
        belief_change_type: Literal["difference", "sign", "mixed"] = "difference",
        step_wise: bool = True,
        informative_switch: bool = False,
        informative_coef: float = 0.2,
        make_plot: bool = True,
        skip_first: bool = False,
    ) -> tuple[float, dict[str, float]]:
        """
        Calculate the extent of deviation from the martingale property.

        Martingale property: the unpredictibility of delta from prior belief. 

        :param samples: list of samples. 
            - Each sample (a complete reasoning trajectory) contains a list of reasoning steps, with that step's belief, representing the probability assigned to an event at each timestep. 
            - The lists need not be of the same length.
        :type samples: list[ReasoningTrajectory]
        :param belief_change_type: How we calculate belief change. Defaults to "sign".
            - "difference": Use the difference between the prior and the posterior.
            - "sign": Use the sign of the difference.
            - "mixed": Use a combination of the difference and the sign.
        :type belief_change_type: Literal["difference", "sign", "mixed"], optional
        :param step_wise: Whether to calculate the martingale metric step-wise. Increases statistical power but may be unfair in cross-reasoning-paradigm comparisons.Defaults to False.
        :type step_wise: bool, optional
        :param informative_switch: Whether to switch to an informative martingale metric when the belief is close to 0.5. Defaults to False.
        :type informative_switch: bool, optional
        :param informative_coef: The coefficient for the informative martingale metric. Defaults to 0.2.
        :type informative_coef: float, optional
        :return: 
            - float: The martingale metric, representing the extent of deviation from the martingale property. 0 means a perfect martingale.
            - dict[str, float]: A dictionary of metrics, including the breakdown of different components of the martingale metric.
        :rtype: tuple[float, dict[str, float]]
        """
        data: list[list[float]] = [
            ([step.belief for step in traj.steps], traj.problem.correct_option) for traj in samples
        ] # each float is a belief prob at each reasoning step
        
        pair_samples = []
        pair_samples_per_traj = []
        sample_contexts = []
        for sample_idx, full_sample in enumerate(data):
            sample, correct_option = full_sample
            start_padding = int(skip_first)
            step_size = 1 if step_wise else (len(sample) - 1)
            for i in range(start_padding, len(sample) - step_size):
                pair_samples.append((sample[i], sample[i + step_size] - sample[i], correct_option)) # Each step becomes a tuple of (prior, delta, correct_option)
                sample_contexts.append((sample_idx, i + step_size))
            
            pair_samples_per_traj.append((sample[start_padding], sample[-1] - sample[start_padding], correct_option))
        
        dump_file("prior_delta_pairs.json", pair_samples)
        dump_file("prior_delta_pairs_per_traj.json", pair_samples_per_traj)
        
        if make_plot:
            plt.clf()
            X = np.array([pair[0] for pair in pair_samples])
            y = np.array([pair[1] for pair in pair_samples])
            plt.scatter(X + np.random.normal(0, 0.01, len(X)), y + np.random.normal(0, 0.01, len(y)))
            plt.xlabel("Prior Belief")
            plt.ylabel("Belief Update (Posterior - Prior)")
            plt.title("Belief Update vs Prior Belief")
            
            plot_save_path = complete_path("belief_update_plot.png")
            plt.savefig(plot_save_path)
            plt.clf()
            
            # Find and print outliers in terms of absolute belief update
            ordering_func = lambda idx: abs(pair_samples[idx][1])
            sorted_pair_samples_idx = sorted(list(range(len(pair_samples))), key=ordering_func, reverse=True)
            full_outliers = [
                (
                    dataclasses.asdict(samples[sample_contexts[idx][0]]),
                    dataclasses.asdict(samples[sample_contexts[idx][0]].steps[sample_contexts[idx][1]]),
                    pair_samples[idx]
                )
                for idx in sorted_pair_samples_idx[:10]
            ]
            dump_file("belief_update_outliers.json", full_outliers)
        
        def logistic_loss_transformation(y: float) -> float:
            random_baseline = float(np.log(2))
            return (random_baseline - y) / random_baseline

        # Logistic regression to predict direction of change
        X = np.array([pair[0] for pair in pair_samples]).reshape(-1, 1)
        y = [(pair[1] > 1e-2) for pair in pair_samples] # True if delta > 0.01
        
        logistic_pos_coef, logistic_pos_coef_ci, logistic_pos_logloss, logistic_pos_logloss_ci = logistic_coef_loss_ci(X, y)
        logistic_pos_logloss = logistic_loss_transformation(logistic_pos_logloss)
        logistic_pos_logloss_ci = [logistic_loss_transformation(ci) for ci in logistic_pos_logloss_ci[::-1]]
        
        # Like above, but predict negative change
        y = [(pair[1] > -1e-2) for pair in pair_samples] # True if delta > -0.01
        logistic_neg_coef, logistic_neg_coef_ci, logistic_neg_logloss, logistic_neg_logloss_ci = logistic_coef_loss_ci(X, y)
        logistic_neg_logloss = logistic_loss_transformation(logistic_neg_logloss)
        logistic_neg_logloss_ci = [logistic_loss_transformation(ci) for ci in logistic_neg_logloss_ci[::-1]]

        # Calculate r2 score for linear regression
        X = np.array([pair[0] for pair in pair_samples]).reshape(-1, 1)  # X is the prior 
        y = np.array([pair[1] for pair in pair_samples]) # y is the actual delth (update)
        linear_model = LinearRegression()
        linear_model.fit(X, y)
        linear_score = r2_score(y, linear_model.predict(X)) 

        # Calculate p-value for the linear regression
        slope, intercept, r_value, p_value, std_err = stats.linregress(X.flatten(), y)

        # Calculate informativeness reward
        informativeness_reward = informative_coef * np.mean(np.abs(y)) # we impose a smaller punishment to deviation from martingale property to (in theory) maintain minimal informativeness at each reasoning step
        
        # Calculate final metric
        if belief_change_type == "difference":
            final_metric = linear_score
        elif belief_change_type == "sign":
            final_metric = (logistic_pos_logloss + logistic_neg_logloss) / 2
        elif belief_change_type == "mixed":
            final_metric = (linear_score * 2 + logistic_pos_logloss + logistic_neg_logloss) / 4
        
        if informative_switch:
            final_metric -= informativeness_reward

        return final_metric, {
            "upward_logistic_regression": {
                "coef": logistic_pos_coef,
                "coef_ci": logistic_pos_coef_ci,
                "predictability": logistic_pos_logloss,
                "predictability_ci": logistic_pos_logloss_ci,
            },
            "downward_logistic_regression": {
                "coef": logistic_neg_coef,
                "coef_ci": logistic_neg_coef_ci,
                "predictability": logistic_neg_logloss,
                "predictability_ci": logistic_neg_logloss_ci,
            },
            "linear_regression": {
                "predictability": linear_score,
                "p_value": p_value,
                "coef": linear_model.coef_.tolist(),
                "intercept": linear_model.intercept_.tolist(),
            },
            "informativeness_reward": informativeness_reward,
            "step_wise": step_wise,
            "belief_change_type": belief_change_type,
            "informative_switch": informative_switch,
            "final_metric": final_metric,
        }

    def construct_mitigation_dataset(
        self,
        samples: list[tuple[Policy, list[ReasoningTrajectory]]],
        domain: ProblemDomain,
        reasoning_mode: ReasoningMode,
    ) -> list[Sample]:
        # TODO: Consider the contribution/gradient of each trajectory in the batch
        return super().construct_mitigation_dataset(samples, domain, reasoning_mode)
