import numpy as np
import json
from scipy.special import expit
from sklearn.preprocessing import StandardScaler
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch.nn as nn
import torch.optim as optim

from stream import StreamingRFFCoresetXY_LazyPhi
from simple import SimplexLogits
from rff import RFFRBFMap, _PointXYLazyPhi, _LevelBufferLazyPhi

import torch
import time



def run_ipmrff_ate_rff_cached(
    X_t, Y_t,
    X_c, Y_c,
    ratio=0.1,
    rff_map=None,              
    epochs=1000,
    lr=0.03,
    seed=0,
):
    if rff_map is None:
        raise ValueError("Please pass rff_map (your cached RFFRBFMap)")

    # ---- coreset size ----
    m0 = max(2, int(ratio * len(X_t)))
    m1 = max(2, int(ratio * len(X_c)))

    # ---- build RFF streaming coresets (cache phi) ----
    co_t = StreamingRFFCoresetXY_LazyPhi(
        buffer_size=m0,
        rff_map=rff_map,
        randomized=True,
        seed=seed,
    )
    co_c = StreamingRFFCoresetXY_LazyPhi(
        buffer_size=m1,
        rff_map=rff_map,
        randomized=True,
        seed=seed,
    )

    for x, y in zip(X_t, Y_t):
        co_t.add(x, y)
    for x, y in zip(X_c, Y_c):
        co_c.add(x, y)

    S0 = co_t.finalize_with_phi()   # (x,y,w,phi)
    S1 = co_c.finalize_with_phi()

    # ---- unpack ----
    X_core0 = np.stack([x for (x, y, w, phi) in S0], axis=0)
    X_core1 = np.stack([x for (x, y, w, phi) in S1], axis=0)

    Y_core0 = np.array([y for (x, y, w, phi) in S0], dtype=np.float64)
    Y_core1 = np.array([y for (x, y, w, phi) in S1], dtype=np.float64)

    U0 = np.array([w for (x, y, w, phi) in S0], dtype=np.float64)
    U1 = np.array([w for (x, y, w, phi) in S1], dtype=np.float64)

    # ---- IMPORTANT: use cached phi ----
    Phi0 = np.stack([phi for (x, y, w, phi) in S0], axis=0)
    Phi1 = np.stack([phi for (x, y, w, phi) in S1], axis=0)

    out = learn_two_mmd_separate_from_phi(
        Phi0, U0,
        Phi1, U1,
        epochs=epochs,
        lr=lr,
    )

    v_IPM = out["z_t"]
    w_IPM = out["z_c"]

    ATE_hat = Y_core0 @ v_IPM - Y_core1 @ w_IPM

    return {
        "ATE": float(ATE_hat),
        "X_core0": X_core0, "Y_core0": Y_core0, "U0": U0, "Phi0": Phi0,
        "X_core1": X_core1, "Y_core1": Y_core1, "U1": U1, "Phi1": Phi1,
        "S0": S0, "S1": S1,
        "max_w_t": out["max_w_t"],
        "max_w_c": out["max_w_c"],
    }


