# %%
import pandas as pd
import numpy as np


class SharedDependencyDataSampler_b:
    def __init__(self, p_y_null, p_y_alt, p_x, rho_XY1, rho_XY2, rho_Y1Y2, bNull):
        self.p_y_null = p_y_null
        self.p_y_alt = p_y_alt
        self.p_x = p_x
        self.rho_XY1 = rho_XY1
        self.rho_XY2 = rho_XY2
        self.rho_Y1Y2 = rho_Y1Y2
        self.bNull = bNull

    def set_rng(self, rng):
        self.rng = rng
        return self

    def lam_bounds(self, p_y_null, p_y_alt, p_x):
        m = min(p_y_null, 1 - p_y_null, p_y_alt, 1 - p_y_alt)
        if p_x <= 0 or p_x >= 1:
            return 0.0  # no dependence possible if X degenerate
        return min(m / p_x, m / (1 - p_x))

    def lam_for_rho_XYi(self, rho_target, p_i, p_x):
        """Solve lambda for a desired Corr(X, Yi)=rho_target."""
        if not (0 < p_i < 1) or not (0 < p_x < 1):
            raise ValueError("p_i and p_x must be in (0,1) to define correlation.")
        return rho_target * np.sqrt(p_i * (1 - p_i)) / np.sqrt(p_x * (1 - p_x))

    def lam_for_rho_Y1Y2(self, rho12_target, p_y_null, p_y_alt, p_x, sign=1):
        """Solve lambda (up to sign) for a desired Corr(Y1, Y2)=rho12_target >= 0."""
        if rho12_target < 0:
            raise ValueError(
                "Negative Corr(Y1,Y2) is impossible under this construction."
            )
        num = rho12_target * np.sqrt(
            p_y_null * (1 - p_y_null) * p_y_alt * (1 - p_y_alt)
        )
        den = p_x * (1 - p_x)
        if den == 0:
            raise ValueError("q must be in (0,1).")
        return float(sign) * np.sqrt(num / den)

    def sample(self, n):
        """
        Sample X~Bern(q), Y1|X~Bern(p1+g(X)), Y2|X~Bern(p2+g(X))
        choosing lambda to hit a specified correlation target.
        - Provide rho_XY1 (or rho_XY2) to target Corr(X,Yi).
        - Or provide rho_Y1Y2 to target Corr(Y1,Y2) (nonnegative).
        If both rho_XY1 and rho_XY2 are given, they must satisfy
        rho_XY1*sqrt(p1(1-p1)) == rho_XY2*sqrt(p2(1-p2)).
        """
        if not (0 < self.p_y_null < 1 and 0 < self.p_y_alt < 1 and 0 < self.p_x < 1):
            raise ValueError("All p1, p2, q must lie in (0,1).")

        # choose lambda
        lam = None
        if self.rho_Y1Y2 is not None:
            print("rho_Y1Y2 is not None")
            print(self.rho_Y1Y2)
            lam = self.lam_for_rho_Y1Y2(
                self.rho_Y1Y2, self.p_y_null, self.p_y_alt, self.p_x, sign=1
            )

        if (self.rho_XY1 is not None) and (self.rho_XY2 is not None):
            lhs = self.rho_XY1 * np.sqrt(self.p_y_null * (1 - self.p_y_null))
            rhs = self.rho_XY2 * np.sqrt(self.p_y_alt * (1 - self.p_y_alt))
            if abs(lhs - rhs) > 1e-10:
                raise ValueError(
                    "Inconsistent targets: rho_XY1*sqrt(p1(1-p1)) must equal rho_XY2*sqrt(p2(1-p2))."
                )
            lam = lhs / np.sqrt(self.p_x * (1 - self.p_x))

        elif self.rho_XY1 is not None:
            lam = self.lam_for_rho_XYi(self.rho_XY1, self.p_y_null, self.p_x)

        elif self.rho_XY2 is not None:
            lam = self.lam_for_rho_XYi(self.rho_XY2, self.p_y_alt, self.p_x)

        if lam is None:
            raise ValueError("Provide one of: rho_XY1, rho_XY2, or rho_Y1Y2.")

        # feasibility check
        L = self.lam_bounds(self.p_y_null, self.p_y_alt, self.p_x)
        if abs(lam) > L + 1e-12:
            # report achievable range for the requested target(s)
            rho1_max = (
                np.sqrt(self.p_x * (1 - self.p_x))
                * L
                / np.sqrt(self.p_y_null * (1 - self.p_y_null))
            )
            rho2_max = (
                np.sqrt(self.p_x * (1 - self.p_x))
                * L
                / np.sqrt(self.p_y_alt * (1 - self.p_y_alt))
            )
            rho12_max = (
                self.p_x
                * (1 - self.p_x)
                * L
                * L
                / np.sqrt(
                    self.p_y_null
                    * (1 - self.p_y_null)
                    * self.p_y_alt
                    * (1 - self.p_y_alt)
                )
            )
            raise ValueError(
                f"Target infeasible. |lambda| <= {L:.6f}. "
                f"Max |Corr(X,Y1)|≈{rho1_max:.6f}, |Corr(X,Y2)|≈{rho2_max:.6f}, Corr(Y1,Y2)∈[0,{rho12_max:.6f}]."
            )

        # sample
        # rng = np.random.default_rng(rng)
        X = (self.rng.random(n) < self.p_x).astype(int)
        g = np.where(
            X == 1, lam * (1 - self.p_x), -lam * self.p_x
        )  # np.where(X == 1, 0, 0)

        pi1 = self.p_y_null + g
        pi2 = self.p_y_alt + g
        if (
            (pi1.min() < -1e-12)
            or (pi1.max() > 1 + 1e-12)
            or (pi2.min() < -1e-12)
            or (pi2.max() > 1 + 1e-12)
        ):
            raise RuntimeError(
                "Probabilities fell outside [0,1] despite bounds; numerical issue?"
            )

        Y1 = (self.rng.random(n) < pi1).astype(int)
        Y2 = (self.rng.random(n) < pi2).astype(int)

        # diagnostics
        def corr(a, b):
            va, vb = a.var(), b.var()
            return (
                0.0
                if va == 0 or vb == 0
                else np.cov(a, b, bias=True)[0, 1] / np.sqrt(va * vb)
            )

        stats = dict(
            EX=float(X.mean()),
            EY1=float(Y1.mean()),
            EY2=float(Y2.mean()),
            rho_XY1=float(corr(X, Y1)),
            rho_XY2=float(corr(X, Y2)),
            rho_Y1Y2=float(corr(Y1, Y2)),
            lambda_used=float(lam),
            lambda_bound=float(L),
        )

        if self.bNull:
            return pd.DataFrame({"x": X, "y": Y1})
        else:
            return pd.DataFrame({"x": X, "y": Y2})

    def sample_x(self, n):
        return self.rng.binomial(1, self.p_x, size=n)


