from collections import defaultdict
import numpy as np
import pandas as pd
from scipy.optimize import root_scalar
import multiprocessing as mp
import logging
import os
import time
import traceback
from tqdm.auto import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import CategoricalNB
import matplotlib.pyplot as plt
import warnings
import pickle


# %aimport utils.data_sampler, utils.ada_grad, utils.plotting
from utils.data_sampler import (
    CivilCommentsDataSampler,
    MathJudgeLabelShiftDataSampler,
    MathJudgeConceptShiftDataSampler,
    SharedDependencyDataSampler,
    SharedDependencyDataSampler_b,
    SharedMarginalXDataSampler,
    LabelShiftDataSampler,
    MathSharedDependencyDataSampler,
)
from utils.ada_grad import AdaGrad
from utils.plotting import plot_power_vs_steps, get_power_vs_steps

CONCEPT_SHIFT_SETTINGS = {
    "shared_dependency",
    "shared_marginal_x",
    "math_shared_dependency",
    "math_judge_concept_shift",
}


###################################################################
def learn_naive_bayes(df):
    # Reshape X and y
    X = df[["x"]]  # features as DataFrame
    y = df["y"]  # labels

    # Naive Bayes for discrete features
    model = CategoricalNB()
    model.fit(X, y)

    # Get conditional probability P(y=1|x)
    x_values = np.sort(df["x"].unique()).reshape(-1, 1)
    prob_y_given_x = model.predict_proba(x_values)

    # Return the model
    return model


###################################################################
def calc_phi(preds, phi, phi_prm=2):

    if phi == "mean":
        out = preds.mean()
    elif phi == "LR":

        # s = preds.sum()
        # n = len(preds)
        # p_null = 0.4
        # p = p_null+0.05
        out = [
            (p_1**pred)
            * ((1 - p_1) ** (1 - pred))
            / ((p_null) ** pred * (1 - p_null) ** (1 - pred))
            for pred in preds
        ]
    elif phi == "exp":
        out = list(np.exp(phi_prm * preds))
    elif phi == "oracle":
        out = [1 if pred else 0.01 for pred in preds]
        # out = [np.exp(2) if pred else 1 for pred in preds]
    else:
        out = [1 if pred else phi for pred in preds]
    return out


##############################################################################
def estimate_binary_probabilities(df):
    """
    Estimates the probabilities from a DataFrame of binary samples (X, Y):
      - P(Y = 1)
      - P(X = 1 | Y = 0)
      - P(X = 1 | Y = 1)

    Parameters:
        df (pd.DataFrame): A DataFrame with binary columns 'X' and 'Y'.

    Returns:
        dict: A dictionary with estimated probabilities.
    """
    if not {"x", "y"}.issubset(df.columns):
        raise ValueError("DataFrame must contain 'X' and 'Y' columns.")

    # P(Y = 1)
    p_y1 = df["y"].mean()

    # P(X = 1 | Y = 0)
    df_y0 = df[df["y"] == 0]
    p_x1_given_y0 = df_y0["x"].mean() if not df_y0.empty else float("nan")

    # P(X = 1 | Y = 1)
    df_y1 = df[df["y"] == 1]
    p_x1_given_y1 = df_y1["x"].mean() if not df_y1.empty else float("nan")

    return {
        "P(Y=1)": p_y1,
        "P(X=1 | Y=0)": p_x1_given_y0,
        "P(X=1 | Y=1)": p_x1_given_y1,
    }


###################################################################
def classify_x_values(x_values, p_y1, p_x1_given_y0, p_x1_given_y1):
    """
    Classify each x in x_values to Y=0 or Y=1 using Bayes rule:
    argmax(P(Y=0)*P(x|Y=0), P(Y=1)*P(x|Y=1))

    Parameters:
        x_values (array-like): A list or numpy array of binary X values (0 or 1).
        p_y1 (float): Estimated P(Y=1).
        p_x1_given_y0 (float): Estimated P(X=1 | Y=0).
        p_x1_given_y1 (float): Estimated P(X=1 | Y=1).

    Returns:
        np.ndarray: Predicted labels (0 or 1) for each x.
    """
    import numpy as np

    p_y0 = 1 - p_y1
    p_x0_given_y0 = 1 - p_x1_given_y0
    p_x0_given_y1 = 1 - p_x1_given_y1

    x_values = np.asarray(x_values)

    # Compute likelihoods
    likelihood_y0 = np.where(x_values == 1, p_x1_given_y0, p_x0_given_y0) * p_y0
    likelihood_y1 = np.where(x_values == 1, p_x1_given_y1, p_x0_given_y1) * p_y1

    # Predict Y=1 if likelihood_y1 > likelihood_y0
    predictions = (likelihood_y1 > likelihood_y0).astype(int)

    return predictions


#########################################################################################
def generate_observations(n, mu, sigma, a, b, rng_in):
    """
    Generate n samples from a Normal distribution and corresponding Bernoulli outcomes.

    Parameters:
        n (int): Number of observations
        mu (float): Mean of the normal distribution
        sigma (float): Standard deviation of the normal distribution
        a (float): Intercept parameter for Bernoulli probability (must be positive)
        b (float): Slope parameter for Bernoulli probability (must be positive)
        random_state (int, optional): Random seed for reproducibility

    Returns:
        pd.DataFrame: DataFrame with two columns 'x' (Normal samples) and 'y' (Bernoulli outcomes)
    """
    # Step 1: Draw samples from Normal distribution

    x = rng_in.normal(loc=mu, scale=sigma, size=n)

    # Step 2: Compute Bernoulli probabilities
    logits = a + b * x
    probs = 1 / (1 + np.exp(-logits))

    # Step 3: Draw binary outcomes from Bernoulli
    y = rng.binomial(n=1, p=probs)

    # Return as DataFrame
    return pd.DataFrame({"x": x, "y": y})


###################################################################
def generate_observations_x_ber(n, p_y, p_z, alpha_0, k_classes, rng_in):
    """
    Generate n samples from a Normal distribution and corresponding Bernoulli outcomes.

    Parameters:
        n (int): Number of observations
        p_x: P(X=1)
        a (float): Intercept parameter for Bernoulli probability (must be positive)
        b (float): Slope parameter for Bernoulli probability (must be positive)
        random_state (int, optional): Random seed for reproducibility

    Returns:
        pd.DataFrame: DataFrame with two columns 'x' (Normal samples) and 'y' (Bernoulli outcomes)
    """
    # Step 1: Draw samples from Binomial distribution

    y = rng_in.binomial(k_classes, p_y, size=n)
    z = rng_in.binomial(k_classes, p_z, size=n)

    # Step 2: Compute Bernoulli probabilities
    # logits = (a + b * x)
    # probs = 1 / (1+np.exp(-logits))

    # Step 3: Draw binary outcomes from Bernoulli
    # y = rng_in.binomial(n=k_classes, p=probs)
    x = np.empty(n, dtype=int)

    choose_y = rng_in.random(size=n) <= alpha_0
    x[choose_y] = y[choose_y]
    x[~choose_y] = z[~choose_y]

    # Return as DataFrame
    return pd.DataFrame({"x": x, "y": y})


##############################################################################################################
def find_pz_for_target_mean(p_y, p_x, alpha_0):
    # x = alpha_0 * y + (1-alpha_0)*z
    # p_x = alpha_0 * p_y + (1-alpha_0)*p_z
    # p_y = (p_x - (1-alpha_0)*p_z)/alpha_0
    return (p_x - alpha_0 * p_y) / (1 - alpha_0)


def find_a_for_target_mean(n, p_x, k_classes_x, b, target_mean=0.5, rng_in=None):
    def objective(a):
        df = generate_observations_x_ber(
            n=n, p_x=p_x, a=a, b=b, k_classes=k_classes_x, rng_in=rng_in
        )
        return df["y"].mean() - target_mean

    result = root_scalar(objective, bracket=[-10, 10], method="brentq")
    return result.root if result.converged else None


