import os
import glob
import argparse
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import torch
import datasets
from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import AutoModelForCausalLM
import pyarrow as pa
import pyarrow.ipc as ipc


## ============================================================
# SAFE LOCAL LOADER (Arrow IPC fallback)
#   - Handles broken HF dataset_info/features by reading .arrow directly
#   - Strips schema metadata so HF doesn't try to reconstruct features
#   - Supports split aliases (e.g., anom -> anom_backup)
# ============================================================

SPLIT_ALIASES: dict[str, list[str]] = {
    "anom": ["anom"],
}

def _resolve_split_folder(dataset_path: str, split: str) -> str | None:
    candidates = SPLIT_ALIASES.get(split, [split])

    def has_any_arrow(p: str) -> bool:
        return len(glob.glob(os.path.join(p, "**", "*.arrow"), recursive=True)) > 0

    for s in candidates:
        sp = os.path.join(dataset_path, s)
        if not os.path.isdir(sp):
            continue

        try:
            _ = datasets.load_from_disk(sp)
            return sp
        except Exception:
            pass

        if has_any_arrow(sp):
            return sp

    return None


def _find_arrow_shards_recursive(split_path: str) -> List[str]:
    files = glob.glob(os.path.join(split_path, "**", "*.arrow"), recursive=True)
    return sorted(set(files))


def _read_arrow_ipc_table(path: str) -> pa.Table:
    try:
        with ipc.open_file(path) as f:
            return f.read_all()
    except Exception:
        pass

    try:
        with open(path, "rb") as f:
            reader = ipc.open_stream(f)
            return reader.read_all()
    except Exception as e:
        raise RuntimeError(f"Failed to read Arrow IPC (file/stream): {path} | {repr(e)}")


def _load_split_fallback_from_arrow(split_path: str) -> Dataset:
    arrow_files = _find_arrow_shards_recursive(split_path)
    if not arrow_files:
        raise RuntimeError(f"[safe_loader] no .arrow shards found under: {split_path}")

    tables: List[pa.Table] = []
    bad = 0
    for af in arrow_files:
        try:
            t = _read_arrow_ipc_table(af)
            t = t.replace_schema_metadata({})
            tables.append(t)
        except Exception as e:
            bad += 1
            print(f"[safe_loader][WARN] failed shard: {af} | {repr(e)}")

    if not tables:
        raise RuntimeError(
            f"[safe_loader] found .arrow but none readable: {split_path} (bad={bad}/{len(arrow_files)})"
        )

    table = pa.concat_tables(tables, promote_options="default") if len(tables) > 1 else tables[0]
    table = table.replace_schema_metadata({})
    return Dataset(table)


def safe_load_from_disk(path: str):
    try:
        return datasets.load_from_disk(path)
    except Exception as e:
        print(f"[safe_loader][WARN] load_from_disk failed: {path} | {repr(e)}")
        return _load_split_fallback_from_arrow(path)


def safe_load_datasetdict(dataset_path: str, splits: List[str]) -> DatasetDict:
    try:
        obj = datasets.load_from_disk(dataset_path)
        if isinstance(obj, datasets.DatasetDict):
            out = {s: obj[s] for s in splits if s in obj}
            if out:
                return DatasetDict(out)
    except Exception as e:
        print(f"[safe_loader][WARN] root load_from_disk failed: {dataset_path} | {repr(e)}")

    out: dict[str, Dataset] = {}
    for s in splits:
        sp = _resolve_split_folder(dataset_path, s)
        if sp is None:
            print(f"[safe_loader][WARN] split folder not found for '{s}' (tried aliases). Skipping.")
            continue

        try:
            out[s] = safe_load_from_disk(sp)
        except Exception as e:
            print(f"[safe_loader][WARN] failed to load split '{s}' at {sp} | {repr(e)}. Skipping.")
            continue

    if not out:
        raise RuntimeError(f"[safe_loader] could not load any splits from: {dataset_path}")
    return DatasetDict(out)

# ============================================================
# Basic helpers
# ============================================================
def to_1d_float(x) -> np.ndarray:
    if x is None:
        return np.array([], dtype=np.float32)
    return np.asarray(x, dtype=np.float32).reshape(-1)


def has_nonempty_band(rec: Dict[str, Any], band: str) -> bool:
    try:
        bd = rec["bands_data"]
        v = bd[band]["target"]
        return (v is not None) and (len(v) > 0)
    except Exception:
        return False


