
import os
import glob
import argparse
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import torch
import datasets
from datasets import Dataset, DatasetDict
import pyarrow as pa
import pyarrow.ipc as ipc

from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module
from uni2ts.common.torch_util import packed_causal_attention_mask


# ============================================================
# Utilities
# ============================================================
def ceil_to_multiple(x: int, m: int) -> int:
    return ((x + m - 1) // m) * m


def to_1d_float(x) -> np.ndarray:
    if x is None:
        return np.array([], dtype=np.float32)
    return np.asarray(x, dtype=np.float32).reshape(-1)


def safe_zscore(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """NaN-safe zscore, sigma≈0 safe."""
    x = np.asarray(x, dtype=np.float32).reshape(-1)
    if x.size == 0:
        return x
    mu = np.nanmean(x)
    sig = np.nanstd(x)
    if (not np.isfinite(sig)) or sig < eps:
        return (x - mu).astype(np.float32, copy=False)
    return ((x - mu) / sig).astype(np.float32, copy=False)


def pad_left(arr: np.ndarray, L: int, pad_value: float = 0.0) -> Tuple[np.ndarray, np.ndarray]:
    """Left-pad to length L (right-align). Returns (padded, is_pad_mask)."""
    arr = np.asarray(arr, dtype=np.float32).reshape(-1)
    n = len(arr)
    if n >= L:
        out = arr[-L:].astype(np.float32, copy=False)
        is_pad = np.zeros(L, dtype=bool)
        return out, is_pad
    out = np.full(L, pad_value, dtype=np.float32)
    if n > 0:
        out[-n:] = arr
    is_pad = np.ones(L, dtype=bool)
    if n > 0:
        is_pad[-n:] = False
    return out, is_pad


def upsert_column(ds: Dataset, name: str, values):
    """Overwrite-safe add_column."""
    if name in ds.column_names:
        ds = ds.remove_columns([name])
    return ds.add_column(name, values)


# ============================================================
# Minimal filter (NO min_length)
# ============================================================
def filter_has_gr_nonempty(x: Dict[str, Any]) -> bool:
    """Keep rows that have BOTH g and r with non-empty targets."""
    try:
        bd = x["bands_data"]
        g = bd.get("g", None)
        r = bd.get("r", None)
        if g is None or r is None:
            return False
        gt = g.get("target", None)
        rt = r.get("target", None)
        return (gt is not None and len(gt) > 0) and (rt is not None and len(rt) > 0)
    except Exception:
        return False


# ============================================================
# SAFE LOCAL LOADER (Arrow IPC fallback + split alias)
# ============================================================
SPLIT_ALIASES: dict[str, list[str]] = {
    "anom": ["anom"],
}


def _resolve_split_folder(dataset_path: str, split: str) -> Optional[str]:
    """
    Return usable folder for split by trying aliases.
    Accept candidate if:
      (a) load_from_disk works OR
      (b) it has at least one .arrow shard (fallback)
    """
    candidates = SPLIT_ALIASES.get(split, [split])

    def has_any_arrow(p: str) -> bool:
        return len(glob.glob(os.path.join(p, "**", "*.arrow"), recursive=True)) > 0

    for s in candidates:
        sp = os.path.join(dataset_path, s)
        if not os.path.isdir(sp):
            continue

        try:
            _ = datasets.load_from_disk(sp)
            return sp
        except Exception:
            pass

        if has_any_arrow(sp):
            return sp

    return None


def _find_arrow_shards_recursive(split_path: str) -> List[str]:
    return sorted(set(glob.glob(os.path.join(split_path, "**", "*.arrow"), recursive=True)))


def _read_arrow_ipc_table(path: str) -> pa.Table:
    # IPC file
    try:
        with ipc.open_file(path) as f:
            return f.read_all()
    except Exception:
        pass
    # IPC stream
    try:
        with open(path, "rb") as f:
            reader = ipc.open_stream(f)
            return reader.read_all()
    except Exception as e:
        raise RuntimeError(f"Failed to read Arrow IPC (file/stream): {path} | {repr(e)}")


def _load_split_fallback_from_arrow(split_path: str) -> Dataset:
    arrow_files = _find_arrow_shards_recursive(split_path)
    if not arrow_files:
        raise RuntimeError(f"[safe_loader] no .arrow shards found under: {split_path}")

    tables: List[pa.Table] = []
    bad = 0
    for af in arrow_files:
        try:
            t = _read_arrow_ipc_table(af).replace_schema_metadata({})
            tables.append(t)
        except Exception as e:
            bad += 1
            print(f"[safe_loader][WARN] failed shard: {af} | {repr(e)}")

    if not tables:
        raise RuntimeError(
            f"[safe_loader] found .arrow but none readable: {split_path} (bad={bad}/{len(arrow_files)})"
        )

    table = pa.concat_tables(tables, promote_options="default") if len(tables) > 1 else tables[0]
    table = table.replace_schema_metadata({})
    return Dataset(table)


def safe_load_from_disk(path: str) -> Dataset:
    try:
        return datasets.load_from_disk(path)
    except Exception as e:
        print(f"[safe_loader][WARN] load_from_disk failed: {path} | {repr(e)}")
        return _load_split_fallback_from_arrow(path)


def safe_load_datasetdict(dataset_path: str, splits: List[str]) -> DatasetDict:
    # Case 1: DatasetDict root
    try:
        obj = datasets.load_from_disk(dataset_path)
        if isinstance(obj, datasets.DatasetDict):
            out = {s: obj[s] for s in splits if s in obj}
            if out:
                return DatasetDict(out)
    except Exception as e:
        print(f"[safe_loader][WARN] root load_from_disk failed: {dataset_path} | {repr(e)}")

    # Case 2: split folders (with alias resolution)
    out: dict[str, Dataset] = {}
    for s in splits:
        sp = _resolve_split_folder(dataset_path, s)
        if sp is None:
            print(f"[safe_loader][WARN] split folder not found for '{s}' (tried aliases). Skipping.")
            continue
        try:
            out[s] = safe_load_from_disk(sp)
        except Exception as e:
            print(f"[safe_loader][WARN] failed to load split '{s}' at {sp} | {repr(e)}. Skipping.")
            continue

    if not out:
        raise RuntimeError(f"[safe_loader] could not load any splits from: {dataset_path}")
    return DatasetDict(out)


# ============================================================
# Build batch for ONE band
# ============================================================
def build_batch(
    ds: Dataset,
    indices: List[int],
    band: str,
    ctx_user: int,
    ctx_pad: int,
    normalize: str,
    pad_value: float,
    use_past_feat: bool,
):
    B = len(indices)
    past_target = np.full((B, ctx_pad), pad_value, dtype=np.float32)
    past_is_pad = np.ones((B, ctx_pad), dtype=bool)
    past_observed = np.zeros((B, ctx_pad), dtype=bool)

    if use_past_feat:
        past_feat = np.full((B, ctx_pad), pad_value, dtype=np.float32)
        past_feat_obs = np.zeros((B, ctx_pad), dtype=bool)
    else:
        past_feat = None
        past_feat_obs = None

    item_ids: List[str] = []

    for i, idx in enumerate(indices):
        rec = ds[idx]
        item_ids.append(str(rec.get("item_id", idx)))

        series = to_1d_float(rec["bands_data"][band]["target"])
        if series.size > ctx_user:
            series = series[-ctx_user:]

        if normalize == "zscore":
            series = safe_zscore(series)
        elif normalize != "none":
            raise ValueError("normalize must be one of: none, zscore")

        spad, is_pad = pad_left(series, ctx_pad, pad_value=pad_value)
        past_target[i] = spad
        past_is_pad[i] = is_pad
        past_observed[i] = (~is_pad) & (~np.isnan(spad))

        if use_past_feat:
            fe = to_1d_float(rec["bands_data"][band].get("past_feat_dynamic_real"))
            if fe.size > ctx_user:
                fe = fe[-ctx_user:]
            if normalize == "zscore":
                fe = safe_zscore(fe)
            fpad, f_is_pad = pad_left(fe, ctx_pad, pad_value=pad_value)
            past_feat[i] = fpad
            past_feat_obs[i] = (~f_is_pad) & (~np.isnan(fpad)) & (~is_pad)

    out = {
        "item_ids": item_ids,
        "past_target": torch.from_numpy(past_target).unsqueeze(-1),             # [B,T,1]
        "past_observed_target": torch.from_numpy(past_observed).unsqueeze(-1),  # [B,T,1]
        "past_is_pad": torch.from_numpy(past_is_pad),                           # [B,T]
        "past_feat_dynamic_real": None,
        "past_observed_feat_dynamic_real": None,
    }
    if use_past_feat:
        out["past_feat_dynamic_real"] = torch.from_numpy(past_feat).unsqueeze(-1)
        out["past_observed_feat_dynamic_real"] = torch.from_numpy(past_feat_obs).unsqueeze(-1)

    return out


@torch.no_grad()
def embed_batch(model: Moirai2Forecast, batch: Dict[str, torch.Tensor], pooling: str) -> torch.Tensor:
    device = next(model.parameters()).device
    module = model.module
    patch = module.patch_size

    for k, v in list(batch.items()):
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(device)

    target, obs, sid, tid, vid, pmask = model._convert(
        patch_size=patch,
        past_target=batch["past_target"],
        past_observed_target=batch["past_observed_target"],
        past_is_pad=batch["past_is_pad"],
        feat_dynamic_real=None,
        observed_feat_dynamic_real=None,
        past_feat_dynamic_real=batch["past_feat_dynamic_real"],
        past_observed_feat_dynamic_real=batch["past_observed_feat_dynamic_real"],
    )

    loc, scale = module.scaler(target, obs * ~pmask.unsqueeze(-1), sid, vid)
    x = torch.cat([(target - loc) / scale, obs.float()], dim=-1)

    h = module.in_proj(x)
    h = module.encoder(
        h,
        packed_causal_attention_mask(sid, tid),
        time_id=tid,
        var_id=vid,
    )

    per_var = model.context_token_length(patch)
    T = model.hparams.target_dim * per_var
    h = h[:, :T, :]

    valid = (sid[:, :T].bool() & obs[:, :T, :].amax(dim=-1).bool())
    if pooling == "last":
        idx = valid.long().sum(dim=1).clamp(min=1) - 1
        emb = h[torch.arange(h.size(0), device=h.device), idx]
    else:
        m = valid.unsqueeze(-1).float()
        emb = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)

    return emb


# ============================================================
# Main
# ============================================================
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset_path", required=True)
    ap.add_argument("--out_dir", required=True)

    ap.add_argument("--model_id", default="Salesforce/moirai-2.0-R-small")
    ap.add_argument("--ctx", type=int, default=200)
    ap.add_argument("--batch_size", type=int, default=256)

    ap.add_argument("--pooling", choices=["mean", "last"], default="mean")
    ap.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
    ap.add_argument("--device", choices=["cuda", "cpu"], default="cuda")

    ap.add_argument("--splits", type=str, default="train,validation,test,anom")
    ap.add_argument("--normalize", choices=["none", "zscore"], default="none")
    ap.add_argument("--pad_value", type=float, default=0.0)
    ap.add_argument("--use_past_feat", action="store_true")
    args = ap.parse_args()

    split_list = [s.strip() for s in args.splits.split(",") if s.strip()]

    device = torch.device(args.device if (args.device == "cpu" or torch.cuda.is_available()) else "cpu")
    torch_dtype = {
        "bf16": (torch.bfloat16 if device.type == "cuda" else torch.float32),
        "fp16": (torch.float16 if device.type == "cuda" else torch.float32),
        "fp32": torch.float32,
    }[args.dtype]

    print("========================================")
    print("[INFO] dataset_path :", args.dataset_path)
    print("[INFO] out_dir      :", args.out_dir)
    print("[INFO] model_id     :", args.model_id)
    print("[INFO] device       :", device)
    print("[INFO] dtype        :", torch_dtype)
    print("[INFO] ctx          :", args.ctx)
    print("[INFO] batch_size   :", args.batch_size)
    print("[INFO] pooling      :", args.pooling)
    print("[INFO] normalize    :", args.normalize)
    print("[INFO] pad_value    :", args.pad_value)
    print("[INFO] splits       :", split_list)
    print("========================================")

    dd_in = safe_load_datasetdict(args.dataset_path, splits=split_list)

    module = Moirai2Module.from_pretrained(args.model_id)
    ctx_pad = ceil_to_multiple(args.ctx, module.patch_size)
    print(f"[model] patch_size={module.patch_size} | ctx_user={args.ctx} | ctx_pad={ctx_pad} | d_model={module.d_model}")

    model = Moirai2Forecast(
        module=module,
        prediction_length=1,
        context_length=ctx_pad,
        feat_dynamic_real_dim=0,
        target_dim=1,
        past_feat_dynamic_real_dim=(1 if args.use_past_feat else 0),
    ).to(device).eval()

    os.makedirs(args.out_dir, exist_ok=True)

    out: Dict[str, Dataset] = {}
    for sp, ds in dd_in.items():
        before = len(ds)
        ds = ds.filter(filter_has_gr_nonempty)
        after = len(ds)
        kept = (after / before) if before > 0 else 0.0
        print(f"[data] {sp}: {after}/{before} kept ({kept:.3f})")

        embs: Dict[str, np.ndarray] = {}
        for band in ["g", "r"]:
            vecs: List[np.ndarray] = []
            for i in range(0, len(ds), args.batch_size):
                idxs = list(range(i, min(i + args.batch_size, len(ds))))
                batch = build_batch(
                    ds=ds,
                    indices=idxs,
                    band=band,
                    ctx_user=args.ctx,
                    ctx_pad=ctx_pad,
                    normalize=args.normalize,
                    pad_value=args.pad_value,
                    use_past_feat=args.use_past_feat,
                )
                autocast_enabled = (device.type == "cuda")
                with torch.autocast(device_type="cuda", dtype=torch_dtype, enabled=autocast_enabled):
                    e = embed_batch(model, batch, args.pooling)
                vecs.append(e.detach().cpu().to(torch.float32).numpy())
            embs[band] = np.vstack(vecs) if vecs else np.zeros((0, module.d_model), dtype=np.float32)

        g = embs["g"].astype(np.float32, copy=False)
        r = embs["r"].astype(np.float32, copy=False)
        fused = np.concatenate([g, r], axis=1).astype(np.float32, copy=False)

        gg = g.tolist()
        rr = r.tolist()
        ff = fused.tolist()

        ds2 = ds
        ds2 = upsert_column(ds2, "embeddings_g", gg)
        ds2 = upsert_column(ds2, "embeddings_r", rr)
        ds2 = upsert_column(ds2, "combined_embedding", ff)

        ds2 = upsert_column(ds2, "emb_g", gg)
        ds2 = upsert_column(ds2, "emb_r", rr)
        ds2 = upsert_column(ds2, "emb_gr_fused", ff)

        out[sp] = ds2
        print(f"[done] {sp} | out_N={len(ds2)}")

    final = DatasetDict(out)
    final.save_to_disk(args.out_dir)
    print(f"[saved] -> {args.out_dir}")
    print("[expect] dataset_dict.json created at root")


if __name__ == "__main__":
    main()