###############################################################################################################
def get_correction_factor_oracle(a_null, b_null, n_test, phi, M, p_x_null):
    import numpy as np
    import pandas as pd

    phi_under_null = []

    rng = np.random.default_rng()

    for t in range(M):
        # lets draw samples from the marignal of X directly:
        x_samples = rng.binomial(1, p_x_null, size=n_test)
        logits = a_null + b_null * x_samples
        p_y_given_x = 1 / (1 + np.exp(-logits))
        predictions = [1 if p > 0.5 else 0 for p in p_y_given_x]
        C = calc_phi(predictions, phi)
        # append the array with \phi(f(x_1),...,f(x_n)) = \sum_i \phi(f(x_i))
        phi_under_null.append(np.array(C).sum())

    return np.array(phi_under_null).sum()


##########################################################################################################################


# {e_value : np.array(phi_under_null).sum() for e_value, phi_under_null in e_value_to_psi_under_null.items()}
# np.array(phi_under_null).sum()
########################################################################################################
def calc_statistic_finite_sample_oracle(
    a_alt, b_alt, batch_size, normalizing_constant, phi, rng_in, M, p_x_null
):
    import numpy as np
    import pandas as pd

    x_samples = rng_in.binomial(1, p_x_null, size=n_test)
    logits = a_alt + b_alt * x_samples
    p_y_given_x = 1 / (1 + np.exp(-logits))
    predictions = [1 if p > 0.5 else 0 for p in p_y_given_x]
    # We consider statistic of the form: \sum_i phi(f(x_i))/E_{D, \prod_i x_i}[\sum_i phi(f(x_i)]
    t = calc_phi(predictions, phi)

    X_batch_data = pd.DataFrame({"X": x_samples})

    out = {
        "ML_statistic": (M + 1)
        * np.array(t).sum()
        / (np.array(t).sum() + normalizing_constant),
        "X_batch_data": X_batch_data,
    }
    return out