class LabelShiftDataSampler:
    def __init__(self, p_y, p_x_y_0, p_x_y_1):
        self.p_y = p_y
        self.p_x_y_0 = p_x_y_0
        self.p_x_y_1 = p_x_y_1
        self.p_x = p_y * p_x_y_1 + (1 - p_y) * p_x_y_0
        p_x_0_y_1 = 1 - self.p_x_y_1
        p_x_0_y_0 = 1 - self.p_x_y_0
        self.tau_0 = p_x_0_y_0 / (p_x_0_y_0 + p_x_0_y_1)
        self.tau_1 = self.p_x_y_0 / (self.p_x_y_0 + self.p_x_y_1)
        self.p_y_given_x_0 = (self.p_y * (1 - self.p_x_y_1)) / (1 - self.p_x)
        self.p_y_given_x_1 = (self.p_y * self.p_x_y_1) / self.p_x
        self.th_0 = self.p_y_given_x_0
        self.th_1 = self.p_y_given_x_1
        self.T_0 = (self.th_0 * self.tau_0) / (
            (1 - self.th_0) * (1 - self.tau_0) + (self.th_0 * self.tau_0)
        )
        self.T_1 = (self.th_1 * self.tau_1) / (
            (1 - self.th_1) * (1 - self.tau_1) + (self.th_1 * self.tau_1)
        )

    def set_rng(self, rng):
        self.rng = rng
        return self

    def sample(self, n):
        y = self.rng.binomial(1, self.p_y, size=n)
        x = [
            (
                self.rng.binomial(1, self.p_x_y_1, size=1).tolist()[0]
                if y_value
                else self.rng.binomial(1, self.p_x_y_0, size=1).tolist()[0]
            )
            for y_value in y
        ]
        return pd.DataFrame({"x": x, "y": y})

    def sample_x(self, n):
        return self.rng.binomial(1, self.p_x, size=n)


class SharedMarginalXDataSampler:
    def __init__(self, p_x, p_y_x_0, p_y_x_1):
        self.p_x = p_x
        self.p_y_x_0 = p_y_x_0
        self.p_y_x_1 = p_y_x_1
        self.p_y = p_x * p_y_x_1 + (1 - p_x) * p_y_x_0

    def set_rng(self, rng):
        self.rng = rng
        return self

    def sample(self, n):
        x = self.rng.binomial(1, self.p_x, size=n)
        y = [
            (
                self.rng.binomial(1, self.p_y_x_1, size=1).tolist()[0]
                if x_value
                else self.rng.binomial(1, self.p_y_x_0, size=1).tolist()[0]
            )
            for x_value in x
        ]
        return pd.DataFrame({"x": x, "y": y})

    def sample_x(self, n):
        return self.rng.binomial(1, self.p_x, size=n)


class SharedDependencyDataSampler:
    def __init__(self, alpha, beta, p_x):
        self.alpha = alpha
        self.beta = beta
        self.p_x = p_x

    def set_rng(self, rng):
        self.rng = rng
        return self

    def sample(self, n):
        x = self.rng.binomial(1, self.p_x, size=n)
        p_y_given_x = self.alpha + self.beta * x
        y = self.rng.binomial(1, p_y_given_x, size=n)
        return pd.DataFrame({"x": x, "y": y})

    def sample_x(self, n):
        return self.rng.binomial(1, self.p_x, size=n)


