#!/usr/bin/env python3
import math, argparse, pathlib, random, os, tarfile, urllib.request
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import List, Tuple

# ---------------- cfg ----------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(s: int):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)

# ---------------- toy datasets (2D/3D) ----------------

def _standardize(x: np.ndarray) -> np.ndarray:
    return (x - x.mean(0, keepdims=True)) / (x.std(0, keepdims=True) + 1e-6)

def make_8gaussians_2d(n=20000, radius=4.0, std=0.1, seed=0) -> torch.Tensor:
    rng = np.random.RandomState(seed)
    centers = np.stack([[radius*math.cos(k*math.pi/4), radius*math.sin(k*math.pi/4)] for k in range(8)], 0)
    idx = rng.randint(0, 8, size=(n,))
    x = centers[idx] + std * rng.randn(n, 2)
    x = _standardize(x)
    return torch.tensor(x, dtype=torch.float32)

def make_swissroll_2d(n=20000, tmin=0.5*math.pi, tmax=2.5*math.pi, noise=0.2, seed=0) -> torch.Tensor:
    rng = np.random.RandomState(seed)
    t = rng.rand(n) * (tmax - tmin) + tmin
    x = t * np.cos(t) + noise * rng.randn(n)
    y = t * np.sin(t) + noise * rng.randn(n)
    z = np.stack([x, y], axis=1)
    z = _standardize(z)
    return torch.tensor(z, dtype=torch.float32)

def make_swissroll_3d(n=40000, tmin=1.0*math.pi, tmax=3.0*math.pi, height=6.0, noise=0.25, seed=0) -> torch.Tensor:
    rng = np.random.RandomState(seed)
    t = rng.rand(n) * (tmax - tmin) + tmin
    x = t * np.cos(t) + noise * rng.randn(n)
    y = rng.rand(n) * height + noise * rng.randn(n)
    z = t * np.sin(t) + noise * rng.randn(n)
    data = np.stack([x, y, z], axis=1)
    data = _standardize(data)
    return torch.tensor(data, dtype=torch.float32)

# ---------------- Tabular (UCI/BSDS) datasets ----------------

_ZENODO_BASE = "https://zenodo.org/record/1161203/files"
_UCI_MAP = {
    "power": "power.tar.gz",
    "gas": "gas.tar.gz",
    "hepmass": "hepmass.tar.gz",
    "miniboone": "miniboone.tar.gz",
    "bsds300": "bsds300.tar.gz",
}

def _ensure_dir(p: str):
    pathlib.Path(p).mkdir(parents=True, exist_ok=True)

class TabularSplits:
    def __init__(self, train: torch.Tensor, val: torch.Tensor, test: torch.Tensor):
        self.train, self.val, self.test = train, val, test

@torch.no_grad()
def _download_and_extract(name: str, data_root: str) -> str:
    assert name in _UCI_MAP, f"unknown tabular dataset: {name}"
    _ensure_dir(data_root)
    fname = _UCI_MAP[name]
    url = f"{_ZENODO_BASE}/{fname}"
    tgz_path = os.path.join(data_root, fname)
    out_dir = os.path.join(data_root, name)
    if not os.path.exists(out_dir):
        print(f"[tabular] fetching {name} from {url}")
        urllib.request.urlretrieve(url, tgz_path)
        with tarfile.open(tgz_path, "r:gz") as tar:
            tar.extractall(path=data_root)
        os.remove(tgz_path)
    else:
        print(f"[tabular] found existing dir: {out_dir}")
    return out_dir