##################################################################################################################################
class StatisticCalcualtor:
    def __init__(
        self,
        n_training,
        n_test,
        null_sampler,
        alt_smapler,
        M,
        p_y_null,
        alg,
        phi_prm=2,
        phi="exp",
        agg="sum",
    ):
        self.n_training = n_training
        self.n_test = n_test
        self.null_sampler = null_sampler
        self.alt_sampler = alt_smapler
        self.M = M
        self.alg = alg
        self.phi = phi
        self.phi_prm = phi_prm
        self.agg = agg
        self.p_y_null = p_y_null
        self.ada_grad = AdaGrad(weight=self.phi_prm)

    def set_rng(self, rng):
        self.rng_np = rng
        self.null_sampler.set_rng(rng)
        self.alt_sampler.set_rng(rng)
        return self

    @classmethod
    def of(cls, alg, config, seed):
        rng = np.random.default_rng(seed=seed)
        if config["settings"] == "shared_dependency":
            null_sampler = SharedDependencyDataSampler_b(
                p_y_null=config["p_y_null"],
                p_y_alt=config["p_y_alt"],
                p_x=config["p_x"],
                rho_XY1=config["rho_XY1"],
                rho_XY2=None,
                rho_Y1Y2=None,
                bNull=True,
            )
            alt_sampler = SharedDependencyDataSampler_b(
                p_y_null=config["p_y_null"],
                p_y_alt=config["p_y_alt"],
                p_x=config["p_x"],
                rho_XY1=config["rho_XY1"],
                rho_XY2=None,
                rho_Y1Y2=None,
                bNull=False,
            )
        elif config["settings"] == "label_shift":
            null_sampler = LabelShiftDataSampler(
                p_y=config["p_y_null"],
                p_x_y_0=config["p_x_y_0"],
                p_x_y_1=config["p_x_y_1"],
            )
            alt_sampler = LabelShiftDataSampler(
                p_y=config["p_y_alt"],
                p_x_y_0=config["p_x_y_0"],
                p_x_y_1=config["p_x_y_1"],
            )
            alt_sampler.T_0 = null_sampler.T_0 = null_sampler.tau_0
            alt_sampler.T_1 = null_sampler.T_1 = null_sampler.tau_1
        elif config["settings"] == "shared_marginal_x":
            null_sampler = SharedMarginalXDataSampler(
                p_x=config["p_x"],
                p_y_x_0=config["p_y_x_0_null"],
                p_y_x_1=config["p_y_x_1_null"],
            )
            alt_sampler = SharedMarginalXDataSampler(
                p_x=config["p_x"],
                p_y_x_0=config["p_y_x_0_alt"],
                p_y_x_1=config["p_y_x_1_alt"],
            )
        elif config["settings"] == "civil_comments":
            null_sampler = CivilCommentsDataSampler(
                hypothesis="null", target_mean=config["p_y_null"], seed=seed
            )
            alt_sampler = CivilCommentsDataSampler(
                hypothesis="alt", target_mean=config["p_y_alt"], seed=seed
            )
            alt_sampler.T_0 = null_sampler.T_0
            alt_sampler.T_1 = null_sampler.T_1
        elif config["settings"] == "math_judge_label_shift":
            null_sampler = MathJudgeLabelShiftDataSampler(
                # model_name="deepseek-ai_DeepSeek-R1-Distill-Qwen-7B",
                model_name="deepseek-ai_DeepSeek-R1-Distill-Llama-8B",
                seed=seed,
                target_mean=config.get("p_y_null"),
                target_mean_x=config.get("p_x_null"),
            )
            alt_sampler = MathJudgeLabelShiftDataSampler(
                "deepseek-ai_DeepSeek-R1-Distill-Qwen-7B",
                hypothesis="alt",
                seed=seed,
                target_mean=config.get("p_y_alt"),
                target_mean_x=config.get("p_x_alt"),
            )
            alt_sampler.T_0 = null_sampler.T_0 = null_sampler.tau_0
            alt_sampler.T_1 = null_sampler.T_1 = null_sampler.tau_1
            logging.info(f"{null_sampler.data['y'].mean()=}")
            logging.info(f"{alt_sampler.data['y'].mean()=}")
        elif config["settings"] == "math_judge_label_shift_validity":
            null_sampler = MathJudgeLabelShiftDataSampler(
                model_name="deepseek-ai_DeepSeek-R1-Distill-Llama-8B",
                seed=seed,
            )
            alt_sampler = MathJudgeLabelShiftDataSampler(
                model_name="deepseek-ai_DeepSeek-R1-Distill-Llama-8B",
                hypothesis="alt",
                seed=seed,
            )
            alt_sampler.T_0 = null_sampler.T_0 = null_sampler.tau_0
            alt_sampler.T_1 = null_sampler.T_1 = null_sampler.tau_1
            logging.info(f"{null_sampler.data['y'].mean()=}")
            logging.info(f"{alt_sampler.data['y'].mean()=}")
        elif config["settings"] == "math_judge_concept_shift":
            null_sampler = MathJudgeConceptShiftDataSampler(
                model_name="deepseek-ai_DeepSeek-R1-Distill-Llama-8B",
                target_mean=config["p_y_null"],
                target_p_y_x_1=config["p_y_x_1_null"],
                target_p_y_x_0=config.get("p_y_x_0_null"),
                seed=seed,
                alt_model_name="deepseek-ai_DeepSeek-R1-Distill-Qwen-7B",
                target_mean_alt=config.get("p_y_alt"),
                target_p_y_x_1_alt=config.get("p_y_x_1_alt"),
                target_p_y_x_0_alt=config.get("p_y_x_0_alt"),
            )
            alt_sampler = null_sampler.get_alt_sampler()

            logging.info(f"{null_sampler.data['y'].mean()=}")
            logging.info(f"{alt_sampler.data['y'].mean()=}")
            logging.info(
                f"{null_sampler.data[null_sampler.data['x'] == 0]['y'].mean()=}"
            )
            logging.info(
                f"{null_sampler.data[null_sampler.data['x'] == 1]['y'].mean()=}"
            )
            logging.info(f"{alt_sampler.data[alt_sampler.data['x'] == 0]['y'].mean()=}")
            logging.info(f"{alt_sampler.data[alt_sampler.data['x'] == 1]['y'].mean()=}")
        elif config["settings"] == "math_shared_dependency":
            null_sampler = MathSharedDependencyDataSampler(
                "deepseek-ai_DeepSeek-R1-Distill-Llama-8B",
                seed=seed,
                hypothesis="null",
                target_mean=config["p_y_null"],
                target_p_y_x_0=config["p_y_x_0_null"],
                target_p_y_x_1=config["p_y_x_1_null"],
            )

            alt_sampler = MathSharedDependencyDataSampler(
                "deepseek-ai_DeepSeek-R1-Distill-Qwen-7B",
                seed=seed,
                hypothesis="alt",
                target_mean=config["p_y_alt"],
                target_p_y_x_0=config["p_y_x_0_alt"],
                target_p_y_x_1=config["p_y_x_1_alt"],
            )

        phi_prm = 2 if alg == "y_only" else 1

        return cls(
            config["n_training"],
            config["n_test"],
            null_sampler,
            alt_sampler,
            config["M"],
            config["p_y_null"],
            alg,
            phi_prm,
        ).set_rng(rng)

    def train_and_predict_NB(self, df_train, df_test):
        unique_vals_x = df_train["x"].unique()
        unique_vals_y = df_train["y"].unique()
        if len(unique_vals_y) == 1:
            # Only one class in target y
            p_y_given_x_NB = np.full(len(df_test), float(unique_vals_y[0]))
            p_y_given_x_NB_y_train = np.full(len(df_train), float(unique_vals_y[0]))
        elif len(unique_vals_x) == 1:
            # Only one feature value in x
            # We cannot learn a relationship. Predict the mean of y observed.
            p_y_mean = df_train["y"].mean()
            p_y_given_x_NB = np.full(len(df_test), p_y_mean)
            p_y_given_x_NB_y_train = np.full(len(df_train), p_y_mean)
        else:
            model_NB = CategoricalNB()
            model_NB.fit(df_train[["x"]], df_train["y"])
            prob_predictions = model_NB.predict_proba(df_test[["x"]])
            p_y_given_x_NB = prob_predictions[:, 1]

            prob_predictions_y_train = model_NB.predict_proba(df_train[["x"]])
            p_y_given_x_NB_y_train = prob_predictions_y_train[:, 1]

        predictions = self.rng_np.binomial(n=1, p=p_y_given_x_NB)
        predictions_y_train = self.rng_np.binomial(n=1, p=p_y_given_x_NB_y_train)
        return predictions, predictions_y_train

    def get_psi_of_imputed_data(self, df_train, sampler, df_test=None):

        if df_test is None:
            df_test = pd.DataFrame({"x": sampler.sample_x(self.n_test)})
        e_value_to_data = {}
        eps = 1e-8
        n0_total_predicted = 0
        n1_total_predicted = 0
        if self.alg == "logistic_regression":
            model_reg = LogisticRegression(n_jobs=1, solver="liblinear")
            model_reg.fit(df_train[["x"]], df_train["y"])
            p_y_given_x = model_reg.predict_proba(df_test[["x"]])[:, 1]
            predictions = self.rng_np.binomial(n=1, p=p_y_given_x)
            n1_total_predicted = predictions.sum()
            n0_total_predicted = len(predictions) - n1_total_predicted

        elif self.alg == "NB":
            # try:
            # unique_vals_x = df_train["x"].unique()
            # unique_vals_y = df_train["y"].unique()
            # if len(unique_vals_y) == 1:
            #     # Only one class in target y
            #     p_y_given_x_NB = np.full(len(df_test), float(unique_vals_y[0]))
            #     p_y_given_x_NB_y_train = np.full(len(df_train), float(unique_vals_y[0]))
            # elif len(unique_vals_x) == 1:
            #     # Only one feature value in x
            #     # We cannot learn a relationship. Predict the mean of y observed.
            #     p_y_mean = df_train["y"].mean()
            #     p_y_given_x_NB = np.full(len(df_test), p_y_mean)
            #     p_y_given_x_NB_y_train = np.full(len(df_train), p_y_mean)
            # else:
            #     model_NB = CategoricalNB()
            #     model_NB.fit(df_train[["x"]], df_train["y"])
            #     prob_predictions = model_NB.predict_proba(df_test[["x"]])
            #     p_y_given_x_NB = prob_predictions[:, 1]

            #     prob_predictions_y_train = model_NB.predict_proba(df_train[["x"]])
            #     p_y_given_x_NB_y_train = prob_predictions_y_train[:, 1]

            # predictions = self.rng_np.binomial(n=1, p=p_y_given_x_NB)
            # predictions_y_train = self.rng_np.binomial(n=1, p=p_y_given_x_NB_y_train)
            predictions, predictions_y_train = self.train_and_predict_NB(
                df_train, df_test
            )
            n1_total_predicted = predictions.sum()
            n0_total_predicted = len(predictions) - n1_total_predicted

        elif self.alg == "NB_max":
            y_hat = df_train["y"].mean()
            predictions = np.where(
                df_test["x"] == 1,
                1 if y_hat > sampler.T_1 else 0,
                1 if y_hat > sampler.T_0 else 0,
            )
            predictions = np.array(predictions)
            predictions_y_train = np.where(
                df_train["x"] == 1,
                1 if y_hat > sampler.T_1 else 0,
                1 if y_hat > sampler.T_0 else 0,
            )
            predictions_y_train = np.array(predictions_y_train)

            n1_total_predicted = predictions.sum()
            n0_total_predicted = len(predictions) - n1_total_predicted

        elif self.alg == "y_only":
            e_value_to_data["y_only"] = df_train["y"].to_numpy()
            n1_total_predicted = df_train["y"].sum()
            n0_total_predicted = len(df_train) - n1_total_predicted

        elif self.alg == "x_only":
            e_value_to_data["x_only"] = df_train["x"].to_numpy()
            n1_total_predicted = df_train["x"].sum()
            n0_total_predicted = len(df_train) - n1_total_predicted

        if self.alg not in {"y_only", "x_only"}:
            e_value_to_data["predictions_only"] = predictions
            e_value_to_data["combined"] = np.concatenate(
                [predictions, df_train["y"].to_numpy()]
            )

        try:
            agg_func = np.sum if self.agg == "sum" else np.prod
            e_value_to_psi = {
                name: agg_func(calc_phi(predictions, self.phi, self.phi_prm))
                for name, predictions in e_value_to_data.items()
            }
            if self.alg == "y_only":
                e_value_to_psi["y_only_mult"] = np.array(
                    calc_phi(e_value_to_data["y_only"], self.phi, 1)
                ).prod()
            elif self.alg == "x_only":
                e_value_to_psi["x_only_mult"] = np.array(
                    calc_phi(e_value_to_data["x_only"], self.phi, 1)
                ).prod()
            else:
                e_value_to_psi["pred_only_lrt"] = compute_E_value_LR(
                    predictions.sum(),
                    len(predictions),
                    self.data_old["y"].sum(),
                    len(self.data_old),
                    self.p_y_null,
                )
        except:
            print(self.alg)
            print(self.phi_prm)
            print(f"Error in calc_phi {traceback.format_exc()}")

        out_dic = {
            "e_value_to_psi": e_value_to_psi,
            "n1_total": n1_total_predicted,
            "n0_total": n0_total_predicted,
        }

        return out_dic

    def calc_statistic_finite_sample(
        self, df_training, data_new, normalizing_constants, df_test=None
    ):
        df_train = pd.concat([df_training, data_new])
        out_dic = self.get_psi_of_imputed_data(df_train, self.alt_sampler, df_test)

        e_value_to_psi = out_dic["e_value_to_psi"]
        n_1_total_numerator = out_dic["n1_total"]
        n_0_total_numerator = out_dic["n0_total"]

        s = {
            name: psi / (psi + normalizing_constants[name])
            for name, psi in e_value_to_psi.items()
        }

        out = {name: (self.M + 1) * s for name, s in s.items()}
        return {
            "statistics": out,
            "n_1_total_numerator": n_1_total_numerator,
            "n_0_total_numerator": n_0_total_numerator,
        }

    def get_normalizing_constant(self, df_training):
        e_value_to_psi_under_null = defaultdict(float)
        n1_total_null = []
        n0_total_null = []
        for _ in range(self.M):

            df_train = pd.DataFrame({"x": [], "y": []})
            while len(df_train["x"].unique()) < 2 or len(df_train["y"].unique()) < 2:
                df_train = pd.concat(
                    [df_training, self.null_sampler.sample(self.n_training)]
                )
            out_dic = self.get_psi_of_imputed_data(df_train, self.null_sampler)

            n1_total_null.append(out_dic["n1_total"])
            n0_total_null.append(out_dic["n0_total"])

            for e_value, psi in out_dic["e_value_to_psi"].items():
                e_value_to_psi_under_null[e_value] += psi

        return {
            "normalizing constant": e_value_to_psi_under_null,
            "n_1_total_null": n1_total_null,
            "n_0_total_null": n0_total_null,
        }

    def calc_ours(self, df_training, data_new, df_test=None):
        output_dic_normalizing_constant = self.get_normalizing_constant(df_training)

        output_dic_calc_statistic_finite_sample = self.calc_statistic_finite_sample(
            df_training,
            data_new,
            output_dic_normalizing_constant["normalizing constant"],
            df_test,
        )
        statistics = output_dic_calc_statistic_finite_sample["statistics"]

        return {
            "statistics": statistics,
            "n_1_numerator": output_dic_calc_statistic_finite_sample[
                "n_1_total_numerator"
            ],
            "n_0_numerator": output_dic_calc_statistic_finite_sample[
                "n_0_total_numerator"
            ],
            "n_1_total_denumerator": output_dic_normalizing_constant["n_1_total_null"],
            "n_0_total_denumerator": output_dic_normalizing_constant["n_0_total_null"],
        }

    def calc_PPI(
        self,
        alg,
        config,
        data_old,
        data_new,
        data_old_unlabeled=None,
        data_new_unlabeled=None,
        lamda_PPI_t_minus_1=0,
        a_t_minus_1=0,
        sampler=None,
        no_unlabeled=False,
        positive_lambda=False,
    ):
        if len(data_old) < len(data_new):
            w_i = data_new["y"] - config["p_y_null"]
        else:
            if alg == "NB_max":
                # learn the model based on historical labeled data, and make predictions on both historical labeled and unlabeled data
                y_hat = data_old["y"].mean()
                predictions_history_labeled = np.where(
                    data_old["x"] == 1,
                    1 if y_hat > sampler.T_1 else 0,
                    1 if y_hat > sampler.T_0 else 0,
                )
                predictions_history_unlabeled = np.where(
                    data_old_unlabeled["x"] == 1,
                    1 if y_hat > sampler.T_1 else 0,
                    1 if y_hat > sampler.T_0 else 0,
                )
                # calculate eps:
                predictions_history_labeled = np.array(predictions_history_labeled)
                predictions_history_unlabeled = np.array(predictions_history_unlabeled)
                N = len(predictions_history_unlabeled)
                n = len(predictions_history_labeled)
                cov_Yf_X = np.cov(predictions_history_labeled, np.array(data_old["y"]))[
                    0, 1
                ]

                f_x = np.concatenate(
                    [predictions_history_labeled, predictions_history_unlabeled]
                )
                var_f_X = np.var(f_x)
                if no_unlabeled:
                    eps = 0
                elif var_f_X == 0:
                    eps = 0.99
                else:
                    # eps = 0.5
                    eps = cov_Yf_X / ((1 + n / N) * var_f_X)
                    eps = max(0, min(1, eps))

                # make predictions on new data, labeled and unlabeled:
                predictions_new_data_labeled = np.where(
                    data_new["x"] == 1,
                    1 if y_hat > sampler.T_1 else 0,
                    1 if y_hat > sampler.T_0 else 0,
                )
                predictions_new_data_labeled = np.array(predictions_new_data_labeled)
                predictions_new_data_unlabeled = np.where(
                    data_new_unlabeled["x"] == 1,
                    1 if y_hat > sampler.T_1 else 0,
                    1 if y_hat > sampler.T_0 else 0,
                )
                predictions_new_data_unlabeled = np.array(
                    predictions_new_data_unlabeled
                )
                # calculate PPI:
                N = len(predictions_new_data_unlabeled)
                n = len(predictions_new_data_labeled)
                k = int(N / n)
                b = predictions_new_data_unlabeled.reshape(N // k, k).mean(axis=1)
                # normalize w_i
                w_i = (
                    np.array(data_new["y"])
                    - eps * predictions_new_data_labeled
                    + eps * b
                    - config["p_y_null"]
                )
                w_i = w_i / max(1 + eps - config["p_y_null"], eps + config["p_y_null"])
            if alg == "NB":
                # learn the model based on historical labeled data, and make predictions on both historical labeled and unlabeled data
                # model_NB = CategoricalNB()
                # model_NB.fit(data_old[["x"]], data_old["y"])
                # prob_predictions_history_labeled = model_NB.predict_proba(
                #     data_old[["x"]]
                # )
                # p_y_given_x_NB_history_labeled = prob_predictions_history_labeled[:, 1]
                # predictions_history_labeled = self.rng_np.binomial(
                #     n=1, p=p_y_given_x_NB_history_labeled
                # )

                # prob_predictions_history_unlabeled = model_NB.predict_proba(
                #     data_old_unlabeled[["x"]]
                # )
                # p_y_given_x_NB_history_unlabeled = prob_predictions_history_unlabeled[
                #     :, 1
                # ]
                # predictions_history_unlabeled = self.rng_np.binomial(
                #     n=1, p=p_y_given_x_NB_history_unlabeled
                # )

                predictions_history_unlabeled, predictions_history_labeled = (
                    self.train_and_predict_NB(data_old, data_old_unlabeled)
                )

                # calculate eps:
                N = len(predictions_history_unlabeled)
                n = len(predictions_history_labeled)
                cov_Yf_X = np.cov(predictions_history_labeled, np.array(data_old["y"]))[
                    0, 1
                ]
                f_x = np.concatenate(
                    [predictions_history_labeled, predictions_history_unlabeled]
                )
                var_f_X = np.var(f_x)
                if no_unlabeled:
                    eps = 0
                if var_f_X == 0:
                    eps = 0.99
                else:
                    eps = cov_Yf_X / ((1 + n / N) * var_f_X)
                    eps = max(0, min(1, eps))

                # make predictions on new data, labeled and unlabeled:
                # prob_predictions_new_data_labeled = model_NB.predict_proba(
                #     data_new[["x"]]
                # )
                # p_y_given_x_NB_new_data_labeled = prob_predictions_new_data_labeled[
                #     :, 1
                # ]
                # predictions_new_data_labeled = self.rng_np.binomial(
                #     n=1, p=p_y_given_x_NB_new_data_labeled
                # )

                # prob_predictions_new_data_unlabeled = model_NB.predict_proba(
                #     data_new_unlabeled[["x"]]
                # )
                # p_y_given_x_NB_new_data_unlabeled = prob_predictions_new_data_unlabeled[
                #     :, 1
                # ]
                # predictions_new_data_unlabeled = self.rng_np.binomial(
                #     n=1, p=p_y_given_x_NB_new_data_unlabeled
                # )

                predictions_new_data_labeled, _ = self.train_and_predict_NB(
                    data_old, data_new
                )

                predictions_new_data_unlabeled, _ = self.train_and_predict_NB(
                    data_old, data_new_unlabeled
                )

                # calculate PPI:
                N = len(predictions_new_data_unlabeled)
                n = len(predictions_new_data_labeled)
                k = int(N / n)
                b = predictions_new_data_unlabeled.reshape(N // k, k).mean(axis=1)
                w_i = (
                    np.array(data_new["y"])
                    - eps * predictions_new_data_labeled
                    + eps * b
                    - config["p_y_null"]
                )
                w_i = w_i / max(1 + eps - config["p_y_null"], eps + config["p_y_null"])

        bett = 1
        lamda_PPI_t = lamda_PPI_t_minus_1
        a_t = a_t_minus_1
        for w in w_i:
            bett = bett * (1 + lamda_PPI_t * (w))
            v_t = w
            # v_t = regret.mean()#predictions_new_data_labeled.mean() - config['p_y_null']
            z_t = v_t / (1 + v_t * lamda_PPI_t)
            a_t = a_t + z_t**2
            lamda_PPI_t = min(
                0.5, max(-0.5, lamda_PPI_t + 2 / (2 - np.log(3)) * z_t / a_t)
            )
            if positive_lambda:
                lamda_PPI_t = max(0, lamda_PPI_t)
            # print(lamda_PPI_t)

        output = {"bett": bett, "lamda_PPI_t": lamda_PPI_t, "a_t": a_t}
        return output

    def calc_lrt(self, data_new, data_old, var):
        if var == "y":
            lrt_y = compute_E_value_LR(
                data_new["y"].sum(),
                len(data_new),
                data_old["y"].sum(),
                len(data_old),
                self.p_y_null,
            )
            return lrt_y
        elif var == "x":
            lrt_x = compute_E_value_LR(
                data_new["x"].sum(),
                len(data_new),
                data_old["x"].sum(),
                len(data_old),
                self.null_sampler.p_x,
            )
            # print(lrt_x)
            return lrt_x
        elif var == "y_given_x_0":
            lrt_y_given_x_0 = compute_E_value_LR(
                data_new["y"].sum(),
                len(data_new),
                data_old["y"].sum(),
                len(data_old),
                self.null_sampler.p_y_x_0,
            )
            return lrt_y_given_x_0
        elif var == "y_given_x_1":
            lrt_y_given_x_1 = compute_E_value_LR(
                data_new["y"].sum(),
                len(data_new),
                data_old["y"].sum(),
                len(data_old),
                self.null_sampler.p_y_x_1,
            )
            return lrt_y_given_x_1

    def update_phi_prm(self, n_1, n_0, n1_total, n0_total):
        alpha_t = self.phi_prm

        if self.agg == "sum":
            denom1 = n_1 * np.exp(alpha_t) + n_0
            denom2 = (n_1 + sum(n1_total)) * np.exp(alpha_t) + (n_0 + sum(n0_total))
            # Avoid division by zero when denominators are zero
            if denom1 > 0 and denom2 > 0:
                grad_L = (
                    n_1 * np.exp(alpha_t) / denom1
                    - (n_1 + sum(n1_total)) * np.exp(alpha_t) / denom2
                )
            else:
                grad_L = 0.0  # No gradient update when data is degenerate
        elif self.agg == "prod":
            sum_exp = sum([np.exp(alpha_t * n_j) for n_j in n1_total])
            sum_exp_derivative = sum([n_j * np.exp(alpha_t * n_j) for n_j in n1_total])
            grad_L = n_1 - (
                (n_1 * np.exp(alpha_t * n_1) + sum_exp_derivative)
                / (np.exp(alpha_t * n_1) + sum_exp)
            )
        alpha_t_plus_1 = self.ada_grad.step(grad_L, clip_max=10**6, clip_min=0.01)
        self.phi_prm = alpha_t_plus_1


##########################################################################################################
def compute_E_value_LR_oracle(x_new, n_new, p0, p1):
    # This function calculate the LR E-statistic
    # It estimates p under the alternative based on previous data, merely (to ensure a valid construction)
    import numpy as np
    from scipy.stats import chi2

    # phat = x_old/n_old
    phat = p1
    if phat > p0:
        logL_p0 = x_new * np.log(p0) + (n_new - x_new) * np.log(1 - p0)
        logL_phat = x_new * np.log(phat) + (n_new - x_new) * np.log(1 - phat)
        E_value = np.exp(logL_phat) / np.exp(logL_p0)
    else:
        E_value = 1
    return E_value


######################################################################
def compute_E_value_LR(x_new, n_new, x_old, n_old, p0):
    # This function calculate the LR E-statistic
    # It estimates p under the alternative based on previous data, merely (to ensure a valid construction)
    import numpy as np
    from scipy.stats import chi2

    if n_old == 0:
        return 1
    phat = x_old / n_old
    if phat == 0 or phat == 1:
        return 1
    E_value = 1  # Default value
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("error", RuntimeWarning)
            if phat > p0:
                logL_p0 = x_new * np.log(p0) + (n_new - x_new) * np.log(1 - p0)
                logL_phat = x_new * np.log(phat) + (n_new - x_new) * np.log(1 - phat)
                # Compute in log-space to avoid overflow
                log_E_value = logL_phat - logL_p0
                # Clip to prevent overflow when exponentiating
                log_E_value = np.clip(log_E_value, -700, 700)
                E_value = np.exp(log_E_value)
            else:
                E_value = 1
    except RuntimeWarning as e:
        print(f"Caught RuntimeWarning as an exception: {e}, {os.getpid()}")
        logging.info(f"Runtime warnning {x_new=}, {n_new=}, {x_old=}, {n_old=}, {p0=} ")
        E_value = 1  # Return safe default on error
    return E_value


####################################################################################
def first_exceeds(arr, alpha):
    mask = arr > alpha
    if np.any(mask):
        return np.argmax(mask)  # Returns the first index where condition is True
    else:
        return None


####################################################################################
def intialize_list(l):
    return l if l else []


def get_power(bReject, num_trials):
    return len(np.where(np.array(bReject) != None)[0]) / num_trials


def print_power(bReject, num_trials, name):
    power = get_power(bReject, num_trials)
    print(f"power of {name}: {power}")


def get_e_sequences(
    config, seed, settings, nsteps, n_training, history_length=40, min_history_length=10
):
    sequences = defaultdict(list)
    df_training = pd.DataFrame({"x": [], "y": []})
    data_old = pd.DataFrame({"x": [], "y": []})
    p_null_estimation_error = 0
    p_alt_estimation_error = 0
    estimation_error = 0

    p_y_x_1_null_est = []
    p_y_x_1_alt_est = []
    p_y_x_0_null_est = []
    p_y_x_0_alt_est = []
    p_y_null_est = []
    p_y_alt_est = []
    algs = (
        ["NB"] if settings in CONCEPT_SHIFT_SETTINGS else ["NB_max"]
    )  # ["NB"]#['NB_max']

    for alg in algs:
        df_training = pd.DataFrame({"x": [], "y": []})
        data_old = pd.DataFrame({"x": [], "y": []})
        data_old_x = pd.DataFrame({"x": [], "y": []})
        data_old_unlabeled = pd.DataFrame({"x": [], "y": []})

        logging.info(f"Using algorithm: {alg}")
        statistics_calculator = StatisticCalcualtor.of(alg, config, seed)
        statistics_calculator.data_old = data_old
        statistics_calculator_x = StatisticCalcualtor.of(alg, config, seed)
        statistics_calculator_prod = StatisticCalcualtor.of(alg, config, seed)
        statistics_calculator_prod.agg = "prod"
        statistics_calculator_prod.data_old = data_old

        alt_sampler = statistics_calculator.alt_sampler
        alt_sampler_x = statistics_calculator_x.alt_sampler
        ada_grad = AdaGrad(weight=statistics_calculator.phi_prm)
        lamda_PPI_t = 0
        lamda_PPI_t_no_unlabeled = 0
        lamda_PPI_t_positive_lambda = 0
        a_t = 1
        a_t_no_unlabeled = 1
        a_t_positive_lambda = 1
        for step in range(nsteps):
            logging.info(f"step {step}")
            data_new = alt_sampler.sample(n_training)
            data_new_unlabeled = alt_sampler.sample(statistics_calculator.n_test)
            # this is here so that the states of all samplers will be the same for fair comparison
            data_new_prod = statistics_calculator_prod.alt_sampler.sample(n_training)
            data_new_unlabeled_prod = statistics_calculator_prod.alt_sampler.sample(
                statistics_calculator_prod.n_test
            )
            data_new_x = pd.concat([data_new, data_new_unlabeled])
            if len(df_training) < min_history_length:
                df_training = pd.concat([df_training, data_new])
                continue
            if history_length == 0:
                df_training = pd.DataFrame({"x": [], "y": []})
            else:
                df_training = df_training[-history_length:]

            output_dic = statistics_calculator.calc_ours(
                df_training, data_new, data_new_unlabeled
            )
            output_dic_prod = statistics_calculator_prod.calc_ours(
                df_training, data_new, data_new_unlabeled
            )
            statistics = output_dic["statistics"]
            n1_total = output_dic["n_1_total_denumerator"]
            n0_total = output_dic["n_0_total_denumerator"]
            n_1 = output_dic["n_1_numerator"]
            n_0 = output_dic["n_0_numerator"]

            for name, s in statistics.items():
                sequences[f"{alg}_{name}"].append(s)

            lrt_y = statistics_calculator.calc_lrt(data_new, data_old, "y")
            sequences["lrt_y"].append(lrt_y)

            if settings in CONCEPT_SHIFT_SETTINGS:
                sequences["NB_prod_predictions_only"].append(
                    output_dic_prod["statistics"]["predictions_only"]
                )
                if step == 0:
                    lrt_y_given_x_0 = 1
                    lrt_y_given_x_1 = 1
                    sequences["cond_lrt"].append(1)
                else:
                    lrt_y_given_x_0 = statistics_calculator.calc_lrt(
                        data_new[data_new.x == 0],
                        data_old[data_old.x == 0],
                        "y_given_x_0",
                    )
                    lrt_y_given_x_1 = statistics_calculator.calc_lrt(
                        data_new[data_new.x == 1],
                        data_old[data_old.x == 1],
                        "y_given_x_1",
                    )
                    sequences["cond_lrt"].append(lrt_y_given_x_0 * lrt_y_given_x_1)
            else:
                sequences["NB_max_prod_predictions_only"].append(
                    output_dic_prod["statistics"]["predictions_only"]
                )
                lrt_x = statistics_calculator.calc_lrt(data_new_x, data_old_x, "x")
                sequences["lrt_x"].append(lrt_x)

            out = statistics_calculator.calc_PPI(
                alg,
                config,
                data_old,
                data_new,
                data_old_unlabeled,
                data_new_unlabeled,
                lamda_PPI_t,
                a_t,
                sampler=alt_sampler,
            )
            sequences["PPI"].append(out["bett"])
            lamda_PPI_t = out["lamda_PPI_t"]
            a_t = out["a_t"]

            out_no_unlabeled = statistics_calculator.calc_PPI(
                alg,
                config,
                data_old,
                data_new,
                data_old_unlabeled,
                data_new_unlabeled,
                lamda_PPI_t_no_unlabeled,
                a_t_no_unlabeled,
                sampler=alt_sampler,
                no_unlabeled=True,
            )
            sequences["PPI_no_unlabeled"].append(out_no_unlabeled["bett"])
            lamda_PPI_t_no_unlabeled = out_no_unlabeled["lamda_PPI_t"]
            a_t_no_unlabeled = out_no_unlabeled["a_t"]

            out_positive_lambda = statistics_calculator.calc_PPI(
                alg,
                config,
                data_old,
                data_new,
                data_old_unlabeled,
                data_new_unlabeled,
                lamda_PPI_t_positive_lambda,
                a_t_positive_lambda,
                sampler=alt_sampler,
                positive_lambda=True,
            )
            sequences["PPI_positive_lambda"].append(out_positive_lambda["bett"])
            lamda_PPI_t_positive_lambda = out_positive_lambda["lamda_PPI_t"]
            a_t_positive_lambda = out_positive_lambda["a_t"]

            if settings == "label_shift":
                p_x_null = config["p_y_null"] * config["p_x_y_1"] + (
                    1 - config["p_y_null"]
                ) * (config["p_x_y_0"])
                p_y_x_0_null = (
                    statistics_calculator.null_sampler.p_y
                    * (1 - statistics_calculator.null_sampler.p_x_y_1)
                    / (1 - p_x_null)
                )
                p_y_x_1_null = (
                    statistics_calculator.null_sampler.p_y
                    * statistics_calculator.null_sampler.p_x_y_1
                    / (p_x_null)
                )

                p_x_alt = config["p_y_alt"] * config["p_x_y_1"] + (
                    1 - config["p_y_alt"]
                ) * (config["p_x_y_0"])
                p_y_x_0_alt = (
                    statistics_calculator.alt_sampler.p_y
                    * (1 - statistics_calculator.alt_sampler.p_x_y_1)
                    / (1 - p_x_alt)
                )
                p_y_x_1_alt = (
                    statistics_calculator.alt_sampler.p_y
                    * statistics_calculator.alt_sampler.p_x_y_1
                    / (p_x_alt)
                )

                df_train_null_tmp = pd.DataFrame({"x": [], "y": []})
                while (
                    len(df_train_null_tmp["x"].unique()) < 2
                    or len(df_train_null_tmp["y"].unique()) < 2
                ):
                    df_train_null_tmp = pd.concat(
                        [
                            df_training,
                            statistics_calculator.null_sampler.sample(
                                statistics_calculator.n_training
                            ),
                        ]
                    )

            if settings == "label_shift":
                p_x_null = config["p_y_null"] * config["p_x_y_1"] + (
                    1 - config["p_y_null"]
                ) * (config["p_x_y_0"])
                p_y_x_0_null = (
                    statistics_calculator.null_sampler.p_y
                    * (1 - statistics_calculator.null_sampler.p_x_y_1)
                    / (1 - p_x_null)
                )
                p_y_x_1_null = (
                    statistics_calculator.null_sampler.p_y
                    * statistics_calculator.null_sampler.p_x_y_1
                    / (p_x_null)
                )

                p_x_alt = config["p_y_alt"] * config["p_x_y_1"] + (
                    1 - config["p_y_alt"]
                ) * (config["p_x_y_0"])
                p_y_x_0_alt = (
                    statistics_calculator.alt_sampler.p_y
                    * (1 - statistics_calculator.alt_sampler.p_x_y_1)
                    / (1 - p_x_alt)
                )
                p_y_x_1_alt = (
                    statistics_calculator.alt_sampler.p_y
                    * statistics_calculator.alt_sampler.p_x_y_1
                    / (p_x_alt)
                )

                df_train_null_tmp = pd.DataFrame({"x": [], "y": []})
                while (
                    len(df_train_null_tmp["x"].unique()) < 2
                    or len(df_train_null_tmp["y"].unique()) < 2
                ):
                    df_train_null_tmp = pd.concat(
                        [
                            df_training,
                            statistics_calculator.null_sampler.sample(
                                statistics_calculator.n_training
                            ),
                        ]
                    )

                p_y_x_1_null_est_t = len(
                    df_train_null_tmp[
                        (df_train_null_tmp.x == 1) & (df_train_null_tmp.y == 1)
                    ]
                ) / max(len(df_train_null_tmp[df_train_null_tmp.x == 1]), 1)
                p_y_x_0_null_est_t = len(
                    df_train_null_tmp[
                        (df_train_null_tmp.x == 0) & (df_train_null_tmp.y == 1)
                    ]
                ) / max(len(df_train_null_tmp[df_train_null_tmp.x == 0]), 1)
                p_y_null_est_t = df_train_null_tmp.y.mean()
                p_y_x_1_null_est.append(p_y_x_1_null_est_t)
                p_y_x_0_null_est.append(p_y_x_0_null_est_t)
                p_y_null_est.append(p_y_null_est_t)

                df_train_alt_tmp = pd.DataFrame({"x": [], "y": []})
                while (
                    len(df_train_alt_tmp["x"].unique()) < 2
                    or len(df_train_alt_tmp["y"].unique()) < 2
                ):
                    df_train_alt_tmp = pd.concat(
                        [
                            df_training,
                            statistics_calculator.alt_sampler.sample(
                                statistics_calculator.n_training
                            ),
                        ]
                    )
                p_y_x_1_alt_est_t = len(
                    df_train_alt_tmp[
                        (df_train_alt_tmp.x == 1) & (df_train_alt_tmp.y == 1)
                    ]
                ) / max(len(df_train_alt_tmp[df_train_alt_tmp.x == 1]), 1)
                p_y_x_0_alt_est_t = len(
                    df_train_alt_tmp[
                        (df_train_alt_tmp.x == 0) & (df_train_alt_tmp.y == 1)
                    ]
                ) / max(len(df_train_alt_tmp[df_train_alt_tmp.x == 0]), 1)
                p_y_alt_est_t = df_train_alt_tmp.y.mean()
                p_y_x_1_alt_est.append(p_y_x_1_alt_est_t)
                p_y_x_0_alt_est.append(p_y_x_0_alt_est_t)
                p_y_alt_est.append(p_y_alt_est_t)

            df_training = pd.concat([df_training, data_new])
            data_old = pd.concat([data_old, data_new])
            data_old_x = pd.concat([data_old_x, data_new_x])
            data_old_unlabeled = pd.concat([data_old_unlabeled, data_new_unlabeled])

            #  upating the phi_prm parameter adaptively:
            # 1. here I update the lambda_L (i.e., optimizing with respect to the y-only e-process)
            # alpha_t = statistics_calculator.phi_prm
            # denom1 = n_1 * np.exp(alpha_t) + n_0
            # denom2 = (n_1 + sum(n1_total)) * np.exp(alpha_t) + (n_0 + sum(n0_total))
            # # Avoid division by zero when denominators are zero
            # if denom1 > 0 and denom2 > 0:
            #     grad_L = (
            #         n_1 * np.exp(alpha_t) / denom1
            #         - (n_1 + sum(n1_total)) * np.exp(alpha_t) / denom2
            #     )
            # else:
            #     grad_L = 0.0  # No gradient update when data is degenerate
            # alpha_t_plus_1 = ada_grad.step(grad_L, clip_max=10**6, clip_min=0.01)
            # statistics_calculator.phi_prm = alpha_t_plus_1
            statistics_calculator.update_phi_prm(n_1, n_0, n1_total, n0_total)
            statistics_calculator.data_old = data_old
            statistics_calculator_prod.data_old = data_old

    if settings == "label_shift":
        p_alt_bias = (
            abs(np.mean(np.array(p_y_alt_est)) - config["p_y_alt"]) / config["p_y_alt"]
        )
        p_alt_std = np.std(np.array(p_y_alt_est))
        # p_alt_MSE = np.mean((np.array(p_y_alt_est) - config['p_y_alt'])**2)
        # np.mean( (np.array(p_y_x_1_alt_est) - p_y_x_1_alt)**2 ) + np.mean( (np.array(p_y_x_0_alt_est) - p_y_x_0_alt)**2 )
        p_null_var = np.std(np.array(p_y_null_est)) / config["p_y_null"]
        # np.var( np.array(p_y_x_1_null_est) ) + np.var( np.array(p_y_x_0_null_est) )
        # p_null_bias_2 = (np.mean( np.array(p_y_null_est) ) - config['p_y_null'])**2
        p_null_bias = abs(
            np.mean(np.array(p_y_null_est)) - config["p_y_null"]
        )  # /config['p_y_null']
        # (np.mean( np.array(p_y_x_1_null_est) ) - p_y_x_1_null)**2 + (np.mean( np.array(p_y_x_0_null_est) ) - p_y_x_0_null)**2
        # MSE: probs_estimation_error
        # output_dic = {'sequences': sequences, 'probs_estimation_error':  p_alt_std+p_null_bias}
        output_dic = {
            "sequences": sequences,
            "probs_estimation_error_null_bias": p_null_bias,
            "probs_estimation_error_alt_std": p_alt_std,
        }
    else:
        output_dic = {
            "sequences": sequences,
            "probs_estimation_error_null_bias": 0,
            "probs_estimation_error_alt_std": 0,
        }

    # return sequences

    return output_dic


def combine_martingales(m1, m2):
    combined = []
    algo = AdaGrad(weight=0.5)
    for s1, s2 in zip(m1, m2):
        grad = s1 - s2  # since reward = S' + a * (S - S')
        algo.step(grad)
        reward = algo.weight * s1 + (1 - algo.weight) * s2
        combined.append(reward)
    return np.array(combined)


def combine_martingales_growth_rate(*martingales):
    """
    Maximize growth rate.
    """
    combined = []
    weights = np.ones(len(martingales)) / len(martingales)  # Start with uniform weights
    algo = AdaGrad(weight=weights)
    for steps in zip(*martingales):
        x_t = np.array(steps)
        step_return = np.dot(weights, x_t)
        grad = x_t / step_return
        combined.append(step_return)
        algo.step(grad)
    return combined


def combine_3_martingales(m1, m2, m3):
    combined = [1.0]
    weights = np.ones(3) / 3  # Start with uniform weights
    algo = AdaGrad(weight=weights)
    for s1, s2, s3 in zip(m1, m2, m3):
        x_t = np.array([s1, s2, s3])
        step_return = np.dot(weights, x_t)
        grad = x_t / step_return
        # grad = s1 - s2  # since reward = S' + a * (S - S')
        algo.step(grad)
        reward = np.dot(algo.weight, x_t)
        combined.append(reward * combined[-1])
    return np.array(combined[1:])


def old_combine_martingales_growth_rate(*martingales):
    """
    Maximize growth rate.
    """
    combined = []
    weights = np.ones(len(martingales)) / len(martingales)  # Start with uniform weights
    algo = AdaGrad(weight=weights)
    for steps in zip(*martingales):
        x_t = np.array(steps)
        step_return = np.dot(weights, x_t)
        grad = x_t / step_return
        combined.append(step_return)
        algo.step(grad)
    return combined


def combine_martingales_growth_rate_eg(
    *martingales, eta=0.7, eps=1e-12, clip_grad=None
):
    # “”"
    # Maximize log-wealth / growth rate (Option 1 / constant rebalanced portfolio) using
    # Exponentiated Gradient (EG) / Hedge updates on the simplex.
    # Objective (equivalent forms):
    #     maximize  Π_t (w^T b_t)   ⇔   maximize  Σ_t log(w^T b_t)
    # where b_t is the vector of per-step returns/bets (here: x_t), and w is a simplex vector.
    # References:
    #   - Kivinen & Warmuth (1997), “Exponentiated Gradient versus Gradient Descent for Linear Predictors”
    #   - Helmbold, Schapire, Singer, Warmuth (1998), “On-Line Portfolio Selection Using Multiplicative Updates”
    #   - Freund & Schapire (1999), “Adaptive game playing using multiplicative weights”
    #   - Cover (1991), “Universal Portfolios”
    # “”"
    combined = []
    K = len(martingales)
    if K == 0:
        raise ValueError("Need at least one martingale.")
    if eta <= 0:
        raise ValueError("eta must be positive.")
    # Start with uniform weights on the simplex Δ_K.
    weights = np.ones(K, dtype=float) / K
    # weights[0] = 1
    # weights[1:] = 0 / (K - 1)
    # weights = np.zeros(K, dtype=float)
    # weights[0] = 1.0
    grads = []
    weights_list = []
    weights_list.append(weights)
    # print(f"eta: {eta}")
    # print(f"eps: {eps}")

    for i, steps in enumerate(zip(*martingales)):
        # x_t is b_t in the portfolio literature (price relatives / bet returns).
        # eta = (i + 1) / (500 * eta)
        x_t = np.asarray(steps, dtype=float)
        if x_t.shape != (K,):
            raise ValueError(f"Expected steps of length {K}, got {x_t.shape}.")
        if np.any(x_t < 0):
            raise ValueError("This EG implementation assumes x_t >= 0.")
        # Mixture/portfolio return for this round: r_t = w_t^T x_t.
        step_return = float(np.dot(weights, x_t))
        step_return = max(step_return, eps)  # avoid division by zero
        # Gradient of log-wealth (Kelly/CRP):
        #   ∇_w log(w^T x_t) = x_t / (w^T x_t).
        grad = x_t / step_return
        logging.info(f"{x_t=}, {step_return=}, {grad=}")
        grads.append(grad)
        # Optional: stabilize when step_return is tiny (grad can explode).
        if clip_grad is not None:
            grad = np.minimum(grad, clip_grad)
        combined.append(step_return)
        # Exponentiated Gradient / Hedge multiplicative update on the simplex:
        #   w_{t+1,k} ∝ w_{t,k} * exp(eta * grad_k).
        # print(f"{grad=}, {weights=}")
        # print(f"log part {np.log(weights + eps)}")
        # print(f"eta * grad part {eta * grad}")
        logw = np.log(weights + eps) + eta * grad
        # print(f"{logw=}")
        logw -= logw.max()  # numerical stability (prevents overflow)
        # print(f"normalized {logw=}")
        weights = np.exp(logw)
        weights /= weights.sum()
        # print(f"new {weights=}")
        weights_list.append(weights)

        logging.info(f"{weights=}")
    return combined
    # , weights_list, grads


# Use log-space cumsum to avoid overflow, then convert back
def safe_cumprod(seq):
    """Compute cumulative product safely by working in log-space."""
    seq = np.array(seq)
    # Clip values to avoid log(0) or log(negative)
    seq = np.clip(seq, 1e-300, None)
    log_cumsum = np.cumsum(np.log(seq))
    # Clip the log values to prevent overflow when exponentiating
    log_cumsum = np.clip(log_cumsum, -700, 700)  # exp(709) ~ 1e308 (float64 max)
    return np.exp(log_cumsum)


def simulate_tests(config, seed):
    pid = os.getpid()
    logging.basicConfig(
        filename=f"logs/process_{pid}.log",
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
    )
    logging.info(f"===========starting new sequence=========")
    start_time = time.time()

    out_dic = get_e_sequences(
        config,
        seed,
        config["settings"],
        config["nsteps"],
        config["n_training"],
        config["history_length"],
        config["min_history_length"],
    )

    logging.info(f"sequence simulation time: {time.time() - start_time}")

    sequences = out_dic["sequences"]
    ml_martingale = (
        "NB_predictions_only"
        if config["settings"] in CONCEPT_SHIFT_SETTINGS
        else "NB_max_predictions_only"
    )

    if config["settings"] in CONCEPT_SHIFT_SETTINGS:
        save_dir = f"sequences/{config['settings']}/p_y_x_1_null_{config['p_y_x_1_null']}/ntest_{config['n_test']}_ntrain_{config['n_training']}"
    elif config["settings"] == "math_judge_label_shift":
        save_dir = f"sequences/{config['settings']}/p_y_null_{config['p_y_null']}_p_x_null_{config['p_x_null']}/ntest_{config['n_test']}_ntrain_{config['n_training']}"
    elif config["settings"] == "math_judge_label_shift_validity":
        save_dir = f"sequences/{config['settings']}/ntest_{config['n_test']}_ntrain_{config['n_training']}"
    else:
        save_dir = f"sequences/{config['settings']}/p_y_null_{config['p_y_null']}_p_x_y_1_{config['p_x_y_1']}/ntest_{config['n_test']}_ntrain_{config['n_training']}"
    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, f"{seed}.pkl"), "wb") as f:
        pickle.dump(sequences, f)

    if config["settings"] in CONCEPT_SHIFT_SETTINGS:
        # NB_convex_lrt_growth_rate: combining *ALL* baselines: ML, lrt_y, cond_lrt and PPI (lrt_x is not relevant here)

        sequences["NB_convex_lrt_growth_positive_lambda"] = (
            combine_martingales_growth_rate_eg(
                sequences["cond_lrt"],
                sequences[ml_martingale],
                sequences["lrt_y"],
                sequences["PPI_positive_lambda"],
            )
        )

        sequences["convex_lrt_y_cond_lrt_growth_rate"] = (
            combine_martingales_growth_rate_eg(
                sequences["cond_lrt"],
                sequences["lrt_y"],
            )
        )

        # PPI_convex_lrt_growth_rate: combining lrt_y and PPI (lrt_x is not relevant here)
        sequences["PPI_convex"] = combine_martingales_growth_rate_eg(
            sequences["cond_lrt"],
            sequences["PPI"],
            sequences["lrt_y"],
        )

        sequences["PPI_convex_positive_lambda"] = combine_martingales_growth_rate_eg(
            sequences["cond_lrt"],
            sequences["PPI_positive_lambda"],
            sequences["lrt_y"],
        )

        sequences["NB_convex_lrt_growth"] = combine_martingales_growth_rate_eg(
            sequences[ml_martingale],
            sequences["PPI_convex"],
        )

        sequences["NB_convex_lrt_y_cond_lrt"] = combine_martingales_growth_rate_eg(
            sequences["cond_lrt"],
            sequences[ml_martingale],
            sequences["lrt_y"],
        )

    else:  # LABEL-SHIFT SETTING
        # NB_convex_lrt_growth_rate: combining *ALL* baselines: ML, lrt_x, lrt_y and PPI

        sequences["NB_convex_lrt_growth_positive_lambda"] = (
            combine_martingales_growth_rate_eg(
                sequences[ml_martingale],
                sequences["lrt_x"],
                sequences["lrt_y"],
                sequences["PPI_positive_lambda"],
            )
        )

        # convex_lrt_x_y_growth_rate: combining lrt_x and lrt_y'
        sequences["convex_lrt_x_y_growth_rate"] = combine_martingales_growth_rate_eg(
            sequences["lrt_y"], sequences["lrt_x"]
        )

        # PPI_convex_lrt_y_lrt_x: combining PPI, lrt_y and lrt_x
        sequences["PPI_convex"] = combine_martingales_growth_rate_eg(
            sequences["lrt_x"],
            sequences["lrt_y"],
            sequences["PPI"],
        )

        sequences["NB_convex_lrt_growth"] = combine_martingales_growth_rate_eg(
            sequences[ml_martingale],
            sequences["PPI_convex"],
        )

        sequences["PPI_convex_positive_lambda"] = combine_martingales_growth_rate_eg(
            sequences["lrt_x"],
            sequences["lrt_y"],
            sequences["PPI_positive_lambda"],
        )
        # NB_convex_lrt_y_lrt_x: combining ML, lrt_y and lrt_x
        sequences["NB_convex_lrt_y_lrt_x"] = combine_martingales_growth_rate_eg(
            sequences[ml_martingale],
            sequences["lrt_x"],
            sequences["lrt_y"],
        )

    cummulative_e_values = {name: safe_cumprod(seq) for name, seq in sequences.items()}

    rejects = {}
    for name, cumprod in cummulative_e_values.items():
        # print("###############################")
        # print(name)
        rejects[name] = first_exceeds(cumprod, 1 / config["alpha"])

    logging.info(f"total simulation time: {time.time() - start_time}")

    probs_estimation_error_null_bias = out_dic["probs_estimation_error_null_bias"]
    probs_estimation_error_alt_std = out_dic["probs_estimation_error_alt_std"]
    output_dic = {
        "rejects": rejects,
        "probs_estimation_error_null_bias": probs_estimation_error_null_bias,
        "probs_estimation_error_alt_std": probs_estimation_error_alt_std,
    }

    return output_dic


