import shutil
from tqdm import tqdm
from pathlib import Path
from plot.datasets.kitti import read_kitti_calib


def read_mapping(root):
    with open(root / "ImageSets" / "train_mapping.txt") as f:
        train_mapping = f.read().splitlines()
    
    with open(root / "ImageSets" / "train_rand.txt") as f:
        train_rand = f.read().splitlines()[0]

    imgidx_to_rawidx = {}
    rawidx_to_imgidx = {}
    raw_indices = train_rand.strip().split(',')
    for i, rdx in enumerate(raw_indices):
        imgidx_to_rawidx[i] = int(rdx) - 1
        rawidx_to_imgidx[int(rdx) - 1] = i

    imgidx_to_rawframe = {}
    rawidx_to_rawframe = {}
    for rawidx, tmap in enumerate(train_mapping):
        _, raw_folder, frame_idx = tmap.strip().split(' ')
        rawidx_to_rawframe[rawidx] = [raw_folder, frame_idx]
        imgidx_to_rawframe[rawidx_to_imgidx[rawidx]] = [raw_folder, int(frame_idx)]
    
    return imgidx_to_rawidx, rawidx_to_imgidx, imgidx_to_rawframe, rawidx_to_rawframe



if __name__ == '__main__':
    root = Path("PATH_TO_KITTI_PARENTS")
    root_det = root / "KITTI_Det"
    root_raw = root / "KITTI_Raw"
    num_frames = 40
    num_past_frames = num_frames // 2
    num_future_frames = num_frames // 2

    images = sorted(list((root_det / "training" / "image_2").glob("*.png")))

    img2raw, raw2img, img2rawframe, raw2rawframe = read_mapping(root_det)
    save_folder = root_det / "frames"
    save_folder.mkdir(parents=True, exist_ok=True)

    pose_save_dir = root_det / "poses"
    pose_save_dir.mkdir(parents=True, exist_ok=True)

    img_indices = list(range(0, 7481))

    for img_idx in tqdm(img_indices):
        save_img_folder = save_folder / f"{img_idx:06d}"
        if save_img_folder.exists():
            shutil.rmtree(save_img_folder)
        save_img_folder.mkdir(parents=True, exist_ok=True)
        
        frame_num = images[img_idx].stem
        
        raw_folder, raw_frame = img2rawframe[img_idx]

        date = str(raw_folder).split("_drive_")[0]

        f_idx = 0
        
        # past frames
        for i in range(num_past_frames, 0, -1):
            raw_frame_num = f"{raw_frame-i:010d}"
            source_file = root_raw / date / raw_folder / "image_02" / "data" / f"{raw_frame_num}.png"
            if source_file.exists():
                shutil.copyfile(source_file, save_img_folder / f"{frame_num}_{f_idx:02d}.png")
            f_idx += 1

        # target frame
        shutil.copyfile(images[img_idx], save_img_folder / f"{frame_num}_{f_idx:02d}.png")
        f_idx += 1


        # future frames
        for i in range(1, num_future_frames+1):
            raw_frame_num = f"{raw_frame+i:010d}"
            source_file = root_raw / date / raw_folder / "image_02" / "data" / f"{raw_frame_num}.png"
            if source_file.exists():
                shutil.copyfile(source_file, save_img_folder / f"{frame_num}_{f_idx:02d}.png")
            f_idx += 1
        
        # break