import numpy as np
from bt_state import BTState
from bt_utils import build_pair_arrays, get_pairs, get_boundary_pairs
from bt_kl import kl_projection_bt, kl_divergence_bt_vec, ETA_CLIP


def sample_pair(P, i, j, rng):
    return 1.0 if rng.random() < P[i, j] else 0.0


class AdaptiveSampler:
    def __init__(self, n, k, gamma=0.33, eta_exp_w=0.2, eta_exp_q=0.2,
                 eta_const_w=1.0, eta_const_q=1.0, min_round_robins=1):
        self.n = n
        self.k = k
        self.gamma = gamma
        self.eta_exp_w = eta_exp_w
        self.eta_exp_q = eta_exp_q
        self.eta_const_w = eta_const_w
        self.eta_const_q = eta_const_q

        self.state = BTState(n, min_round_robins=min_round_robins)
        self.pair_i = self.state.pair_i
        self.pair_j = self.state.pair_j
        self.pair_to_idx = self.state.pair_to_idx
        self.num_pairs = self.state.num_pairs

        self.w = np.ones(self.num_pairs) / self.num_pairs
        self.S_w = np.zeros(self.num_pairs)
        self.S_q = np.zeros(self.num_pairs)
        self.log_r = np.zeros(self.num_pairs)     
       
        self.P_track = np.zeros(self.num_pairs)
        self.N_track = np.zeros(self.num_pairs)

        self.t = 0
        self.num_updates = 0

    def step(self, P, rng):
        # Warmup 
        if not self.state.warmup_complete:
            pair = self.state.warmup_next_pair()
            if pair is not None:
                i, j = pair
                y = sample_pair(P, i, j, rng)
                self.state.warmup_observe(i, j, y)
            return

        self.t += 1

        # C-tracking with mixing 
        rho = self.t ** (-self.gamma)
        w_tilde = (1 - rho) * self.w + rho / self.num_pairs
        self.P_track += w_tilde

        # Select pair with largest gap
        pair_idx = np.argmax(self.P_track - self.N_track)
        i, j = int(self.pair_i[pair_idx]), int(self.pair_j[pair_idx])

        # Sample and observe
        y = sample_pair(P, i, j, rng)
        self.state.add_observation(i, j, y)
        self.N_track[pair_idx] += 1

        # Primal-dual update
        if self.state.mle_exists():
            self._update_weights(rng)

    def _update_weights(self, rng):
        self.num_updates += 1
        theta_hat = self.state.get_mle()
        boundary = get_boundary_pairs(theta_hat, self.k)
        m = len(boundary)

        boundary_indices = [self.pair_to_idx[(bi, bj)] for (bi, bj, _, _) in boundary]
        log_r_boundary = self.log_r[boundary_indices]
        log_r_boundary = log_r_boundary - log_r_boundary.max()
        q = np.exp(log_r_boundary)
        q /= q.sum()

        I_idx = rng.integers(m)
        bi, bj, bu, bv = boundary[I_idx]
        theta_star, gamma_val = kl_projection_bt(
            theta_hat, self.pair_i, self.pair_j, self.w, bu, bv, self.n
        )

        eta_hat = theta_hat[self.pair_i] - theta_hat[self.pair_j]
        eta_star = theta_star[self.pair_i] - theta_star[self.pair_j]
        eta_hat = np.clip(eta_hat, -ETA_CLIP, ETA_CLIP)
        eta_star = np.clip(eta_star, -ETA_CLIP, ETA_CLIP)
        d_ab = np.maximum(0, kl_divergence_bt_vec(eta_hat, eta_star))

        q_I = q[I_idx]
        self.S_w += m * q_I * d_ab
        I_pair_idx = self.pair_to_idx[(bi, bj)]
        self.S_q[I_pair_idx] += m * gamma_val

        eta_w = self.eta_const_w * (self.num_updates + 1) ** (-self.eta_exp_w)
        eta_q = self.eta_const_q * (self.num_updates + 1) ** (-self.eta_exp_q)

        # Primal update 
        log_w = -np.log(self.num_pairs) + eta_w * self.S_w
        log_w -= log_w.max()
        self.w = np.exp(log_w)
        self.w /= self.w.sum()

        # Dual update
        self.log_r = -eta_q * self.S_q


class OracleSampler:
    def __init__(self, n, w_opt, min_round_robins=1):
        self.n = n
        self.state = BTState(n, min_round_robins=min_round_robins)
        self.pairs = get_pairs(n)
        self.num_pairs = len(self.pairs)

        self.w = w_opt / w_opt.sum()
        self.P_track = np.zeros(self.num_pairs)
        self.N_track = np.zeros(self.num_pairs)
        self.t = 0

    def step(self, P, rng):
        if not self.state.warmup_complete:
            pair = self.state.warmup_next_pair()
            if pair is not None:
                i, j = pair
                y = sample_pair(P, i, j, rng)
                self.state.warmup_observe(i, j, y)
            return

        self.t += 1
        self.P_track += self.w
        pair_idx = np.argmax(self.P_track - self.N_track)
        i, j = self.pairs[pair_idx]
        y = sample_pair(P, i, j, rng)
        self.state.add_observation(i, j, y)
        self.N_track[pair_idx] += 1


