"""Split videos into train, val and test and save captions to pickle.

The 'filter' argument of find_videos can be used to create a subset of the dataset.
"""
import glob
import json
import os
import pickle
import shutil

import numpy as np


def find_videos(video_dir, filter=None):
    paths = glob.glob(os.path.join(video_dir, "*.avi"))
    videos = [os.path.basename(path) for path in paths]
    n_total = len(videos)
    if filter:
        videos = [v for v in videos if filter(v)]
        print("Filtered out", n_total - len(videos), "videos")
    return videos


def find_captions(video_dir, pattern):
    files = glob.glob(os.path.join(video_dir, pattern))
    if isinstance(files, str):
        files = [files]
    all_captions = {}
    is_pkl = os.path.splitext(pattern)[1] == ".pkl"
    mode = "rb" if is_pkl else "r"
    for file in files:
        with open(file, mode) as f:
            if is_pkl:
                captions = pickle.load(f)
            else:
                captions = json.load(f)
            for v_id, caption in captions.items():
                v_id = os.path.basename(v_id)
                if v_id in all_captions:
                    print("Duplicate video id:", v_id)
                    print("Old caption:", all_captions[v_id])
                    print("New caption:", caption)
                all_captions[v_id] = caption
    return all_captions


def format_captions(captions):
    """Match the format in previous raw-captions.pkl files."""
    formatted_captions = {}
    for v_id, caption in captions.items():
        if caption[-1] == ".":
            caption = caption[:-1]
        caption = caption[0].capitalize() + caption[1:]
        if v_id[-4:] == ".avi":
            v_id = v_id[:-4]
        formatted_captions[v_id] = [caption.split(" ")]
    return formatted_captions


def split_videos(videos, fractions, seed=None):
    assert sum(fractions) == 1
    assert len(fractions) == 3
    np.random.seed(seed)
    np.random.shuffle(videos)
    num_videos = len(videos)
    train_videos = videos[: int(num_videos * fractions[0])]
    val_videos = videos[
        int(num_videos * fractions[0]) : int(num_videos * fractions[0])
        + int(num_videos * fractions[1])
    ]
    test_videos = videos[
        int(num_videos * fractions[0]) + int(num_videos * fractions[1]) :
    ]
    return train_videos, val_videos, test_videos


def load_existing_splits(split_dir):
    existing_splits = {}
    for split in ["train", "val", "test"]:
        with open(os.path.join(split_dir, f"{split}_list.txt"), "r") as f:
            existing_splits[split] = f.read().split("\n")
    return existing_splits


if __name__ == "__main__":
    top_dir = f"{os.environ['WORK']}/nlr/vlm_dataset"
    in_dir = os.path.join(top_dir, "pick_videos")
    out_dir = os.path.join(top_dir, "pick_videos_small")
    CREATE_SMALL_DATASET = True
    # 'succeeded' corresponds to moving tasks rather than picking.
    videos = find_videos(in_dir, filter=lambda x: not "succeeded" in x)
    videos = [v[:-4] if v[-4:] == ".avi" else v for v in videos]

    captions_file = "raw-captions.pkl"  # *.json
    preformatted = os.path.splitext(captions_file)[1] == ".pkl"
    # captions = find_captions(in_dir, "*.json")
    captions = find_captions(in_dir, "raw-captions.pkl")
    if not preformatted:
        captions = format_captions(captions)
    # captions = {
    #     k: v for k, v in captions.items() if ["Perform", "failed", "grasp"] not in v
    # }

    split_dir = os.path.join(top_dir, "pick_videos")
    if split_dir:
        splits = load_existing_splits(split_dir)
    else:
        splits = split_videos(videos, [0.8, 0.1, 0.1], seed=0)

    videos = [v for v in videos if v in captions]
    for k in splits:
        splits[k] = [v for v in splits[k] if v in captions]
    if CREATE_SMALL_DATASET:
        print("Train size before:", len(splits["train"]))
        splits["train"] = splits["train"][0:-1:20]
        print("Train size after:", len(splits["train"]))
    for k, v in splits.items():
        print(f"{k}: {len(v) / len(videos)}")

    print(list(captions)[:5])
    print(list(videos)[:5])

    for v_id in videos:
        if v_id not in captions:
            print("Missing caption for video:", v_id)
    captions = {k: v for k, v in captions.items() if k in videos}

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    with open(os.path.join(out_dir, "raw-captions.pkl"), "wb") as f:
        raw_captions = pickle.dump(captions, f)

    for k, v in splits.items():
        with open(os.path.join(out_dir, f"{k}_list.txt"), "w") as f:
            f.write("\n".join(v))

    # This is slow: uncomment if needed.
    # for v_id in videos:
    #     shutil.copyfile(
    #         os.path.join(in_dir, f"{v_id}.avi"), os.path.join(out_dir, f"{v_id}.avi")
    #     )