@torch.no_grad()
def load_tabular(name: str, data_root: str) -> TabularSplits:
    out_dir = _download_and_extract(name, data_root)
    def _try(paths):
        for p in paths:
            fp = os.path.join(out_dir, p)
            if os.path.exists(fp):
                return np.load(fp)
        raise FileNotFoundError(f"cannot find any of {paths} in {out_dir}")
    # common naming patterns from MAF release
    train = _try(["train.npy", "train_data.npy", f"{name}_train.npy"])  # [N_tr, D]
    val   = _try(["validation.npy", "val.npy", "valid.npy", f"{name}_val.npy"])  # [N_val, D]
    test  = _try(["test.npy", f"{name}_test.npy"])  # [N_te, D]
    # standardize by train stats
    mean = train.mean(axis=0, keepdims=True)
    std  = train.std(axis=0, keepdims=True) + 1e-6
    train = (train - mean) / std
    val   = (val   - mean) / std
    test  = (test  - mean) / std
    return TabularSplits(
        torch.tensor(train, dtype=torch.float32),
        torch.tensor(val,   dtype=torch.float32),
        torch.tensor(test,  dtype=torch.float32),
    )

# ---------------- Cauchy N-D feature layer ----------------

class CauchyNDLayer(nn.Module):
    r"""
    f(x_1,...,x_D) = sum_{k=1..M} λ_k / Π_{d=1..D} (ξ_{d,k} - x_d + i δ_{d,k})
    - Input: x ∈ R^{B×D}
    - Output: Re f ∈ R^{B×out_features} (or [Re,Im] if real_output=False)
    """
    def __init__(self, dim: int, m: int = 64, out_features: int = 1,
                 ranges: List[Tuple[float,float]] = None,
                 delta0: float = 0.08, delta_min: float = 1e-3,
                 init_std: float = 1e-2, real_output: bool = True, eps: float = 1e-8):
        super().__init__()
        assert dim >= 1
        self.dim = int(dim)
        self.m = int(m)
        self.out_features = int(out_features)
        self.real_output = bool(real_output)
        self.eps = float(eps)
        self.delta_min = float(delta_min)
        if ranges is None:
            ranges = [(-3.0, 3.0)] * self.dim
        assert len(ranges) == self.dim
        mid  = torch.tensor([0.5*(lo+hi) for (lo,hi) in ranges], dtype=torch.float32)
        span = torch.tensor([max(hi-lo, 1e-6) for (lo,hi) in ranges], dtype=torch.float32)
        self.register_buffer("in_mid",  mid.view(1, self.dim))
        self.register_buffer("in_span", span.view(1, self.dim))
        self.l_re = nn.Parameter(torch.randn(self.out_features, self.m) * init_std)
        self.l_im = nn.Parameter(torch.randn(self.out_features, self.m) * init_std)
        self.xi_re = nn.Parameter(torch.empty(self.dim, self.m).uniform_(-1.0, 1.0))
        init_raw = math.log(math.expm1(max(delta0 - self.delta_min, 1e-6)))
        self.raw_delta = nn.Parameter(torch.full((self.dim, self.m), init_raw))
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return 2.0 * (x - self.in_mid) / self.in_span
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 2 and x.size(1) == self.dim, f"expected [B,{self.dim}]"
        B = x.size(0)
        x = self._norm(x)
        lam = torch.complex(self.l_re, self.l_im)       # [O, M]
        denom = torch.ones(B, self.m, dtype=torch.complex64, device=x.device)
        for d in range(self.dim):
            xi_re_d = self.xi_re[d].view(1, -1)
            delta_d = F.softplus(self.raw_delta[d]).view(1, -1) + self.delta_min
            denom_d = (xi_re_d - x[:, [d]]) + 1j * delta_d
            denom = denom * denom_d
        invden = 1.0 / (denom + self.eps)
        fz = invden @ lam.transpose(0, 1)
        fz = fz / math.sqrt(self.m)
        return fz.real if self.real_output else torch.cat([fz.real, fz.imag], dim=1)

# ---------------- feature net for (s,t) ----------------

