# %%
import json
from tqdm import tqdm
import re

import os

load_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env_dataset",
    "dataset_8192.jsonl",
)

train_split_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env_dataset",
    "trainset_8192.jsonl",
)

val_split_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env_dataset",
    "valset_8192.jsonl",
)

print("Start loading data...")
with open(load_path, 'r') as f:
    data = []
    for line in tqdm(f, desc="Loading data"):
        try:
            item = json.loads(line)
            data.append(item)
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            continue
print("Data loaded successfully. Total items:", len(data))
# %%

train_split_ratio = 0.8

train_data = []
val_data = []

for item in tqdm(data, desc="Splitting data"):
    match = re.search(r'"subject_id":\s*(\d+)', item['messages'][1]['content'])
    if match:
        subject_id = int(match.group(1))
        if subject_id % 100 < (100 * train_split_ratio):
            train_data.append(item)
        else:
            val_data.append(item)
    else:
        print(f"Warning: No subject_id found in item: {item['messages'][1]['content']}")

print(f"Training data size: {len(train_data)}")
print(f"Validation data size: {len(val_data)}")
# %%
print("Saving trainset...")
with open(train_split_path, 'w') as f:
    for item in train_data:
        f.write(json.dumps(item) + '\n')

print("Trainset saved successfully.")
print("Saving validation set...")
with open(val_split_path, 'w') as f:
    for item in val_data:
        f.write(json.dumps(item) + '\n')
print("Validation set saved successfully.")

# %%
