import argparse
import glob
import json
import pathlib
import pickle


def main():
    batch_files = glob.glob(str(BASE_DIR.joinpath(
        VIDEOS_DIR,
        f"prompt_{PROMPT_NUMBER}",
        MODEL_NAME,
        "*.pkl",
    )))

    GENERATED_SENTENCES = {}

    for batch_file in batch_files:
        with open(batch_file, "rb") as f:
            batches = pickle.load(f)

        for batch in batches:
            generated_sentence = batch["generated_text"]
            video_path = batch["video_path"]

            GENERATED_SENTENCES[video_path] = generated_sentence

    print(f"Found {len(GENERATED_SENTENCES)} sentences")
    with open(str(OUTPUT_DIR.joinpath("generated_sentences.json")), "w") as f:
        json.dump(GENERATED_SENTENCES, f, indent="  ")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog="Sentences extracter")

    _ = parser.add_argument(
        "-d",
        "--base-dir",
        required=False,
        type=pathlib.Path,
        # default=pathlib.Path("/tmp/akshett.jindal"),
        help="The path to the directory where all the models outputs will be stored and loaded from",
    )
    _ = parser.add_argument(
        "-m",
        "--model-id",
        type=str,
        required=True,
        help="The model id whose hidden state representations are to be used",
    )
    _ = parser.add_argument(
        "-p",
        "--prompt-number",
        type=int,
        required=True,
        help="The prompt number to use for aligning",
    )

    _ = parser.add_argument(
        "-v",
        "--videos-folder",
        type=str,
        required=True,
        help="The movie folder",
    )


    args = parser.parse_args()

    BASE_DIR: pathlib.Path = args.base_dir
    MODEL_ID: str = args.model_id
    PROMPT_NUMBER: int = args.prompt_number
    VIDEOS_DIR: pathlib.Path = args.videos_folder

    MODEL_NAME = MODEL_ID.replace("/", "_").replace(" ", "_")

    OUTPUT_DIR = BASE_DIR.joinpath(
            "generated_sentences",
            VIDEOS_DIR,
            f"prompt_{PROMPT_NUMBER}",
            MODEL_NAME,
    )

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    main()
