import json
from pathlib import Path
from moviepy.editor import ImageSequenceClip
from nuscenes.nuscenes import NuScenes, NuScenesExplorer
from nuscenes.utils.splits import create_splits_scenes
from visual_tokens import viz_scenes
import os

def make_split(nusc: NuScenes):
    split_dict = {"train": {}, "val": {}}
    train_scenes = create_splits_scenes()['train']
    val_scenes = create_splits_scenes()['val']
    for scene in nusc.scene:
        if scene['name'] in train_scenes:
            split_dict['train'][scene['name']] = []
            sample_token = scene['first_sample_token']
            last_sample_token = scene['last_sample_token']
            split_dict['train'][scene['name']].append(sample_token)
            while sample_token != "":
                sample = nusc.get('sample', sample_token)
                sample_token = sample['next']
                split_dict['train'][scene['name']].append(sample_token)
                if sample_token == last_sample_token:
                    break
        elif scene['name'] in val_scenes:
            split_dict['val'][scene['name']] = []
            sample_token = scene['first_sample_token']
            last_sample_token = scene['last_sample_token']
            split_dict['val'][scene['name']].append(sample_token)
            while sample_token != "":
                sample = nusc.get('sample', sample_token)
                sample_token = sample['next']
                split_dict['val'][scene['name']].append(sample_token)
                if sample_token == last_sample_token:
                    break
    json.dump(split_dict, open("data/viz/full_split.json", "w"))
    return split_dict


def make_video(split, post_fix="_all"):
    for scene in split['val'].keys():
        if scene not in viz_scenes:
            continue
        print(scene)
        image_files = []  # Add your image paths
        sample_tokens = split['val'][scene]
        for sample_token in sample_tokens:
            path = Path("experiments/visualization") / Path(sample_token + post_fix + '.jpg')
            if path.exists():
                image_files.append(str(path))
            # else:
                # return
        clip = ImageSequenceClip(image_files, fps=4)  # fps = frames per second
        # Write the video file
        video_path = Path("experiments/visualization/videos") / Path(scene + '.mp4')
        clip.write_videofile(str(video_path))
    return

def make_video_gpt(split, post_fix="_all"):

    # import re
    def custom_sort(path):
        # Split the path into parts using '/'
        parts = path.split('/')
        # Extract the filename
        filename = parts[-1]
        # Extract the numeric part before the file extension
        num = int(filename.split('.')[0])
        return num

    all_images = []
    for scene in split['val'].keys():
        if scene not in [
            "scene-0332", "scene-0925",
            "scene-1072",
        ]:
            continue
        print(scene)
        image_files = []  # Add your image paths
        folder_path = video_path = Path(f"experiments/visualization/seq_images/{scene}")
        for path in folder_path.iterdir():
            image_files.append(str(path))
            
        image_files = sorted(image_files, key=custom_sort)
        print(image_files)
        all_images.extend(image_files)
    clip = ImageSequenceClip(all_images, fps=4)  # fps = frames per second
    # Write the video file
    video_path = Path("experiments/visualization/videos") / Path("demo_gptdriver_small" + '.mp4')
    clip.write_videofile(str(video_path))
    return

def make_seq_images(split, post_fix="_all"):
    for scene in split['val'].keys():
        if scene not in viz_scenes:
            continue
        print(scene)
        image_files = []  # Add your image paths
        sample_tokens = split['val'][scene]
        for sample_token in sample_tokens:
            path = Path("experiments/visualization") / Path(sample_token + post_fix + '.jpg')
            if path.exists():
                image_files.append(str(path))
        os.system("mkdir -p experiments/visualization/seq_images/{}".format(scene))
        for i, image_file in enumerate(image_files):
            os.system("cp {} experiments/visualization/seq_images/{}/{}.jpg".format(image_file, scene, i))
    return

if __name__ == "__main__":
    # nusc = NuScenes(version="v1.0-trainval", dataroot="~/Datasets/nuScenes", verbose=True)
    # split = make_split(nusc)
    split = json.load(open('data/viz/full_split.json', 'r'))
    # make_seq_images(split)
    make_video_gpt(split)
