import numpy as np
from tqdm.auto import tqdm
from datasets import load_from_disk
from ASTROMER.models import SingleBandEncoder
from datasets.utils.logging import enable_progress_bar
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, required=True)
    parser.add_argument("--embs_path", type=str, required=True)
    args = parser.parse_args()

    # --- enable HF-datasets tqdm ---
    enable_progress_bar()

    # --- 1) CONFIG & MODEL INIT ---
    SPLIT_PATH    = args.dataset_path
    OUT_PATH      = args.embs_path
    BANDS         = ["g", "r"]     # add "i" if desired
    DURATION      = 200
    ENC_BATCH     = 1024           # tune to fit CPU/GPU memory
    PREPROC_PROCS = 16             # CPU cores for window building

    model = SingleBandEncoder().from_pretraining("macho")

    # --- 2) LOAD SPLITS ---
    splits = load_from_disk(SPLIT_PATH)

    # --- 3) PREPROCESS: BUILD WINDOWS FOR EACH BAND ---
    def make_windows(example, band):
        bd = example["bands_data"].get(band)
        if bd is None:
            example[f"windows_{band}"] = None
        else:
            mjd = np.array(bd["mjd"],                    dtype=np.float32)
            mag = np.array(bd["target"],                 dtype=np.float32)
            err = np.array(bd["past_feat_dynamic_real"], dtype=np.float32)
            arr = np.stack([mjd, mag, err], axis=1)      # (L,3)
            L   = arr.shape[0]
            if L >= DURATION:
                win = arr[-DURATION:]
            else:
                pad = np.zeros((DURATION - L, 3), dtype=np.float32)
                win = np.vstack([arr, pad]) # pad 0 at the end
            win = win - win.mean(axis=0, keepdims=True)
            example[f"windows_{band}"] = win
        return example

    for band in tqdm(BANDS, desc="Building windows for bands"):
        splits = splits.map(
            lambda ex, b=band: make_windows(ex, b),
            num_proc=PREPROC_PROCS,
            desc=f"make_windows[{band}]"
        )

    # --- 4) ENCODE WINDOWS (no avg) ---
    def encode_batch(batch, band):
        windows = batch[f"windows_{band}"]
        ids     = batch["item_id"]

        valid_ids, wins_valid = [], []
        for i, w in enumerate(windows):
            if w is None: continue
            wins_valid.append(np.array(w, dtype=np.float32))
            valid_ids.append(i)



        # if no valid windows, emit zeros
        if not valid_ids:
            D = model.hidden_size
            B = len(windows)
            emb = np.zeros((B, DURATION, D), dtype=np.float32)
        else:
            out = model.encode(
                wins_valid,
                oids_list=[ids[i] for i in valid_ids],
                batch_size=len(wins_valid),
                concatenate=False
            )[0].cpu().numpy()       # (V, DURATION, D)
            D = out.shape[2]
            B = len(windows)
            emb = np.zeros((B, DURATION, D), dtype=np.float32)
            for loc, idx in enumerate(valid_ids):
                emb[idx] = out[loc]

        return { f"embeddings_{band}": emb.tolist() }

    for band in tqdm(BANDS, desc="Encoding bands"):
        splits = splits.map(
            lambda batch, b=band: encode_batch(batch, b),
            batched=True,
            batch_size=ENC_BATCH,
            num_proc=1,                # single-process for encoding
            remove_columns=[f"windows_{band}"],
            desc=f"encode[{band}]"
        )

    # --- 5) SAVE FINAL DATASET ---
    splits.save_to_disk(OUT_PATH)
    print("Done. Split sizes:", {s: len(splits[s]) for s in splits})