class CivilCommentsDataSampler:
    def __init__(self, hypothesis="null", target_mean=None, seed=None):
        self.hypothesis = hypothesis
        splits = ["train", "validation", "test"]
        self.data = pd.concat(
            [pd.read_parquet(f"data/civil_comments/{split}") for split in splits]
        ).reset_index(drop=True)
        rng = np.random.default_rng(seed=seed)
        self.rng = rng
        shuffled_indices = rng.permutation(len(self.data))
        self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        # self.data = pd.read_parquet(f"data/civil_comments/{split}")

        if hypothesis == "null":
            self.data = self.data.iloc[:200_000].reset_index(drop=True)
        else:
            self.data = self.data.iloc[200_000:].reset_index(drop=True)

        if target_mean is not None:
            current_mean = self.data["y"].mean()
            if target_mean < current_mean:
                # drop ones to decrease mean
                num_zeros = (self.data["y"] == 0).sum()
                num_ones_needed = int(target_mean * num_zeros / (1 - target_mean))
                num_ones_to_drop = (self.data["y"] == 1).sum() - num_ones_needed
                indices_to_remove = self.data[self.data["y"] == 1].index[
                    :num_ones_to_drop
                ]
                self.data.drop(indices_to_remove, inplace=True)
            elif target_mean > current_mean:
                # drop zeros to increase mean
                num_ones = (self.data["y"] == 1).sum()
                num_zeros_needed = int(num_ones * (1 - target_mean) / target_mean)
                num_zeros_to_drop = (self.data["y"] == 0).sum() - num_zeros_needed
                indices_to_remove = self.data[self.data["y"] == 0].index[
                    :num_zeros_to_drop
                ]
                self.data.drop(indices_to_remove, inplace=True)

        self.current_index = 0
        self.p_y_given_x_0 = self.data[self.data["x"] == 0]["y"].mean()
        self.p_y_given_x_1 = self.data[self.data["x"] == 1]["y"].mean()
        if hypothesis == "null":
            self.p_x = self.data["x"].mean()
            self.p_y = self.data["y"].mean()
            self.p_y_given_x_0 = self.data[self.data["x"] == 0]["y"].mean()
            self.p_y_given_x_1 = self.data[self.data["x"] == 1]["y"].mean()
            self.p_x_y_0 = self.data[self.data["y"] == 0]["x"].mean()
            self.p_x_y_1 = self.data[self.data["y"] == 1]["x"].mean()
            p_x_0_y_1 = 1 - self.p_x_y_1
            p_x_0_y_0 = 1 - self.p_x_y_0
            self.tau_0 = p_x_0_y_0 / (p_x_0_y_0 + p_x_0_y_1)
            self.tau_1 = self.p_x_y_0 / (self.p_x_y_0 + self.p_x_y_1)
            self.th_0 = self.p_y_given_x_0
            self.th_1 = self.p_y_given_x_1
            self.T_0 = (self.th_0 * self.tau_0) / (
                (1 - self.th_0) * (1 - self.tau_0) + (self.th_0 * self.tau_0)
            )
            self.T_1 = (self.th_1 * self.tau_1) / (
                (1 - self.th_1) * (1 - self.tau_1) + (self.th_1 * self.tau_1)
            )

        # else:
        # self.data[self.data["y"] == 1]["x"].mean()

    def set_rng(self, rng):
        # self.rng = rng
        # shuffled_indices = rng.permutation(len(self.data))
        # self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        return self

    def sample(self, n):
        if self.hypothesis == "null":
            x = self.rng.binomial(1, self.p_x, size=n)
            p_y_given_x = np.where(x == 1, self.p_y_given_x_1, self.p_y_given_x_0)
            y = self.rng.binomial(1, p_y_given_x, size=n)
            return pd.DataFrame({"x": x, "y": y})
        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret

    def sample_x(self, n):
        if self.hypothesis == "null":
            return self.rng.binomial(1, self.p_x, size=n)
        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret["x"].to_numpy()

    @classmethod
    def get_samplers(cls):
        # return null, alt
        return cls(split="validation"), cls(split="test")

        # return self.rng.binomial(1, self.p_x, size=n)


