import os
import json
from helper import generate_prompt_for_baseline  

DATA_DIR = "/home/ubuntu/projects/time_series_main/data"
OUTPUT_DIR = os.path.join(DATA_DIR, "json")
os.makedirs(OUTPUT_DIR, exist_ok=True)

domains = [
    "air quality", "crime", "border crossing", "demography", "road injuries",
    "covid", "co2", "diet", "online retail", "walmart", "agriculture"
]

def build_entry(name, domain):
    meta_path = os.path.join(DATA_DIR, "metadata", f"{name}.json")
    ts_path = os.path.join(DATA_DIR, "time series", f"{name}.txt")
    cap_path = os.path.join(DATA_DIR, "captions", f"{name}.txt")

    with open(meta_path) as f:
        metadata = json.load(f)
    with open(ts_path) as f:
        ts = ", ".join([line.strip() for line in f])
    with open(cap_path) as f:
        caption = f.read().strip()

    prompt = generate_prompt_for_baseline(domain, metadata, ts)

    return {
        "image": f"{name}.jpeg",
        "conversations": [
            {"from": "human", "value": f"<image>\n{prompt.strip()}"},
            {"from": "gpt", "value": caption.strip()}
        ]
    }

# def is_valid(name):
#     try:
#         num = int(name.split("_")[-1])
#         return num % 3 != 0
#     except:
#         return False

def main():
    train_data = []

    for domain in domains:
        all_names = [
            f.replace(".txt", "")
            for f in os.listdir(os.path.join(DATA_DIR, "captions"))
            if f.startswith(f"{domain}_")
        ]
        # valid_names = [n for n in all_names if is_valid(n)]

        for name in sorted(all_names):
            entry = build_entry(name, domain)
            train_data.append(entry)

    with open(os.path.join(OUTPUT_DIR, "tsqa_train_all.json"), "w") as f:
        json.dump(train_data, f, indent=2)

    print(f"Saved {len(train_data)} training samples.")

if __name__ == "__main__":
    main()