#!/usr/bin/env python3
import os
import argparse
import numpy as np
from tqdm import tqdm
import random

def load_npz_frames(npz_path):
    data = np.load(npz_path, allow_pickle=True)
    keys = list(data.keys())
    try:
        keys_sorted = sorted(keys, key=lambda k: int(k))
    except Exception:
        keys_sorted = keys
    frames = [data[k] for k in keys_sorted]
    return keys_sorted, frames

def save_npz_frames(frames, out_path):
    save_dict = {}
    for i, f in enumerate(frames):
        save_dict[str(i)] = f.astype(np.float32)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    np.savez_compressed(out_path, **save_dict)

def synthesize_variant(frames, motion_scale_std=0.12, jitter_std=0.002):
    T = len(frames)
    new_frames = []
    prev_new = frames[0].astype(np.float32).copy()
    new_frames.append(prev_new.copy())

    for t in range(1, T):
        A = frames[t-1].astype(np.float32)
        B = frames[t].astype(np.float32)
        nA, nB = A.shape[0], B.shape[0]

        if nA == nB:
            delta = B - A
        else:
            if nA < nB:
                reps = int(np.ceil(nB / nA))
                A_rep = np.concatenate([A] * reps, axis=0)[:nB]
                delta = B - A_rep
            else:
                reps = int(np.ceil(nA / nB))
                B_rep = np.concatenate([B] * reps, axis=0)[:nA]
                delta = (B_rep - A)[:min(nA, nB)]

        scale = 1.0 + float(np.random.normal(0.0, motion_scale_std))
        noise = np.random.normal(0.0, jitter_std, delta.shape).astype(np.float32)
        delta2 = delta * scale + noise

        pn, dn = prev_new.shape[0], delta2.shape[0]
        if pn == dn:
            curr = prev_new + delta2
        else:
            if pn < dn:
                prev_rep = np.concatenate([prev_new] * int(np.ceil(dn / pn)), axis=0)[:dn]
                curr = prev_rep + delta2
            else:
                delta2_rep = np.concatenate([delta2] * int(np.ceil(pn / dn)), axis=0)[:pn]
                curr = prev_new + delta2_rep

        new_frames.append(curr.astype(np.float32).copy())
        prev_new = curr.astype(np.float32).copy()

    return new_frames

def process_one_file(root, rel_path, variants, out_name_suffix="__synth"):
    in_full = os.path.join(root, 'Sentences', rel_path)
    if not os.path.exists(in_full):
        print(f"[WARN] source missing: {in_full}")
        return []

    try:
        keys, frames = load_npz_frames(in_full)
    except Exception as e:
        print(f"[ERROR] loading {in_full}: {e}")
        return []

    created = []
    base_dir = os.path.dirname(rel_path)
    base_name = os.path.splitext(os.path.basename(rel_path))[0]

    for v in range(variants):
        new_rel_dir = base_dir
        new_base = f"{base_name}{out_name_suffix}_{v}"
        new_rel = os.path.join(new_rel_dir, new_base + ".npz")
        out_full = os.path.join(root, 'Sentences', new_rel)

        new_frames = synthesize_variant(frames)

        try:
            save_npz_frames(new_frames, out_full)
            created.append(new_rel)
        except Exception as e:
            print(f"[ERROR] saving {out_full}: {e}")

    return created

def main():
    parser = argparse.ArgumentParser(description="Synthesize point-cloud variants for WatchYourMouth dataset")
    parser.add_argument('--data-path', default='/home/ubuntu-user/Documents/AtticusDon/', help='dataset root')
    parser.add_argument('--input-list', default='Sentences/Train.txt', help='Train list relative to data-path')
    parser.add_argument('--variants-per-sample', type=int, default=1, help='每个原样本生成多少变体')
    parser.add_argument('--subset-size', type=int, default=-1, help='从原始样本中随机抽取多少个用于合成（-1表示全部）')
    parser.add_argument('--out-list', default='Sentences/augment/Train_augmented.txt', help='输出合并后的训练列表')
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)

    data_root = args.data_path
    input_list_path = os.path.join(data_root, args.input_list)
    if not os.path.exists(input_list_path):
        raise FileNotFoundError(f"Input list not found: {input_list_path}")

    with open(input_list_path, 'r') as f:
        lines = [ln.strip() for ln in f if ln.strip()]

    if args.subset_size > 0 and args.subset_size < len(lines):
        lines = random.sample(lines, args.subset_size)
        print(f"[INFO] Subset mode: using {len(lines)} samples from original list.")

    new_entries = []
    print("[INFO] Starting synthesis ...")
    for rel in tqdm(lines, desc="Files"):
        created = process_one_file(data_root, rel, args.variants_per-sample)
        new_entries.extend(created)

    out_list_path = os.path.join(data_root, args.out_list)
    os.makedirs(os.path.dirname(out_list_path), exist_ok=True)
    with open(out_list_path, 'w') as f:
        for rel in lines:
            f.write(rel + '\n')
        for rel in new_entries:
            f.write(rel + '\n')

    print(f"[DONE] Generated {len(new_entries)} new samples. Augmented list saved to: {out_list_path}")
    print("Use this list in training via --train-dataset", os.path.relpath(out_list_path, data_root))

if __name__ == "__main__":
    main()