def compare_tests(config, pool=None):
    SEED = 0
    res = []
    futures = []
    # print(config['settings'])

    for i in range(config["num_of_repititions"]):
        args = (config, SEED + i)

        if pool:
            futures.append(pool.apply_async(simulate_tests, args=args))
        else:
            res.append(simulate_tests(*args))

    if pool:
        res = [f.get() for f in tqdm(futures)]

    rejection_res = [tmp["rejects"] for tmp in res]
    estimation_error_null_bias = [
        tmp["probs_estimation_error_null_bias"] for tmp in res
    ]
    estimation_error_alt_std = [tmp["probs_estimation_error_alt_std"] for tmp in res]

    rejects_lists = defaultdict(list)
    for rejects in rejection_res:  # res:
        for name, r in rejects.items():
            rejects_lists[name].append(r)

    output_dic = {
        "rejects_lists": rejects_lists,
        "estimation_error_null_bias": estimation_error_null_bias,
        "estimation_error_alt_std": estimation_error_alt_std,
    }
    return output_dic

    #######################################################################################################


def run_simulation(config, pool):

    sc = StatisticCalcualtor.of("y_only", config, seed=0)
    null_sampler = sc.null_sampler
    alt_sampler = sc.alt_sampler

    null_sampler.sample_x(10)
    alt_sampler.sample_x(10)

    df_0 = null_sampler.sample(50_000)
    df_1 = alt_sampler.sample(50_000)
    print(df_0.corr()["x"]["y"])
    print(df_1.corr()["x"]["y"])

    # print(config['settings'])

    output_dic = compare_tests(config, pool)
    res = output_dic["rejects_lists"]
    statistic_to_power = {
        name: get_power(rejects_list, config["num_of_repititions"])
        for name, rejects_list in res.items()
    }

    return (
        statistic_to_power,
        df_0.corr()["x"]["y"],
        df_1.corr()["x"]["y"],
        res,
        output_dic["estimation_error_null_bias"],
        output_dic["estimation_error_alt_std"],
    )


def plot_power(x, x_label, statistic_to_power):
    for name, power in statistic_to_power.items():
        plt.plot(x, power, label=name)

    plt.xlabel(x_label)
    plt.ylabel("Power")
    plt.title(f"Power vs {x_label}")
    plt.legend()
    plt.show()
