import numpy as np
import os
import time
from bt_utils import get_top_k, get_boundary_pairs, get_pairs, build_P
from bt_mle import glr_statistic
from samplers import AdaptiveSampler, OracleSampler

def _build_count_laplacian(n, N, pair_i, pair_j):
    L = np.zeros((n, n))
    for idx in range(len(pair_i)):
        i, j = pair_i[idx], pair_j[idx]
        count = N[idx]
        L[i, i] += count
        L[j, j] += count
        L[i, j] -= count
        L[j, i] -= count
    return L

def threshold_gaussian(t, delta, n, theta_hat, N, pair_i, pair_j, lambda_param=1.0, sigma_bar=0.25):
    L = _build_count_laplacian(n, N, pair_i, pair_j)
    c = sigma_bar / lambda_param
    eigenvalues = np.linalg.eigvalsh(L)
    log_det = np.sum(np.log(1.0 + c * np.maximum(eigenvalues, 0)))
    penalty = (lambda_param / 2.0) * np.sum(theta_hat ** 2)
    return np.log(1.0 / delta) + penalty + 0.5 * log_det

def load_optimal_w(n, k, gap):
    """Load precomputed optimal w from oracle_allocation folder."""
    filename = f'optimal_w_n{n}_k{k}_gap{gap}.npz'
    path = os.path.join(os.path.dirname(__file__), 'oracle_allocation', filename)
    if os.path.exists(path):
        data = np.load(path)
        return data['w_opt'], float(data['Gamma_opt']), data['theta']
    raise FileNotFoundError(f"Could not find oracle file: {path}")

def shuffle_w_opt(w_opt, perm, n):
    pairs = get_pairs(n)
    w_shuffled = np.zeros_like(w_opt)
    for idx, (i, j) in enumerate(pairs):
        ci, cj = perm[i], perm[j]
        if ci > cj:
            ci, cj = cj, ci
        canonical_idx = ci * n - ci * (ci + 1) // 2 + (cj - ci - 1)
        w_shuffled[idx] = w_opt[canonical_idx]
    return w_shuffled

def compute_Z_min(state, theta_hat, k, threshold=None):
    boundary = get_boundary_pairs(theta_hat, k)
    Z_min = float('inf')
    for (_, _, bu, bv) in boundary:
        Z = glr_statistic(state.n, state.N, state.S, state.pair_i, state.pair_j, theta_hat, bu, bv)
        if Z < Z_min:
            Z_min = Z
        if threshold is not None and Z_min < threshold:
            return Z_min
    return Z_min

# Run one (seed, algorithm) combination for stopping experiment. Returns dict with results.
def run_single(seed, alg, n, k, delta_values,
               gamma=0.33, eta_exp_w=0.2, eta_exp_q=0.2, eta_const_w=1.0, eta_const_q=1.0,
               min_round_robins=1, glr_min_t=500000, glr_freq=10000, lambda_param=1.0,
               theta_mode="gap",
               gap=0.1,
               theta_min=-5.0, theta_max=5.0,
               sst_spread=0.05, boundary_min=0.01):
    
    # Build P matrix and ground truth
    P, true_top_k, perm, theta = build_P(
        n, k, theta_mode, seed,
        gap=gap,
        theta_min=theta_min, theta_max=theta_max,
        sst_spread=sst_spread,
        boundary_min=boundary_min
    )

    # Handle SEEKS, SEEKS_V2m active_ranking separately (no GLR stopping)
    if alg == 'seeks':
        from ren20 import seeks
        S = list(range(n))
        stopped = {}
        correct = {}
        for delta in delta_values:
            np.random.seed(seed)
            top_k_out, num_comps = seeks(S, P, k, delta)
            stopped[delta] = num_comps
            correct[delta] = 1 if set(top_k_out) == true_top_k else 0
        return {'seed': seed, 'alg': alg, 'stopped': stopped, 'correct': correct}

    if alg == 'seeks_v2':
        from ren20 import seeks_v2
        S = list(range(n))
        stopped = {}
        correct = {}
        for delta in delta_values:
            np.random.seed(seed)
            top_k_out, num_comps = seeks_v2(S, P, k, delta)
            stopped[delta] = num_comps
            correct[delta] = 1 if set(top_k_out) == true_top_k else 0
        return {'seed': seed, 'alg': alg, 'stopped': stopped, 'correct': correct}

    if alg == 'active_ranking':
        from active_ranking import active_ranking
        S = list(range(n))
        stopped = {}
        correct = {}
        for delta in delta_values:
            np.random.seed(seed)
            top_k_out, num_comps = active_ranking(S, P, k, delta)
            stopped[delta] = num_comps
            correct[delta] = 1 if set(top_k_out) == true_top_k else 0
        return {'seed': seed, 'alg': alg, 'stopped': stopped, 'correct': correct}

    # Create sampler for GLR-based methods
    rng_alg = np.random.default_rng(seed)
    if alg == 'oracle':
        if theta_mode != "gap":
            raise ValueError("Oracle sampler only available for gap mode (requires precomputed w*)")
        w_opt, _, _ = load_optimal_w(n, k, gap)
        w_shuffled = shuffle_w_opt(w_opt, perm, n)
        sampler = OracleSampler(n, w_shuffled, min_round_robins)
    elif alg == 'adaptive':
        sampler = AdaptiveSampler(n, k, gamma, eta_exp_w, eta_exp_q,
                                  eta_const_w, eta_const_q, min_round_robins)
    else:
        raise ValueError(f"Unknown algorithm: {alg}")

    stopped = {d: None for d in delta_values}
    t_post_warmup = 0
    lambda_warmup = None 

    while True:
        was_warmup = not sampler.state.warmup_complete
        sampler.step(P, rng_alg)

        # After warmup completes, compute lambda = (n-1) / ||theta_hat_warmup||²
        if was_warmup and sampler.state.warmup_complete:
            if sampler.state.mle_exists():
                theta_hat_warmup = sampler.state.get_mle()
                norm_sq_warmup = np.sum(theta_hat_warmup ** 2)
                lambda_warmup = (n - 1) / norm_sq_warmup

        if was_warmup:
            continue
        t_post_warmup += 1

        # Check stopping
        if t_post_warmup >= glr_min_t and t_post_warmup % glr_freq == 0:
            if not sampler.state.mle_exists():
                continue
            theta_hat = sampler.state.get_mle()
            largest_unstopped = max((d for d in delta_values if stopped[d] is None), default=None)
            if largest_unstopped is None:
                break
            # Use warmup-based lambda if available, else fall back to lambda_param
            lam = lambda_warmup if lambda_warmup is not None else lambda_param
            min_beta = threshold_gaussian(t_post_warmup, largest_unstopped, n, theta_hat,
                                          sampler.state.N, sampler.state.pair_i,
                                          sampler.state.pair_j, lam)
            Z_min = compute_Z_min(sampler.state, theta_hat, k, threshold=min_beta)

            for delta in sorted(delta_values, reverse=True):
                if stopped[delta] is None:
                    beta = threshold_gaussian(t_post_warmup, delta, n, theta_hat,
                                              sampler.state.N, sampler.state.pair_i,
                                              sampler.state.pair_j, lam)
                    if Z_min >= beta:
                        num_comps = sampler.state.total_comparisons()
                        stopped[delta] = num_comps
                    else:
                        break

            if all(v is not None for v in stopped.values()):
                break

    correct = {}
    if sampler.state.mle_exists():
        theta_hat_final = sampler.state.get_mle()
        estimated_top_k = set(np.argsort(theta_hat_final)[-k:])
        is_correct = estimated_top_k == true_top_k
        for delta in delta_values:
            correct[delta] = 1 if is_correct else 0 if stopped[delta] is not None else None
    else:
        for delta in delta_values:
            correct[delta] = None
    return {'seed': seed, 'alg': alg, 'stopped': stopped, 'correct': correct}