def safe_zscore_1d(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    x = to_1d_float(x)
    if x.size == 0:
        return x
    x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
    mu = float(x.mean())
    sigma = float(x.std())
    if (not np.isfinite(sigma)) or sigma < eps:
        return (x - mu).astype(np.float32, copy=False)
    return ((x - mu) / sigma).astype(np.float32, copy=False)


def pad_or_trim_right_aligned(x: np.ndarray, ctx: int, pad_value: float = 0.0) -> Tuple[np.ndarray, int]:
    x = to_1d_float(x)
    if x.size == 0:
        return np.full((ctx,), pad_value, dtype=np.float32), 0

    x = np.nan_to_num(x, nan=pad_value, posinf=pad_value, neginf=pad_value).astype(np.float32, copy=False)

    if x.size > ctx:
        x = x[-ctx:]
    L = int(x.size)

    out = np.full((ctx,), pad_value, dtype=np.float32)
    out[-L:] = x
    return out, L

def upsert_column(ds: Dataset, name: str, values):
    if name in ds.column_names:
        ds = ds.remove_columns([name])
    return ds.add_column(name, values)


# ============================================================
# TimeMoE embedder (FIXED: use model's own learned input embedding)
# ============================================================
class TimeMoeEmbedder:
    
    def __init__(self, model, device: torch.device, dtype: torch.dtype, pooling: str):
        self.model = model.eval()
        self.device = device
        self.dtype = dtype
        self.pooling = pooling

    @classmethod
    def from_pretrained(cls, model_id: str, device: torch.device, dtype: torch.dtype, pooling: str):
        device_map = {"": 0} if device.type == "cuda" else {"": "cpu"}
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=device_map,
            torch_dtype=dtype,
            trust_remote_code=True,
        )
        print("[TimeMoE] loaded:", type(model))
        return cls(model=model, device=device, dtype=dtype, pooling=pooling)

    @torch.no_grad()
    def embed(self, X: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        X: [B,T] float32 CPU ok
        lengths: [B] int
        returns: [B,H] float32 CPU
        """
        # TimeMoE expects float "input_ids" shaped [B,T,1] (it will call embed_layer(input_ids))
        X = X.to(self.device, dtype=self.dtype).unsqueeze(-1)  # [B,T,1]
        B, T, _ = X.shape

        lengths = lengths.to(self.device)
        ar = torch.arange(T, device=self.device).unsqueeze(0).expand(B, T)  # [B,T]
        thr = (T - lengths).unsqueeze(1)                                     # [B,1]
        attn_mask = (ar >= thr).to(torch.long)                               # [B,T], 1=valid (right-aligned)

        # ✅ 핵심 수정: inputs_embeds가 아니라 input_ids로 전달 (모델의 TimeMoeInputEmbedding을 사용)
        out = self.model(
            input_ids=X,
            attention_mask=attn_mask,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )
        h = out.hidden_states[-1]  # [B,T,H]

        if self.pooling == "last":
            # Right-aligned padding이면 마지막 valid token은 항상 T-1 (length>0 가정)
            last_idx = torch.full((B,), T - 1, device=self.device, dtype=torch.long)
            emb = h[torch.arange(B, device=self.device), last_idx, :]
        else:
            m = attn_mask.to(h.dtype).unsqueeze(-1)
            emb = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1)

        return emb.to(torch.float32).cpu()


# ============================================================
# Embedding routines
# ============================================================
def embed_one_band(
    ds: Dataset,
    band: str,
    ctx: int,
    batch_size: int,
    pad_value: float,
    pipe: TimeMoeEmbedder,
    normalize: str,
) -> Tuple[List[int], List[List[float]]]:
    kept_idx: List[int] = []
    embs: List[List[float]] = []

    buf_x: List[np.ndarray] = []
    buf_L: List[int] = []
    buf_i: List[int] = []

    def flush():
        if not buf_x:
            return
        X = torch.from_numpy(np.stack(buf_x, axis=0).astype(np.float32))  # [B,T]
        L = torch.tensor(np.asarray(buf_L, dtype=np.int64))
        E = pipe.embed(X, L).numpy()  # [B,H]
        for ii, ee in zip(buf_i, E):
            kept_idx.append(ii)
            embs.append(ee.astype(np.float32, copy=False).tolist())
        buf_x.clear(); buf_L.clear(); buf_i.clear()

    for i in range(len(ds)):
        rec = ds[i]
        if not has_nonempty_band(rec, band):
            continue

        x_raw = rec["bands_data"][band]["target"]
        x = to_1d_float(x_raw)

        if normalize == "zscore":
            x = safe_zscore_1d(x)
        elif normalize == "none":
            x = np.nan_to_num(x, nan=pad_value, posinf=pad_value, neginf=pad_value).astype(np.float32, copy=False)
        else:
            raise ValueError("normalize must be one of: none, zscore")

        x_pad, L = pad_or_trim_right_aligned(x, ctx=ctx, pad_value=pad_value)
        if L <= 0:
            continue

        buf_x.append(x_pad); buf_L.append(L); buf_i.append(i)
        if len(buf_x) >= batch_size:
            flush()

    flush()
    return kept_idx, embs


def fuse_gr(
    g_idx: List[int],
    g_embs: List[List[float]],
    r_idx: List[int],
    r_embs: List[List[float]],
    method: str,
) -> Tuple[List[int], List[List[float]], List[List[float]], List[List[float]]]:
    map_g = {i: e for i, e in zip(g_idx, g_embs)}
    map_r = {i: e for i, e in zip(r_idx, r_embs)}
    common = sorted(set(map_g.keys()) & set(map_r.keys()))

    out_g, out_r, out_f = [], [], []
    for i in common:
        eg = np.asarray(map_g[i], dtype=np.float32)
        er = np.asarray(map_r[i], dtype=np.float32)

        if method == "concat":
            ef = np.concatenate([eg, er], axis=0)
        elif method == "avg":
            if eg.shape != er.shape:
                raise ValueError(f"avg fuse requires same dims, got {eg.shape} vs {er.shape}")
            ef = 0.5 * (eg + er)
        else:
            raise ValueError("fuse must be 'concat' or 'avg'")

        out_g.append(eg.tolist())
        out_r.append(er.tolist())
        out_f.append(ef.astype(np.float32, copy=False).tolist())

    return common, out_g, out_r, out_f


# ============================================================
# Main
# ============================================================
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dataset_path", required=True, help="Local dataset root (DatasetDict or split folders).")
    ap.add_argument("--out_dir", required=True, help="Output DatasetDict root (save_to_disk).")

    ap.add_argument("--model_id", default="Maple728/TimeMoE-200M")
    ap.add_argument("--ctx", type=int, default=200)
    ap.add_argument("--splits", type=str, default="train,validation,test,anom")
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--pad_value", type=float, default=0.0)

    ap.add_argument("--pooling", choices=["mean", "last"], default="mean")
    ap.add_argument("--fuse", choices=["concat", "avg"], default="concat")

    ap.add_argument("--normalize", choices=["none", "zscore"], default="none")
    ap.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
    args = ap.parse_args()

    split_list = [s.strip() for s in args.splits.split(",") if s.strip()]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch_dtype = {
        "bf16": (torch.bfloat16 if device.type == "cuda" else torch.float32),
        "fp16": (torch.float16 if device.type == "cuda" else torch.float32),
        "fp32": torch.float32,
    }[args.dtype]

    print("========================================")
    print("[INFO] dataset_path :", args.dataset_path)
    print("[INFO] out_dir      :", args.out_dir)
    print("[INFO] model_id     :", args.model_id)
    print("[INFO] device       :", device)
    print("[INFO] dtype        :", torch_dtype)
    print("[INFO] ctx          :", args.ctx)
    print("[INFO] splits       :", split_list)
    print("[INFO] batch_size   :", args.batch_size)
    print("[INFO] pooling      :", args.pooling)
    print("[INFO] fuse         :", args.fuse)
    print("[INFO] normalize    :", args.normalize)
    print("========================================")

    dd_in = safe_load_datasetdict(args.dataset_path, splits=split_list)

    pipe = TimeMoeEmbedder.from_pretrained(
        model_id=args.model_id,
        device=device,
        dtype=torch_dtype,
        pooling=args.pooling,
    )

    out_splits = {}
    for sp, ds in dd_in.items():
        print(f"\n=== [Split: {sp}] raw N={len(ds)} ===")

        g_idx, g_embs = embed_one_band(
            ds=ds, band="g", ctx=args.ctx, batch_size=args.batch_size,
            pad_value=args.pad_value, pipe=pipe, normalize=args.normalize,
        )
        r_idx, r_embs = embed_one_band(
            ds=ds, band="r", ctx=args.ctx, batch_size=args.batch_size,
            pad_value=args.pad_value, pipe=pipe, normalize=args.normalize,
        )

        common_idx, gg, rr, ff = fuse_gr(g_idx, g_embs, r_idx, r_embs, method=args.fuse)
        print(f"[Split: {sp}] embedded g={len(g_idx)} r={len(r_idx)} | common={len(common_idx)}")

        ds_sub = ds.select(common_idx)

        ds_sub = upsert_column(ds_sub, "embeddings_g", gg)
        ds_sub = upsert_column(ds_sub, "embeddings_r", rr)
        ds_sub = upsert_column(ds_sub, "combined_embedding", ff)

        ds_sub = upsert_column(ds_sub, "emb_g", gg)
        ds_sub = upsert_column(ds_sub, "emb_r", rr)
        ds_sub = upsert_column(ds_sub, "emb_gr_fused", ff)

        out_splits[sp] = ds_sub

    final = DatasetDict(out_splits)
    os.makedirs(args.out_dir, exist_ok=True)
    final.save_to_disk(args.out_dir)

    print("\n[SAVED] DatasetDict ->", args.out_dir)
    print("[EXPECT] dataset_dict.json at root, plus split folders.")


if __name__ == "__main__":
    main()
