import numpy as np
import os
import time
from bt_utils import build_pair_arrays, get_boundary_pairs, build_theta
from bt_kl import kl_projection_bt, kl_divergence_bt_vec

def compute_Gamma(theta, pair_i, pair_j, w, boundary):
    return min(kl_projection_bt(theta, pair_i, pair_j, w, bu, bv, len(theta))[1]
               for (_, _, bu, bv) in boundary)

def solve_optimal_w(n, k, gap,
                    base_iters=1000, max_doublings=16,
                    eta_exp_w=0.2, eta_exp_q=0.2, eta_const_w=1.0, eta_const_q=1.0,
                    convergence_tol=0.005, seed=1, verbose=True):
    theta = build_theta(n, gap)
    boundary = get_boundary_pairs(theta, k)
    m = len(boundary)

    pair_i, pair_j, _ = build_pair_arrays(n)
    num_pairs = len(pair_i)

    # Doubling checkpoints
    checkpoints = {base_iters * (2 ** i) for i in range(max_doublings + 1)}
    max_iters = max(checkpoints)

    w = np.ones(num_pairs) / num_pairs
    S_w = np.zeros(num_pairs)
    S_q = np.zeros(m)
    w_sum = np.zeros(num_pairs)

    rng = np.random.default_rng(seed)
    results = {}
    t0 = time.time()

    for t in range(1, max_iters + 1):
        eta_w = eta_const_w * (t ** (-eta_exp_w))
        eta_q = eta_const_q * (t ** (-eta_exp_q))
        w_sum += w

        log_q = -eta_q * S_q
        q = np.exp(log_q - log_q.max())
        q /= q.sum()

        idx = rng.integers(m)
        _, _, bu, bv = boundary[idx]
        theta_star, gamma = kl_projection_bt(theta, pair_i, pair_j, w, bu, bv, n)

        eta_vec = theta[pair_i] - theta[pair_j]
        eta_star = theta_star[pair_i] - theta_star[pair_j]
        grad = np.maximum(0.0, kl_divergence_bt_vec(eta_vec, eta_star))

        S_w += q[idx] * grad * m
        S_q[idx] += gamma * m

        log_w = eta_w * S_w
        w = np.exp(log_w - log_w.max())
        w /= w.sum()

        if t in checkpoints:
            w_avg = w_sum / t
            w_avg /= w_avg.sum()
            Gamma = compute_Gamma(theta, pair_i, pair_j, w_avg, boundary)
            results[t] = Gamma

            t_half = t // 2
            elapsed = time.time() - t0
            if t_half in results:
                delta = abs(Gamma - results[t_half]) / Gamma
                if verbose:
                    print(f"T={t:>8}: Gamma={Gamma:.6f}, delta={delta*100:.2f}%, {elapsed:.1f}s")
                if delta < convergence_tol:
                    if verbose:
                        print("Converged")
                    break
            elif verbose:
                print(f"T={t:>8}: Gamma={Gamma:.6f}, {elapsed:.1f}s")

    w_avg = w_sum / t
    w_avg /= w_avg.sum()
    Gamma_final = compute_Gamma(theta, pair_i, pair_j, w_avg, boundary)
    return w_avg, Gamma_final, theta, t


def main():
    n = 100
    k = 5
    gap = 0.1
    base_iters = 10000
    max_doublings = 16
    eta_exp_w = 0.2
    eta_exp_q = 0.2
    eta_const_w = 1.0
    eta_const_q = 1.0
    convergence_tol = 0.001
    seed = 1
    save_data = True

    print(f"n={n}, k={k}, gap={gap}\n")

    w_opt, Gamma_opt, theta, converged_at = solve_optimal_w(
        n, k, gap,
        base_iters=base_iters, max_doublings=max_doublings,
        eta_exp_w=eta_exp_w, eta_exp_q=eta_exp_q,
        eta_const_w=eta_const_w, eta_const_q=eta_const_q,
        convergence_tol=convergence_tol, seed=seed
    )

    print(f"\nOptimal Gamma: {Gamma_opt:.6f}")
    print(f"Converged at iteration: {converged_at}")

    if save_data:
        out_dir = os.path.join(os.path.dirname(__file__), 'oracle_allocation')
        os.makedirs(out_dir, exist_ok=True)
        path = os.path.join(out_dir, f'optimal_w_n{n}_k{k}_gap{gap}.npz')
        np.savez(path, w_opt=w_opt, Gamma_opt=Gamma_opt, theta=theta,
                 converged_at=converged_at, n=n, k=k, gap=gap)
        print(f"Saved: {path}")


if __name__ == "__main__":
    main()