class STNetND(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=128, layers=2,
                 feat="cauchy", s_clip=2.0, m=64):
        super().__init__()
        self.s_clip = s_clip
        self.last_t = None
        modules = []

        # --- 选择激活 & 初始化策略 ---
        # 平滑非线性统一用 kaiming(relu) 最稳（SiLU/GELU/ELU 都可）
        def kaiming_linear(in_f, out_f):
            lin = nn.Linear(in_f, out_f)
            nn.init.kaiming_uniform_(lin.weight, nonlinearity="relu")
            nn.init.zeros_(lin.bias)
            return lin

        # 激活模块（第一层与隐藏层一致）
        feat_norm = feat.lower()
        if feat_norm == "swish":  # 别名
            feat_norm = "silu"

        if feat_norm == "cauchy":
            # 前端用 Cauchy 特征，再接一个轻激活（用 Tanh 就好）
            modules.append(CauchyNDLayer(dim=in_dim, m=m, out_features=hidden, ranges=[(-3,3)]*in_dim))
            act = nn.Tanh()
            modules.append(act)
            hidden_act = nn.Tanh()  # 后续隐藏层也用 Tanh

        elif feat_norm == "tanh":
            lin = nn.Linear(in_dim, hidden)
            nn.init.xavier_uniform_(lin.weight); nn.init.zeros_(lin.bias)
            modules += [lin, nn.Tanh()]
            hidden_act = nn.Tanh()

        elif feat_norm in ("relu", "silu", "gelu", "elu"):
            lin = kaiming_linear(in_dim, hidden)
            if feat_norm == "relu": first_act = nn.ReLU()
            elif feat_norm == "silu": first_act = nn.SiLU()
            elif feat_norm == "gelu": first_act = nn.GELU()
            else: first_act = nn.ELU()
            modules += [lin, first_act]
            # 隐藏层激活与第一层一致
            hidden_act = type(first_act)()

        else:
            raise ValueError(f"unknown feat: {feat}")

        # 追加隐藏层（layers>=2 时）
        for _ in range(layers-1):
            if feat_norm in ("relu", "silu", "gelu", "elu"):
                lin = kaiming_linear(hidden, hidden)
            else:
                lin = nn.Linear(hidden, hidden)
                nn.init.xavier_uniform_(lin.weight); nn.init.zeros_(lin.bias)
            modules += [lin, type(hidden_act)()]

        self.backbone = nn.Sequential(*modules)

        # heads（零初始化，稳定从恒等开始）
        self.s_head = nn.Linear(hidden, out_dim)
        self.t_head = nn.Linear(hidden, out_dim)
        nn.init.zeros_(self.s_head.weight); nn.init.zeros_(self.s_head.bias)
        nn.init.zeros_(self.t_head.weight); nn.init.zeros_(self.t_head.bias)

    def forward(self, x):
        h = self.backbone(x)
        s = torch.tanh(self.s_head(h)) * self.s_clip
        t = self.t_head(h)
        self.last_t = t
        return s, t



# ---------------- RealNVP (N-D) ----------------

class AffineCouplingND(nn.Module):
    def __init__(self, mask: torch.Tensor, hidden=128, layers=2, feat="cauchy", s_clip=2.0, m=64):
        super().__init__()
        assert mask.dtype == torch.bool
        self.register_buffer("mask", mask)
        in_dim = int(mask.sum().item())
        out_dim = int((~mask).sum().item())
        self.st = STNetND(in_dim, out_dim, hidden, layers, feat=feat, s_clip=s_clip, m=m)
    def forward(self, x: torch.Tensor, logdet=None, reverse=False):
        c = x[:, self.mask]
        u = x[:, ~self.mask]
        s, t = self.st(c)
        if reverse:
            u = (u - t) * torch.exp(-s); ld = -s
        else:
            u = u * torch.exp(s) + t; ld = s
        y = x.clone(); y[:, ~self.mask] = u
        if logdet is None: return y
        return y, logdet + ld.sum(dim=1, keepdim=True)