def main():
    n = 20
    k = 5
    num_sims = 1
    seed_base = 1
    algorithms = ['oracle','adaptive','seeks_v2']  # 'adaptive', 'oracle', 'seeks',   'seeks_v2', 'active_ranking'                                                  
                                 
    delta_values = [0.01]
    glr_min_t = 1000
    glr_freq = 1000

    # MODE: "gap", "uniform", or "sst_boundary_min"
    theta_mode = "gap"

    # Gap mode params
    gap = 0.1

    # Uniform mode params
    theta_min = -5.0
    theta_max = 5.0

    # SST boundary_min mode params
    sst_spread = 0.05
    boundary_min = 0.01

    print(f"n={n}, k={k}, mode={theta_mode}")
    if theta_mode == "gap":
        print(f"gap={gap}")
    elif theta_mode == "uniform":
        print(f"theta ~ U[{theta_min}, {theta_max}]")
    elif theta_mode == "sst_boundary_min":
        print(f"SST boundary_min: P[k,k+1] in [{0.5+boundary_min}, {0.5+sst_spread}]")
    print(f"Delta values: {[f'{d:.0e}' for d in delta_values]}")

    if 'oracle' in algorithms:
        if theta_mode != "gap":
            print("WARNING: Oracle only available for gap mode, removing from algorithms")
            algorithms = [a for a in algorithms if a != 'oracle']
        else:
            _, Gamma_opt, _ = load_optimal_w(n, k, gap)
            print(f"Gamma*={Gamma_opt:.6f}")

    print(f"Algorithms: {algorithms}")
    print(f"num_sims={num_sims}\n")

    all_stopped = {alg: {d: [] for d in delta_values} for alg in algorithms}
    all_correct = {alg: {d: [] for d in delta_values} for alg in algorithms}
    t0 = time.time()

    for sim in range(num_sims):
        seed = seed_base + sim
        for alg in algorithms:
            result = run_single(seed, alg, n, k, delta_values,
                                gamma=0.33, eta_exp_w=0.2, eta_exp_q=0.2,
                                eta_const_w=1.0, eta_const_q=1.0,
                                min_round_robins=1, glr_min_t=glr_min_t,
                                glr_freq=glr_freq, lambda_param=1.0,
                                theta_mode=theta_mode,
                                gap=gap,
                                theta_min=theta_min, theta_max=theta_max,
                                sst_spread=sst_spread,
                                boundary_min=boundary_min)
            for d in delta_values:
                all_stopped[alg][d].append(result['stopped'][d])
                all_correct[alg][d].append(result['correct'][d])

    # Print results
    print(f"\nResults:")
    for alg in algorithms:
        print(f"\n{alg}:")
        for delta in delta_values:
            times = all_stopped[alg][delta]
            correct = all_correct[alg][delta]
            n_correct = sum(correct)
            print(f"  delta={delta:.0e}: E[tau]={np.mean(times):.0f}, correct={n_correct}/{len(times)}")

    print(f"\nTotal time: {time.time()-t0:.1f}s")
    return all_stopped, all_correct


if __name__ == "__main__":
    all_stopped, all_correct = main()