class MathJudgeLabelShiftDataSampler:

    def __init__(
        self,
        model_name,
        hypothesis="null",
        target_mean=None,
        target_mean_x=None,
        seed=None,
        judge="",
    ):
        self.model_name = model_name
        self.hypothesis = hypothesis
        splits = ["train", "test"]
        datasets = ["gsm8k", "aqua_rat", "math"]
        # datasets = ["gsm8k", "aqua_rat"]
        data_frames = []
        data_dir = f"data/math_judge{('_' + judge) if judge else ''}"
        for dataset in datasets:
            for split in splits:
                df = pd.read_parquet(
                    f"{data_dir}/{model_name}/{dataset}/{split}/results_xy.parquet"
                )
                data_frames.append(df)
        self.data = pd.concat(data_frames).reset_index(drop=True)
        rng = np.random.default_rng(seed=seed)
        self.rng = rng
        shuffled_indices = rng.permutation(len(self.data))
        self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        # self.data = pd.read_parquet(f"data/civil_comments/{split}")

        # if hypothesis == "null":
        #     self.data = self.data.iloc[:10_000].reset_index(drop=True)
        # else:
        #     self.data = self.data.iloc[10_000:].reset_index(drop=True)

        if target_mean is not None and target_mean_x is not None:
            n00 = ((self.data["x"] == 0) & (self.data["y"] == 0)).sum()
            n01 = ((self.data["x"] == 0) & (self.data["y"] == 1)).sum()
            n10 = ((self.data["x"] == 1) & (self.data["y"] == 0)).sum()
            n11 = ((self.data["x"] == 1) & (self.data["y"] == 1)).sum()

            best_K = 0
            best_counts = (0, 0, 0, 0)

            for K in range(len(self.data), -1, -1):
                if K == 0:
                    break

                Ny = int(round(target_mean * K))
                Nx = int(round(target_mean_x * K))

                z_min = max(0, Nx - n10, Ny - n01, Nx + Ny - K)
                z_max = min(n11, Nx, Ny, n00 + Nx + Ny - K)

                if z_min <= z_max:
                    best_K = K
                    k11 = int(z_min)
                    k10 = Nx - k11
                    k01 = Ny - k11
                    k00 = K - Nx - Ny + k11
                    best_counts = (k00, k01, k10, k11)
                    break

            k00, k01, k10, k11 = best_counts
            idx00 = self.data[(self.data["x"] == 0) & (self.data["y"] == 0)].index[:k00]
            idx01 = self.data[(self.data["x"] == 0) & (self.data["y"] == 1)].index[:k01]
            idx10 = self.data[(self.data["x"] == 1) & (self.data["y"] == 0)].index[:k10]
            idx11 = self.data[(self.data["x"] == 1) & (self.data["y"] == 1)].index[:k11]

            indices_to_keep = np.concatenate([idx00, idx01, idx10, idx11])
            self.data = self.data.loc[indices_to_keep].reset_index(drop=True)

        else:
            if target_mean is not None:
                current_mean = self.data["y"].mean()
                if target_mean < current_mean:
                    # drop ones to decrease mean
                    num_zeros = (self.data["y"] == 0).sum()
                    num_ones_needed = int(target_mean * num_zeros / (1 - target_mean))
                    num_ones_to_drop = (self.data["y"] == 1).sum() - num_ones_needed
                    indices_to_remove = self.data[self.data["y"] == 1].index[
                        :num_ones_to_drop
                    ]
                    self.data.drop(indices_to_remove, inplace=True)
                elif target_mean > current_mean:
                    # drop zeros to increase mean
                    num_ones = (self.data["y"] == 1).sum()
                    num_zeros_needed = int(num_ones * (1 - target_mean) / target_mean)
                    num_zeros_to_drop = (self.data["y"] == 0).sum() - num_zeros_needed
                    indices_to_remove = self.data[self.data["y"] == 0].index[
                        :num_zeros_to_drop
                    ]
                    self.data.drop(indices_to_remove, inplace=True)

            if target_mean_x is not None:
                current_mean_x = self.data["x"].mean()
                if target_mean_x < current_mean_x:
                    # drop ones to decrease mean
                    num_zeros = (self.data["x"] == 0).sum()
                    num_ones_needed = int(
                        target_mean_x * num_zeros / (1 - target_mean_x)
                    )
                    num_ones_to_drop = (self.data["x"] == 1).sum() - num_ones_needed
                    indices_to_remove = self.data[self.data["x"] == 1].index[
                        :num_ones_to_drop
                    ]
                    self.data.drop(indices_to_remove, inplace=True)
                elif target_mean_x > current_mean_x:
                    # drop zeros to increase mean
                    num_ones = (self.data["x"] == 1).sum()
                    num_zeros_needed = int(
                        num_ones * (1 - target_mean_x) / target_mean_x
                    )
                    num_zeros_to_drop = (self.data["x"] == 0).sum() - num_zeros_needed
                    indices_to_remove = self.data[self.data["x"] == 0].index[
                        :num_zeros_to_drop
                    ]
                    self.data.drop(indices_to_remove, inplace=True)

        self.current_index = 0
        self.p_y_x_0 = self.data[self.data["x"] == 0]["y"].mean()
        self.p_y_x_1 = self.data[self.data["x"] == 1]["y"].mean()
        self.p_x_y_0 = self.data[self.data["y"] == 0]["x"].mean()
        self.p_x_y_1 = self.data[self.data["y"] == 1]["x"].mean()
        if hypothesis == "null":
            self.p_x = self.data["x"].mean()
            self.p_y = self.data["y"].mean()
            p_x_0_y_1 = 1 - self.p_x_y_1
            p_x_0_y_0 = 1 - self.p_x_y_0
            self.tau_0 = p_x_0_y_0 / (p_x_0_y_0 + p_x_0_y_1)
            self.tau_1 = self.p_x_y_0 / (self.p_x_y_0 + self.p_x_y_1)
            self.th_0 = self.p_y_x_0
            self.th_1 = self.p_y_x_1
            self.T_0 = (self.th_0 * self.tau_0) / (
                (1 - self.th_0) * (1 - self.tau_0) + (self.th_0 * self.tau_0)
            )
            self.T_1 = (self.th_1 * self.tau_1) / (
                (1 - self.th_1) * (1 - self.tau_1) + (self.th_1 * self.tau_1)
            )

        # else:
        # self.data[self.data["y"] == 1]["x"].mean()

    def set_rng(self, rng):
        # self.rng = rng
        # shuffled_indices = rng.permutation(len(self.data))
        # self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        return self

    def sample(self, n):
        if self.hypothesis == "null":
            x = self.rng.binomial(1, self.p_x, size=n)
            p_y_given_x = np.where(x == 1, self.p_y_x_1, self.p_y_x_0)
            y = self.rng.binomial(1, p_y_given_x, size=n)
            return pd.DataFrame({"x": x, "y": y})
        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret

    def sample_x(self, n):
        if self.hypothesis == "null":
            return self.rng.binomial(1, self.p_x, size=n)
        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret["x"].to_numpy()