class BaseDist:
    def __init__(self, D: int, kind: str = "gauss"):
        self.D = int(D); self.kind = kind
    def sample(self, n: int, device):
        if self.kind == "gauss":
            return torch.randn(n, self.D, device=device)
        elif self.kind == "cauchy":
            u = torch.rand(n, self.D, device=device).clamp(1e-6, 1-1e-6)
            return torch.tan(math.pi*(u - 0.5))
        else:
            raise ValueError(self.kind)
    def log_prob(self, z: torch.Tensor) -> torch.Tensor:
        if self.kind == "gauss":
            return -0.5*(z*z).sum(dim=1, keepdim=True) - 0.5*self.D*math.log(2*math.pi)
        elif self.kind == "cauchy":
            return - torch.log1p(z*z).sum(dim=1, keepdim=True) - self.D*math.log(math.pi)

class FixedOrthogonal(nn.Module):
    """Det(Q)=+1 的固定正交混合，不改变 logdet；用于缓解 RealNVP 的 '桥接/拉丝'。"""
    def __init__(self, D: int, seed: int = 0):
        super().__init__()
        g = torch.Generator()
        g.manual_seed(seed)
        Q, _ = torch.linalg.qr(torch.randn(D, D, generator=g))
        # det(Q)=+1
        if torch.linalg.det(Q) < 0:
            Q[:, 0] *= -1
        self.register_buffer("Q", Q)

    def forward(self, x, logdet=None, reverse=False):
        y = x @ (self.Q.t() if reverse else self.Q)
        if logdet is None:
            return y
        return y, logdet  # det=1 ⇒ logdet 不变


class RealNVPND(nn.Module):
    def __init__(self, D, K=8, hidden=128, layers=2, feat="cauchy", s_clip=2.0, m=64, base="gauss"):
        super().__init__()
        self.D = int(D)
        idx = torch.arange(D)
        masks = [(idx % 2 == (k % 2)) for k in range(K)]
        self.layers = nn.ModuleList([AffineCouplingND(mask, hidden, layers, feat, s_clip, m) for mask in masks])

        # NEW: 1×1 正交混合，放在每个耦合层之后（最后一层除外）
        self.mixers = nn.ModuleList([FixedOrthogonal(D, seed=1234 + k) for k in range(K - 1)])

        self.base = BaseDist(D, kind=base)

    def fwd(self, x):
        logdet = torch.zeros(x.size(0), 1, device=x.device)
        y = x
        for i, layer in enumerate(self.layers):
            y, logdet = layer(y, logdet, reverse=False)
            if i < len(self.mixers):
                y, logdet = self.mixers[i](y, logdet, reverse=False)  # NEW
        return y, logdet

    def inv(self, z):
        x = z
        for i in reversed(range(len(self.layers))):
            if i < len(self.mixers):
                x = self.mixers[i](x, reverse=True)  # NEW
            x = self.layers[i](x, logdet=None, reverse=True)
        return x

    def log_prob(self, x):
        z, logdet = self.fwd(x)
        log_pz = self.base.log_prob(z)
        return (log_pz + logdet).squeeze(1)

# ---------------- utils: viz & training (toy & tabular) ----------------

def count_params(m: nn.Module):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def plot_density_2d(model: RealNVPND, data: torch.Tensor, out_png: str, n=220, lim=3.5, nsamp=2000, title=""):
    model.eval()
    xs = np.linspace(-lim, lim, n); ys = np.linspace(-lim, lim, n)
    X, Y = np.meshgrid(xs, ys)
    grid = torch.tensor(np.stack([X.reshape(-1), Y.reshape(-1)], 1), dtype=torch.float32, device=device)
    with torch.no_grad():
        lp = model.log_prob(grid).view(n, n).cpu().numpy()
    z = model.base.sample(nsamp, device=device)
    with torch.no_grad():
        smp = model.inv(z).cpu().numpy()
    plt.figure(figsize=(5.4,4.8))
    CS = plt.contourf(X, Y, np.exp(lp), levels=30)
    plt.colorbar(CS, shrink=0.85)
    plt.scatter(data[:2000,0].cpu(), data[:2000,1].cpu(), s=6, alpha=0.25, label="data", edgecolors='none')
    plt.scatter(smp[:2000,0], smp[:2000,1], s=6, alpha=0.25, label="samples", edgecolors='none')
    plt.xlim(-lim, lim); plt.ylim(-lim, lim); plt.gca().set_aspect('equal', 'box')
    plt.legend(loc="upper right", frameon=False); plt.title(title)
    plt.tight_layout(); plt.savefig(out_png, dpi=240); plt.close(); print(f"saved -> {out_png}")

