import pickle
import os
all_actions = ["basketball", "basketball_signal", "directing_traffic", "jumping",
                   "running", "soccer", "walking", "washwindow"]
name_list = ['cmu_train', 'cmu_val', 'cmu_test']
for a in all_actions:
    for n in name_list:
        path=f"/nlp/scr/jiangm/wproject/CGeoDM/data/cmu/{n}_{a}_new.pt"

        with open(path, "rb") as f:
            data = pickle.load(f)
        dataset = data[0]
        print(f"{n} {a} Dataset length:", len(dataset))
    print("---")



def load_dataset_part(file_path):
    with open(file_path, "rb") as f:
        data = pickle.load(f) 
    return data[0]  

def save_dataset(data_list, file_path):
    with open(file_path, "wb") as f:
        pickle.dump((data_list,), f)  

def merge_split(split, class_names, data_dir, output_dir, summary_lines):
    merged_data = []
    per_class_counts = {}

    for class_name in class_names:
        filename = f"cmu_{split}_{class_name}.pt"
        path = os.path.join(data_dir, filename)
        if not os.path.exists(path):
            print(f"⚠️ File not found: {path}")
            continue

        print(f"Loading {path} ...")
        dataset = load_dataset_part(path)
        count = len(dataset)
        print(f" - Loaded {count} samples from {class_name} [{split}]")

        merged_data.extend(dataset)
        per_class_counts[class_name] = count

    out_path = os.path.join(output_dir, f"cmu_{split}_pretrain.pt")
    save_dataset(merged_data, out_path)
    print(f"✅ Saved {split} set: {len(merged_data)} samples to {out_path}\n")


def main():
    data_dir = "/nlp/scr/jiangm/wproject/CGeoDM/data/cmu"         # 输入路径
    output_dir = "/nlp/scr/jiangm/wproject/CGeoDM/data/cmu/"   # 输出路径
    os.makedirs(output_dir, exist_ok=True)

    pretrain_classes = ["washwindow_new", "directing_traffic_new", "basketball_signal_new"]
    splits = ["train", "val", "test"]
    summary_lines = []

    for split in splits:
        merge_split(split, pretrain_classes, data_dir, output_dir, summary_lines)

    
    nlist = ['train', 'val', 'test']
    for n in nlist:
        path=f"/nlp/scr/jiangm/wproject/CGeoDM/data/cmu/cmu_{n}_pretrain.pt"
        with open(path, "rb") as f:
            data = pickle.load(f)
        dataset = data[0]
        print(f"{n} Dataset length:", len(dataset))
    print("---")

if __name__ == "__main__":
    main()