class MathJudgeConceptShiftDataSampler:
    def __init__(
        self,
        model_name,
        hypothesis="null",
        target_mean=None,
        target_p_y_x_1=None,
        target_p_y_x_0=None,
        seed=None,
        judge="",
        indices=None,
        alt_model_name=None,
        target_mean_alt=None,
        target_p_y_x_1_alt=None,
        target_p_y_x_0_alt=None,
    ):
        self.model_name = model_name
        self.hypothesis = hypothesis
        splits = ["train", "test"]
        datasets = ["gsm8k", "aqua_rat", "math"]

        # Helper to load data
        def load_data(model):
            data_frames = []
            data_dir = f"data/math_judge{('_' + judge) if judge else ''}"
            for dataset in datasets:
                for split in splits:
                    df = pd.read_parquet(
                        f"{data_dir}/{model}/{dataset}/{split}/results_xy.parquet"
                    )
                    data_frames.append(df)
            return pd.concat(data_frames).reset_index(drop=True)

        self.data = load_data(model_name)

        if alt_model_name:
            self.alt_data_raw = load_data(alt_model_name)
            # Assume alignment and merge y_alt
            self.data["y_alt"] = self.alt_data_raw["y"]
            # x values for alt are taken from null (self.data["x"]) as per requirement

        rng = np.random.default_rng(seed=seed)
        self.rng = rng

        # Shuffle consistently
        shuffled_indices = rng.permutation(len(self.data))
        self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)

        if indices is not None:
            self.data = self.data.loc[indices].reset_index(drop=True)
            self.kept_indices = indices
        elif alt_model_name:
            # Joint filtering logic
            self._apply_joint_constraints(
                target_mean,
                target_p_y_x_1,
                target_p_y_x_0,
                target_mean_alt,
                target_p_y_x_1_alt,
                target_p_y_x_0_alt,
            )
            shuffled_indices = rng.permutation(len(self.data))
            self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        else:
            # Original single-model filtering logic
            self._apply_single_constraints(target_mean, target_p_y_x_1, target_p_y_x_0)

        self.current_index = 0
        self._update_stats()

    def _update_stats(self):
        self.p_y_x_0 = self.data[self.data["x"] == 0]["y"].mean()
        self.p_y_x_1 = self.data[self.data["x"] == 1]["y"].mean()
        self.p_x = self.data["x"].mean()
        self.p_y = self.data["y"].mean()

        if self.hypothesis == "null":
            self.p_x_y_0 = self.data[self.data["y"] == 0]["x"].mean()
            self.p_x_y_1 = self.data[self.data["y"] == 1]["x"].mean()
            p_x_0_y_1 = 1 - self.p_x_y_1
            p_x_0_y_0 = 1 - self.p_x_y_0
            self.tau_0 = (
                p_x_0_y_0 / (p_x_0_y_0 + p_x_0_y_1)
                if (p_x_0_y_0 + p_x_0_y_1) > 0
                else 0
            )
            self.tau_1 = (
                self.p_x_y_0 / (self.p_x_y_0 + self.p_x_y_1)
                if (self.p_x_y_0 + self.p_x_y_1) > 0
                else 0
            )
            self.th_0 = self.p_y_x_0
            self.th_1 = self.p_y_x_1

            denom0 = (1 - self.th_0) * (1 - self.tau_0) + (self.th_0 * self.tau_0)
            self.T_0 = (self.th_0 * self.tau_0) / denom0 if denom0 > 0 else 0

            denom1 = (1 - self.th_1) * (1 - self.tau_1) + (self.th_1 * self.tau_1)
            self.T_1 = (self.th_1 * self.tau_1) / denom1 if denom1 > 0 else 0

    def _apply_single_constraints(self, target_mean, target_p_y_x_1, target_p_y_x_0):
        # ... existing logic adapted to use helper ...
        # For now, I'll just paste the existing logic but slightly refactored to handle p_y_x_0 if provided

        # 1. Adjust Conditional Probabilities
        for x_val, target_p in [(1, target_p_y_x_1), (0, target_p_y_x_0)]:
            if target_p is not None:
                mask = self.data["x"] == x_val
                df_group = self.data[mask]
                current_p = df_group["y"].mean()

                if target_p < current_p:
                    # Drop ones
                    n_zeros = (df_group["y"] == 0).sum()
                    if abs(1 - target_p) < 1e-9:
                        n_ones_needed = len(
                            df_group
                        )  # Should not happen if target < current <= 1
                    else:
                        n_ones_needed = int(target_p * n_zeros / (1 - target_p))

                    n_ones_current = (df_group["y"] == 1).sum()
                    n_drop = n_ones_current - n_ones_needed
                    if n_drop > 0:
                        drop_idx = df_group[df_group["y"] == 1].index[:n_drop]
                        self.data.drop(drop_idx, inplace=True)

                elif target_p > current_p:
                    # Drop zeros
                    n_ones = (df_group["y"] == 1).sum()
                    if abs(target_p) < 1e-9:
                        n_zeros_needed = len(df_group)
                    else:
                        n_zeros_needed = int(n_ones * (1 - target_p) / target_p)

                    n_zeros_current = (df_group["y"] == 0).sum()
                    n_drop = n_zeros_current - n_zeros_needed
                    if n_drop > 0:
                        drop_idx = df_group[df_group["y"] == 0].index[:n_drop]
                        self.data.drop(drop_idx, inplace=True)

        # 2. Adjust Marginal Mean (by balancing X=0 and X=1 groups)
        if target_mean is not None:
            # Calculate current stats after conditional adjustments
            mask_x1 = self.data["x"] == 1
            n1 = mask_x1.sum()
            p1 = self.data.loc[mask_x1, "y"].mean() if n1 > 0 else 0

            mask_x0 = self.data["x"] == 0
            n0 = mask_x0.sum()
            p0 = self.data.loc[mask_x0, "y"].mean() if n0 > 0 else 0

            # We want (n1' * p1 + n0' * p0) / (n1' + n0') = target_mean
            # We keep one group fixed and reduce the other, or reduce both?
            # Strategy: Keep the smaller group (relative to need) fixed, reduce the other.
            # Or better: find max n1', n0' <= n1, n0 satisfying the equation.

            # target_mean = p_x * p1 + (1-p_x) * p0
            # p_x = (target_mean - p0) / (p1 - p0)

            if abs(p1 - p0) > 1e-9:
                target_px = (target_mean - p0) / (p1 - p0)
                if 0 <= target_px <= 1:
                    # We need n1' / (n1' + n0') = target_px
                    # n1' * (1 - target_px) = n0' * target_px
                    # n1' / n0' = target_px / (1 - target_px) = R

                    R = target_px / (1 - target_px) if target_px < 1 else float("inf")

                    # Try keeping all n0, calculate needed n1
                    n1_needed = int(n0 * R)
                    if n1_needed <= n1:
                        # Reduce n1
                        n1_keep = n1_needed
                        n0_keep = n0
                    else:
                        # Reduce n0
                        n0_needed = int(n1 / R) if R > 0 else 0  # if R=inf, n0=0
                        n1_keep = n1
                        n0_keep = n0_needed

                    # Apply drops
                    if n1 > n1_keep:
                        drop_idx = self.data[self.data["x"] == 1].index[n1_keep:]
                        self.data.drop(drop_idx, inplace=True)
                    if n0 > n0_keep:
                        drop_idx = self.data[self.data["x"] == 0].index[n0_keep:]
                        self.data.drop(drop_idx, inplace=True)
            else:
                # p1 == p0. If target_mean != p1, impossible to achieve by rebalancing X.
                # We assume conditional adjustments already handled it or it's impossible.
                pass

        self.kept_indices = self.data.index
        self.data = self.data.reset_index(drop=True)

    def _apply_joint_constraints(
        self,
        target_mean_null,
        target_p1_null,
        target_p0_null,
        target_mean_alt,
        target_p1_alt,
        target_p0_alt,
    ):
        # Defaults
        if target_p1_null is None:
            target_p1_null = self.data[self.data["x"] == 1]["y"].mean()
        if target_p0_null is None:
            target_p0_null = self.data[self.data["x"] == 0]["y"].mean()
        if target_p1_alt is None:
            target_p1_alt = self.data[self.data["x"] == 1]["y_alt"].mean()
        if target_p0_alt is None:
            target_p0_alt = self.data[self.data["x"] == 0]["y_alt"].mean()

        # Determine target p_x from Null targets
        # p_y = p_x * p1 + (1-p_x) * p0
        if target_mean_null is not None and abs(target_p1_null - target_p0_null) > 1e-9:
            target_px = (target_mean_null - target_p0_null) / (
                target_p1_null - target_p0_null
            )
            target_px = max(0.0, min(1.0, target_px))
        else:
            target_px = self.data["x"].mean()

        # Solve for X=0 and X=1 separately
        def solve_group(x_val, p_null_target, p_alt_target):
            mask = self.data["x"] == x_val
            df = self.data[mask]
            n00 = ((df["y"] == 0) & (df["y_alt"] == 0)).sum()
            n01 = ((df["y"] == 0) & (df["y_alt"] == 1)).sum()
            n10 = ((df["y"] == 1) & (df["y_alt"] == 0)).sum()
            n11 = ((df["y"] == 1) & (df["y_alt"] == 1)).sum()

            # Binary search for max K
            low = 0
            high = len(df)
            best_K = 0
            best_counts = (0, 0, 0, 0)

            for K in range(high, -1, -1):
                if K == 0:
                    break

                # Target counts of ones
                Cn = int(round(p_null_target * K))
                Ca = int(round(p_alt_target * K))

                # Constraints on k11 (z)
                # k10 = Cn - z
                # k01 = Ca - z
                # k00 = K - Ca - Cn + z

                z_min = max(0, Cn - n10, Ca - n01, Ca + Cn - K)
                z_max = min(n11, Cn, Ca, n00 + Ca + Cn - K)

                if z_min <= z_max:
                    best_K = K
                    z = int(z_min)  # Pick smallest valid k11
                    k11 = z
                    k10 = Cn - k11
                    k01 = Ca - k11
                    k00 = K - Ca - Cn + k11
                    best_counts = (k00, k01, k10, k11)
                    break
            return best_K, best_counts

        K0, counts0 = solve_group(0, target_p0_null, target_p0_alt)
        K1, counts1 = solve_group(1, target_p1_null, target_p1_alt)

        # Adjust K0, K1 to satisfy target_px
        # K1 / (K0 + K1) = target_px
        if target_px > 0 and target_px < 1:
            R = target_px / (1 - target_px)
            if K1 > K0 * R:
                K1 = int(K0 * R)
            else:
                K0 = int(K1 / R)
        elif target_px == 0:
            K1 = 0
        elif target_px == 1:
            K0 = 0

        # Scale down counts
        def scale_counts(counts, target_K):
            current_K = sum(counts)
            if current_K == 0:
                return (0, 0, 0, 0)
            factor = target_K / current_K
            return tuple(int(c * factor) for c in counts)

        final_counts0 = scale_counts(counts0, K0)
        final_counts1 = scale_counts(counts1, K1)

        # Select indices
        indices_to_keep = []

        def collect_indices(x_val, counts):
            k00, k01, k10, k11 = counts
            mask = self.data["x"] == x_val
            df = self.data[mask]

            idx00 = df[(df["y"] == 0) & (df["y_alt"] == 0)].index[:k00]
            idx01 = df[(df["y"] == 0) & (df["y_alt"] == 1)].index[:k01]
            idx10 = df[(df["y"] == 1) & (df["y_alt"] == 0)].index[:k10]
            idx11 = df[(df["y"] == 1) & (df["y_alt"] == 1)].index[:k11]

            return np.concatenate([idx00, idx01, idx10, idx11])

        indices_to_keep.extend(collect_indices(0, final_counts0))
        indices_to_keep.extend(collect_indices(1, final_counts1))

        self.data = self.data.loc[indices_to_keep].reset_index(drop=True)
        self.kept_indices = indices_to_keep  # These are indices into the shuffled data

    def get_alt_sampler(self):
        if "y_alt" not in self.data.columns:
            raise ValueError("Alt data not loaded. Provide alt_model_name in init.")

        # Create a new sampler instance
        # We can cheat and create an empty one, then populate it
        alt_sampler = MathJudgeConceptShiftDataSampler(
            model_name=self.model_name, hypothesis="alt", seed=None  # Dummy
        )
        # Populate with alt data
        alt_df = self.data.copy()
        alt_df["y"] = alt_df["y_alt"]
        alt_df = alt_df.drop(columns=["y_alt"])

        alt_sampler.data = alt_df
        alt_sampler.kept_indices = self.kept_indices
        alt_sampler._update_stats()

        return alt_sampler

    def set_rng(self, rng):
        # self.rng = rng
        # shuffled_indices = rng.permutation(len(self.data))
        # self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        return self

    def sample(self, n):
        if self.hypothesis == "null":
            x = self.rng.binomial(1, self.p_x, size=n)
            p_y_given_x = np.where(x == 1, self.p_y_x_1, self.p_y_x_0)
            y = self.rng.binomial(1, p_y_given_x, size=n)
            return pd.DataFrame({"x": x, "y": y})
        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret

    def sample_x(self, n):
        if self.hypothesis == "null":
            return self.rng.binomial(1, self.p_x, size=n)
        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret["x"].to_numpy()