def plot_scatter_3d_samples(model: RealNVPND, data: torch.Tensor, out_png: str, nsamp=10000, title=""):
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
    model.eval()
    z = model.base.sample(nsamp, device=device)
    with torch.no_grad():
        smp = model.inv(z).cpu().numpy()
    dnp = data[:nsamp].cpu().numpy()
    fig = plt.figure(figsize=(6.2,5.4))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(dnp[:,0], dnp[:,1], dnp[:,2], s=2, alpha=0.15, label='data')
    ax.scatter(smp[:,0], smp[:,1], smp[:,2], s=2, alpha=0.15, label='samples')
    ax.set_title(title); ax.legend(loc='upper right')
    plt.tight_layout(); plt.savefig(out_png, dpi=220); plt.close(); print(f"saved -> {out_png}")

@torch.no_grad()
def evaluate_nll(model: RealNVPND, data: torch.Tensor, bs: int = 4096) -> float:
    model.eval(); xs = []
    for i in range(0, data.size(0), bs):
        xs.append(-model.log_prob(data[i:i+bs].to(device)).mean().item())
    return float(np.mean(xs))

def exp_lr(it, base_lr, every=5000, gamma=0.9):
    k = it // every; return base_lr * (gamma ** k)

# ---------------- training entry ----------------