def load_hetero_batch(seed=1, n=200, standardize=True, batch_id=0):
    rng = np.random.default_rng(seed)

    # X ~ Unif(-2,2)^6
    X = rng.uniform(-2.0, 2.0, size=(n, 6)).astype(np.float64)
    x1, x2, x3, x4, x5, x6 = [X[:, j] for j in range(6)]

    # mild max (per-sample)
    m1 = np.maximum.reduce([x1, x3, x6])   # (n,)
    m2 = np.maximum.reduce([x2, x4, x5])   # (n,)

    # ------------------------
    # propensity score mu1 (fixed across batches)
    # ------------------------
    mu1 = (
        0.9 * np.tanh(0.8 * m1)
        + 0.6 * np.sin(x2 + 0.5 * x4)
        + 0.4 * np.cos(1.5 * x5)
        + 0.25 * x3
        - 0.20 * (x4 ** 2)
        + 0.35 * np.log1p(m2 ** 2)
        - 2.2
    )

    # treatment T
    t_prime = rng.normal(loc=mu1, scale=np.sqrt(0.5), size=n)
    p = expit(t_prime)
    T = rng.binomial(1, p, size=n).astype(int)

    # ------------------------
    # batch-dependent tau(X)
    # Batch 1..5 -> use X1..X5 respectively
    # ------------------------
    b = int(batch_id % 5) + 1  # 1..5

    if b == 1:
        tau_x = 3.0 * np.cos(2.0 * x1) + 2.0 * np.sin(x1)
    elif b == 2:
        tau_x = 3.0 * np.cos(2.0 * x2) + 2.0 * np.sin(x2)
    elif b == 3:
        tau_x = 3.0 * np.cos(2.0 * x3) + 2.0 * np.sin(x3)
    elif b == 4:
        tau_x = 3.0 * np.cos(2.0 * x4) + 2.0 * np.sin(x4)
    elif b == 5:
        tau_x = 3.0 * np.cos(2.0 * x5) + 2.0 * np.sin(x5)
    else:
        raise RuntimeError("batch index error")

    # ------------------------
    # batch-dependent baseline b(X)
    # ------------------------
    base = 2.6 * (2.0 * x4 - 1.0) ** 2

    if b >= 2:
        base = base + 1.8 * (x1 - 2.0) ** 2
    if b >= 3:
        base = base + 0.8 * (x3 ** 2)
    if b >= 4:
        base = base + 0.7 * np.sin(x2) / (1.0 + x6 ** 2)
    # b==5: same as b==4 baseline (no extra term)

    # outcome mean
    mu2 = base + tau_x * T

    # observed outcome
    Y = rng.normal(loc=mu2, scale=np.sqrt(0.1), size=n).astype(np.float64)

    # split treated/control
    idx_t = (T == 1)
    X_t = X[idx_t]
    X_c = X[~idx_t]
    Y_t = Y[idx_t]
    Y_c = Y[~idx_t]

    # gt_ate = E[tau(X)] with Xi ~ Unif(-2,2)
    # E[sin(Xi)] = 0, E[cos(2Xi)] = sin(4)/4
    gt_ate = float(3.0 * 0.25 * np.sin(4.0))

    # standardize like your pipeline (on pooled treated+control within batch)
    if standardize:
        X_all = np.vstack([X_t, X_c])
        scaler = StandardScaler()
        X_all_std = scaler.fit_transform(X_all)
        n_t = X_t.shape[0]
        X_t = X_all_std[:n_t]
        X_c = X_all_std[n_t:]

    return X_t, X_c, Y_t, Y_c, gt_ate