class MathSharedDependencyDataSampler:
    # in the init, get model, read results from all datasets for this model and add dataset category
    def __init__(
        self,
        model_name,
        seed,
        hypothesis="null",
        target_mean=None,
        target_p_y_x_0=None,
        target_p_y_x_1=None,
    ):
        self.model_name = model_name
        self.hypothesis = hypothesis
        splits = ["train", "test"]
        # datasets = ["gsm8k", "aqua_rat", "math"]
        datasets = ["gsm8k", "aqua_rat"]
        data_frames = []
        for dataset in datasets:
            for split in splits:
                df = pd.read_parquet(
                    f"data/math/{model_name}/{dataset}/{split}/results.parquet"
                )
                df["x"] = dataset
                data_frames.append(df)
        self.data = pd.concat(data_frames).reset_index(drop=True)
        self.data["x"] = self.data["x"].map({"gsm8k": 0, "aqua_rat": 1, "math": 2})
        rng = np.random.default_rng(seed=seed)
        self.rng = rng
        shuffled_indices = rng.permutation(len(self.data))
        self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        # Calculate available counts
        mask0 = self.data["x"] == 0
        mask1 = self.data["x"] == 1

        A00 = ((self.data["y"] == 0) & mask0).sum()
        A01 = ((self.data["y"] == 1) & mask0).sum()
        A10 = ((self.data["y"] == 0) & mask1).sum()
        A11 = ((self.data["y"] == 1) & mask1).sum()

        # Determine target p0, p1
        p0 = target_p_y_x_0
        if p0 is None:
            p0 = A01 / (A00 + A01) if (A00 + A01) > 0 else 0

        p1 = target_p_y_x_1
        if p1 is None:
            p1 = A11 / (A10 + A11) if (A10 + A11) > 0 else 0

        # Calculate max group sizes
        def get_max_n(A0, A1, p):
            if p == 0:
                return A0
            if p == 1:
                return A1
            # N * p <= A1  -> N <= A1/p
            # N * (1-p) <= A0 -> N <= A0/(1-p)
            return min(int(A1 / p), int(A0 / (1 - p)))

        N0_max = get_max_n(A00, A01, p0)
        N1_max = get_max_n(A10, A11, p1)

        N0_final = N0_max
        N1_final = N1_max

        if target_mean is not None:
            if abs(p1 - p0) > 1e-9:
                target_px = (target_mean - p0) / (p1 - p0)
                # Clip to [0, 1]
                target_px = max(0.0, min(1.0, target_px))

                if target_px == 0:
                    N1_final = 0
                    N0_final = N0_max
                elif target_px == 1:
                    N0_final = 0
                    N1_final = N1_max
                else:
                    R = target_px / (1 - target_px)
                    # N1 = R * N0
                    # N0 <= N0_max
                    # R * N0 <= N1_max -> N0 <= N1_max / R

                    N0_final = min(N0_max, int(N1_max / R))
                    N1_final = int(N0_final * R)
            else:
                # p1 == p0. If target_mean != p0, impossible to adjust via P(X).
                # We just keep max data.
                pass

        # Calculate final counts for each cell
        k01 = int(N0_final * p0)
        k00 = N0_final - k01

        k11 = int(N1_final * p1)
        k10 = N1_final - k11

        # Select indices
        idx00 = self.data[(self.data["x"] == 0) & (self.data["y"] == 0)].index[:k00]
        idx01 = self.data[(self.data["x"] == 0) & (self.data["y"] == 1)].index[:k01]
        idx10 = self.data[(self.data["x"] == 1) & (self.data["y"] == 0)].index[:k10]
        idx11 = self.data[(self.data["x"] == 1) & (self.data["y"] == 1)].index[:k11]

        indices_to_keep = np.concatenate([idx00, idx01, idx10, idx11])
        self.data = self.data.loc[indices_to_keep].reset_index(drop=True)

        self.current_index = 0
        self.p_y_x_0 = (
            self.data[self.data["x"] == 0]["y"].mean() if N0_final > 0 else p0
        )
        self.p_y_x_1 = (
            self.data[self.data["x"] == 1]["y"].mean() if N1_final > 0 else p1
        )
        if hypothesis == "null":
            self.p_x = self.data["x"].mean()
            self.p_y = self.data["y"].mean()

        shuffled_indices = rng.permutation(len(self.data))
        self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)

    def set_rng(self, rng):
        # self.rng = rng
        # shuffled_indices = rng.permutation(len(self.data))
        # self.data = self.data.iloc[shuffled_indices].reset_index(drop=True)
        return self

    def sample(self, n):
        if self.hypothesis == "null":
            x = self.rng.binomial(1, self.p_x, size=n)
            p_y_given_x = np.where(x == 1, self.p_y_x_1, self.p_y_x_0)
            y = self.rng.binomial(1, p_y_given_x, size=n)
            return pd.DataFrame({"x": x, "y": y})

        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret

    def sample_x(self, n):
        if self.hypothesis == "null":
            return self.rng.binomial(1, self.p_x, size=n)
        if self.current_index + n > len(self.data):
            raise ValueError("No more unique rows available to sample.")
        ret = self.data.iloc[self.current_index : self.current_index + n].reset_index(
            drop=True
        )
        self.current_index += n
        return ret["x"].to_numpy()


# %%

# rng_np.binomial(1, 0.5, size=150)
# sampler.sample_x(150)

# %%
