import os
import numpy as np

# 파일에서 데이터 읽기
with open("scripts/human-seq.txt", "r") as f:
    lines = f.readlines()

# 각 줄을 정수로 변환하고, z 값을 0/0.5/1.0으로 매핑
data_list = []
for line in lines:
    x, y, z = map(int, line.strip().split())
    if z == 1:
        mu = 0.0
    elif z == 2:
        mu = 0.5
    elif z == 3:
        mu = 1.0
    else:
        raise ValueError(f"Unexpected z value: {z}")
    data_list.append(((x, x + 25), (y, y + 25), mu))

print(f"Total pairs: {len(data_list)}")

seq_data_list = []
is_increasing = None

before_trajectory_list = []
last_trajectory_list = []


# 시퀀스 데이터 생성
for i in range(len(data_list)):
    s0, s1, mu = data_list[i]

    if is_increasing is None:
        is_increasing = mu
        seq_data_list.append((s0, s1, mu))
        if mu == 0.5:
            last_trajectory_list.append(s0)
            last_trajectory_list.append(s1)
        else:
            before_trajectory_list.append(s0)
            last_trajectory_list.append(s1)
        continue

    if mu == 0.5:
        for traj in before_trajectory_list:
            seq_data_list.append((traj, s1, is_increasing))
        for traj in last_trajectory_list:
            seq_data_list.append((traj, s1, 0.5))
        last_trajectory_list.append(s1)
    elif mu == is_increasing or is_increasing == 0.5:
        # 이전 시퀀스와 동일한 방향
        for traj in before_trajectory_list:
            seq_data_list.append((traj, s1, mu))
        for traj in last_trajectory_list:
            seq_data_list.append((traj, s1, mu))
        before_trajectory_list.extend(last_trajectory_list)
        last_trajectory_list = [s1]

        if is_increasing == 0.5:
            is_increasing = mu
    else:
        # 이전 시퀀스와 반대 방향
        for traj in last_trajectory_list:
            seq_data_list.append((traj, s1, mu))
        before_trajectory_list = last_trajectory_list
        last_trajectory_list = [s1]
        is_increasing = mu


print(f"Total sequences: {len(seq_data_list)}")

array = np.array(
    seq_data_list, dtype=[("s0", "i4", (2,)), ("s1", "i4", (2,)), ("mu", "f4")]
)

for i in range(0, 10):
    save_dir = f"pair/button-press-topdown-v2/HUMAN-{i:02d}/train"
    os.makedirs(save_dir, exist_ok=True)
    np.savez(os.path.join(save_dir, "human-seq.npz"), data=array)