def learn_one_side_mmd_optimize_z_from_phi(
    Phi_src: np.ndarray,
    U_src: np.ndarray,
    Phi_tgt: np.ndarray,
    U_tgt: np.ndarray,
    epochs: int = 1000,
    lr: float = 0.03,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    """
    Solve:
        min_z || z^T Phi_src - zt^T Phi_tgt ||^2
    where z is simplex weights on src points.

    Phi_src: (n_src, Dphi)
    Phi_tgt: (n_tgt, Dphi)
    """

    Phi_s = torch.tensor(np.asarray(Phi_src, np.float64), device=device, dtype=torch.float64)
    Phi_t = torch.tensor(np.asarray(Phi_tgt, np.float64), device=device, dtype=torch.float64)

    zt = np.asarray(U_tgt, np.float64)
    zt = zt / (zt.sum() + 1e-12)
    zt = torch.tensor(zt, device=device, dtype=torch.float64)

    model = SimplexLogits(U_src).to(device)
    opt = optim.Adam(model.parameters(), lr=lr)

    with torch.no_grad():
        mu_t = zt @ Phi_t  # (Dphi,)

    for _ in range(epochs):
        z = model()        # (n_src,)
        mu_s = z @ Phi_s   # (Dphi,)
        loss = torch.sum((mu_s - mu_t) ** 2)

        opt.zero_grad()
        loss.backward()
        opt.step()

    z_np = model().detach().cpu().numpy()
    return z_np, float(z_np.max())

def learn_two_mmd_separate_from_phi(
    Phi0: np.ndarray, U0: np.ndarray,
    Phi1: np.ndarray, U1: np.ndarray,
    epochs: int = 1000,
    lr: float = 0.03,
):
    Phi_tgt = np.vstack([Phi0, Phi1])
    U_tgt = np.concatenate([U0, U1]).astype(np.float64)

    z_t, max_t = learn_one_side_mmd_optimize_z_from_phi(
        Phi_src=Phi0, U_src=U0,
        Phi_tgt=Phi_tgt, U_tgt=U_tgt,
        epochs=epochs, lr=lr,
    )
    z_c, max_c = learn_one_side_mmd_optimize_z_from_phi(
        Phi_src=Phi1, U_src=U1,
        Phi_tgt=Phi_tgt, U_tgt=U_tgt,
        epochs=epochs, lr=lr,
    )

    return {"z_t": z_t, "z_c": z_c, "max_w_t": max_t, "max_w_c": max_c}


def signs_from_features(Phi: np.ndarray) -> np.ndarray:
    """
    Equivalent to running Algorithm-1 on Gram=Phi Phi^T,
    but without constructing Gram.
    """
    Phi = np.asarray(Phi, dtype=np.float64)
    m, D = Phi.shape

    sigma = np.empty(m, dtype=np.int8)
    sigma[0] = 1

    v = Phi[0].copy()  # (D,)

    for i in range(1, m):
        s = float(v @ Phi[i])
        sigma[i] = 1 if s <= 0.0 else -1
        v += float(sigma[i]) * Phi[i]

    return sigma


ratios = (0.1, 0.2, 0.4, 0.6, 1.2)
nsim = 100

B = 5
n_per_batch = 2000
n_all = B * n_per_batch


def summarize(ate_hat: np.ndarray, gt_vec: np.ndarray, n_total: np.ndarray, n_all: int):
    err = ate_hat - gt_vec
    abs_err = np.abs(err)

    rmse_each = abs_err
    mae_each  = abs_err

    return {
        "gt_mean": float(np.mean(gt_vec)),
        "bias":    float(np.mean(err)),
        "rmse":    float(np.sqrt(np.mean(err**2))),
        "rmse_std": float(np.std(rmse_each, ddof=1)),
        "mae":     float(np.mean(abs_err)),
        "mae_std": float(np.std(mae_each, ddof=1)),
        "std":     float(np.std(ate_hat, ddof=1)),

        "n_total_mean": float(np.mean(n_total)),
        "n_total_std":  float(np.std(n_total, ddof=1)),
        "n_total_min":  int(np.min(n_total)),
        "n_total_max":  int(np.max(n_total)),
        "keep_ratio_mean": float(np.mean(n_total) / n_all),
    }


# ============================================================
# IMPORTANT: create ONE rff_map and reuse it everywhere
# ============================================================
X_t0, X_c0, Y_t0, Y_c0, gt0 = load_hetero_batch(seed=0, n=10, batch_id=0)
d_in = X_t0.shape[1]
rff_map = RFFRBFMap(d_in=d_in, D=200, sigma=3.0, seed=0)


def one_run_prefix(ratio: float, seed: int):
    """
    For one simulation run, return results for each prefix t=1..B.
    Returns:
        prefix_res: list length B
            each entry is (ate_hat_t, n_total_t, gt_t, n_seen_t)
    """
    prefix_res = []

    # accumulate samples up to current batch
    X_t_hist, X_c_hist = [], []
    Y_t_hist, Y_c_hist = [], []

    for b in range(B):
        # independent seed per batch (reproducible)
        seed_b = seed + 1000 * b

        X_t, X_c, Y_t, Y_c, gt_ate = load_hetero_batch(
            seed=seed_b,
            n=n_per_batch,
            standardize=True,
            batch_id=b
        )

        X_t_hist.append(X_t)
        X_c_hist.append(X_c)
        Y_t_hist.append(Y_t)
        Y_c_hist.append(Y_c)

        # prefix dataset: batches 0..b
        X_t_all = np.vstack(X_t_hist)
        X_c_all = np.vstack(X_c_hist)
        Y_t_all = np.concatenate(Y_t_hist)
        Y_c_all = np.concatenate(Y_c_hist)

        Res = run_ipmrff_ate_rff_cached(
            X_t_all, Y_t_all,
            X_c_all, Y_c_all,
            ratio=ratio,
            rff_map=rff_map,
            epochs=1000,
            lr=0.03,
            seed=seed,  # keep fixed across prefixes within one run
        )

        ate_hat = float(Res["ATE"])
        n_total = int(len(Res["S0"]) + len(Res["S1"]))
        n_seen = int(X_t_all.shape[0] + X_c_all.shape[0])

        # NOTE: gt_ate is constant in your design, but we keep it general
        prefix_res.append((ate_hat, n_total, float(gt_ate), n_seen))

    return prefix_res


all_results = {}

for ratio in ratios:
    # store lists for each prefix t
    ate_hat_by_t = [[] for _ in range(B)]
    n_total_by_t = [[] for _ in range(B)]
    gt_by_t = [[] for _ in range(B)]
    n_seen_by_t = [[] for _ in range(B)]

    for s in range(nsim):
        prefix_res = one_run_prefix(ratio=ratio, seed=s)

        for t in range(B):
            ate_hat_t, n_total_t, gt_t, n_seen_t = prefix_res[t]
            ate_hat_by_t[t].append(ate_hat_t)
            n_total_by_t[t].append(n_total_t)
            gt_by_t[t].append(gt_t)
            n_seen_by_t[t].append(n_seen_t)

    # summarize per prefix t
    prefix_metrics = {}
    for t in range(B):
        ate_hat = np.asarray(ate_hat_by_t[t], dtype=np.float64)
        n_total = np.asarray(n_total_by_t[t], dtype=np.int64)
        gt_vec  = np.asarray(gt_by_t[t], dtype=np.float64)

        # current prefix total sample size = (t+1)*n_per_batch
        n_prefix_all = (t + 1) * n_per_batch

        metrics = summarize(ate_hat, gt_vec, n_total, n_all=n_prefix_all)
        prefix_metrics[t + 1] = metrics  # use 1-indexed prefix length

        print(f"\n========== ratio={ratio:.1f} | PREFIX t={t+1}/{B} SUMMARY (nsim={nsim}) ==========")
        print(f"GT mean          = {metrics['gt_mean']:.6f}")
        print(f"Bias             = {metrics['bias']:.6f}")
        print(f"RMSE             = {metrics['rmse']:.6f} ± {metrics['rmse_std']:.6f}")
        print(f"MAE              = {metrics['mae']:.6f} ± {metrics['mae_std']:.6f}")
        print(f"Std(ATEhat)      = {metrics['std']:.6f}")
        print(f"n_total(mean±std)= {metrics['n_total_mean']:.1f} ± {metrics['n_total_std']:.1f}")
        print(f"n_total(min,max) = ({metrics['n_total_min']}, {metrics['n_total_max']})")
        print(f"keep_ratio_mean  = {metrics['keep_ratio_mean']:.4f}")

    all_results[ratio] = {
        "meta": {
            "B": B,
            "n_per_batch": n_per_batch,
            "nsim": nsim,
            "rff_D": 200,
            "rff_sigma": 3.0,
            "tau_schedule": "rotate",
            "eval_mode": "prefix_to_current_batch",
        },
        "prefix_metrics": prefix_metrics,
    }

out_path = f"all_results_ipmrff_hetero_prefix_B{B}_n{n_per_batch}_nsim{nsim}_rffcached_new1.json"
with open(out_path, "w") as f:
    json.dump(all_results, f, indent=2)

print(f"Saved: {out_path}")