def train(args):
    set_seed(args.seed)
    # pick dataset
    tabular = False
    if args.dataset == "8g2d":
        data = make_8gaussians_2d(n=args.n_data, seed=args.seed).to(device)
        D = data.size(1)
    elif args.dataset == "swiss2d":
        data = make_swissroll_2d(n=args.n_data, seed=args.seed).to(device)
        D = data.size(1)
    elif args.dataset == "swiss3d":
        data = make_swissroll_3d(n=max(args.n_data, 40000), seed=args.seed).to(device)
        D = data.size(1)
    elif args.dataset.startswith("uci:"):
        tabular = True
        name = args.dataset.split(":",1)[1].lower()
        splits = load_tabular(name, args.data_root)
        data_train, data_val, data_test = splits.train.to(device), splits.val.to(device), splits.test.to(device)
        D = data_train.size(1)
    else:
        raise ValueError(args.dataset)

    model = RealNVPND(D=D, K=args.blocks, hidden=args.width, layers=args.layers,
                  feat=args.feat,  # 直接传入feat参数
                  s_clip=args.sclip, m=args.m, base=args.base).to(device)

    print(f"Device={device}, D={D}, dataset={args.dataset}, feat={args.feat}, base={args.base}, params={count_params(model):,}")

    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)
    bs, steps = args.batch, args.steps
    best_nll, best_state = float("inf"), None

    for it in range(1, steps+1):
        # lr schedule
        lr_now = exp_lr(it, args.lr, every=5000, gamma=0.9)
        for pg in opt.param_groups: pg["lr"] = lr_now

        # minibatch
        if tabular:
            idx = torch.randint(0, data_train.size(0), (bs,), device=device)
            x = data_train[idx]
        else:
            idx = torch.randint(0, data.size(0), (bs,), device=device)
            x = data[idx]

        lp = model.log_prob(x); nll = - lp.mean()
        # weak t-regularization
        t_reg = 0.0
        for layer in model.layers:
            t = layer.st.last_t
            if t is not None:
                t_reg = t_reg + (t*t).mean()
        t_reg = args.treg * t_reg
        loss = nll + t_reg

        opt.zero_grad(set_to_none=True)
        loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        opt.step()

        # anneal deltas late
        # if args.feat == "cauchy" and it >= int(0.8*steps):
        #     with torch.no_grad():
        #         for mod in model.modules():
        #             if isinstance(mod, CauchyNDLayer):
        #                 mod.raw_delta.add_(math.log(0.995))

        if (it % args.log_every == 0) or it == 1 or it == steps:
            eval_target = data_val if tabular else data
            eval_nll = evaluate_nll(model, eval_target)
            print(f"it={it:6d} | lr={lr_now:.2e} | nll={eval_nll:.4f} | loss={loss.item():.4f} | treg={t_reg.item():.3e}")
            if eval_nll < best_nll:
                best_nll = eval_nll
                best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}

        if not torch.isfinite(loss).item():
            print(f"[WARN] loss exploded at it={it}; stopping early.")
            break

    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"[best] nll={best_nll:.4f}")

    out_dir = pathlib.Path(args.outdir); out_dir.mkdir(parents=True, exist_ok=True)

    title = f"{args.dataset} | feat={args.feat} | base={args.base} | K={args.blocks}, L={args.layers}, W={args.width}"
    if tabular:
        test_nll = evaluate_nll(model, data_test)
        print(f"[test] nll={test_nll:.4f}")

        csv_path = out_dir / "results.csv"
        line = (f"{args.dataset},{args.feat},{args.base},{args.blocks},"
                f"{args.layers},{args.width},{args.m},{args.lr},"
                f"{args.seed},{best_nll:.6f},{test_nll:.6f}\n")
        with open(csv_path, "a", encoding="utf-8") as f:
            # if csv_path.stat().st_size == 0: f.write("dataset,feat,base,blocks,layers,width,m,lr,seed,val_best_nll,test_nll\n")
            f.write(line)
    else:
        # 仍然保留 toy 可视化分支
        if D == 2:
            plot_density_2d(model, data, str(out_dir / f"density_{args.dataset}_{args.feat}.png"),
                            title=title)
        elif D == 3:
            plot_scatter_3d_samples(model, data, str(out_dir / f"scatter_{args.dataset}_{args.feat}.png"),
                                    title=title)


    torch.save(model.state_dict(), str(out_dir / f"realnvp_{args.dataset}_{args.feat}.pt"))
    print("done.")

# ---------------- main ----------------

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset", type=str, default="8g2d",
                    help="toy: {8g2d, swiss2d, swiss3d} | tabular: uci:{power,gas,hepmass,miniboone,bsds300}")
    #ap.add_argument("--feat", type=str, default="cauchy", choices=["cauchy","tanh"], help="feature net: CauchyND vs tanh baseline")
    ap.add_argument("--feat", type=str, default="cauchy", choices=["cauchy","tanh","relu","elu","silu","gelu"])
    ap.add_argument("--base", type=str, default="gauss", choices=["gauss","cauchy"], help="latent base distribution")

    ap.add_argument("--blocks", type=int, default=6, help="number of coupling layers (K)")
    ap.add_argument("--layers", type=int, default=2, help="hidden layers in s,t nets (>=1)")
    ap.add_argument("--width", type=int, default=128, help="hidden width in s,t nets")
    ap.add_argument("--m", type=int, default=64, help="number of Cauchy poles per output feature")
    ap.add_argument("--sclip", type=float, default=2.0, help="log-scale clamp via tanh*sclip")

    ap.add_argument("--n_data", type=int, default=30000)
    ap.add_argument("--batch", type=int, default=512)
    ap.add_argument("--steps", type=int, default=20000)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--treg", type=float, default=1e-4, help="L2 on t outputs (stabilizer)")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--log_every", type=int, default=1000)
    ap.add_argument("--outdir", type=str, default="flow_out")
    ap.add_argument("--data_root", type=str, default="data_tabular", help="cache/where to download UCI/BSDS")
    args = ap.parse_args()
    train(args)
