import sys, os
lib_path = ""
sys.path.append(lib_path)

from mugen_data import MUGENDataset
import os
import torch
import re
import json
from torchvision.io import write_video

def split_text(text):
    regex_expr = "(, and|and)(?! a | gems| coins)"
    regex_expr1 = ", and |and"
    subsentences = [t.strip() for t in re.split(regex_expr, text) if not ('and' == t or ', and' == t)]

    return subsentences


class SVMugenDataset(MUGENDataset):
    def __init__(
            self,
            args,
            split='train',
            save_video = False,
            video_save_dir = None
    ):
        super().__init__(args, split)
        self.save_video = save_video
        self.video_save_dir = video_save_dir

    def __getitem__(self, idx):
        self.load_json_file(idx)
        game_name = self.data[idx]["video"]["json_file"].split('/')[-1][:-5]

        start_idx, end_idx = self.get_start_end_idx()
        alien_name = 'Mugen'

        result_dict = {}
        result_dict['game_name'] = game_name

        if self.args.get_audio:
            wav_file = os.path.join(self.dataset_metadata["data_folder"], self.data[idx]["video"]["video_file"])
            result_dict['audio'] = self.get_game_audio(wav_file)

        if self.args.get_game_frame:
            game_video = self.get_game_video(start_idx, end_idx, alien_name=alien_name)
            if self.save_video:
                save_video_file = os.path.join(self.video_save_dir, f"{game_name}.mp4")
                write_video(save_video_file, game_video, fps=int(game_video.shape[0] / 3.2))

            result_dict['video'] = game_video


        if self.args.get_seg_map:
            seg_map_video = self.get_smap_video(start_idx, end_idx, alien_name=alien_name)
            result_dict["video_smap"] = seg_map_video
            print('here')

        if self.args.get_text_desc:
            # text description will be generated in the range of start and end frames
            # this means we can use full json and auto-text to train transformer too

            assert self.args.use_auto_annotation or self.args.use_manual_annotation
            if self.args.use_manual_annotation and not self.args.use_auto_annotation:
                assert len(self.data[idx]["annotations"]) > 1, "need at least one manual annotation if using only manual annotations"
                # exclude the auto-text, which is always index 0
                rand_idx = torch.randint(low=1, high=len(self.data[idx]["annotations"]), size=(1,)).item() if self.train else 1
            elif not self.args.use_manual_annotation and self.args.use_auto_annotation:
                rand_idx = 0
            else:
                rand_idx = torch.randint(low=0, high=len(self.data[idx]["annotations"]), size=(1,)).item()

            if self.args.use_manual_annotation and not self.args.use_auto_annotation:
                assert self.data[idx]["annotations"][rand_idx]["type"] == "manual", "Should only be sampling manual annotations"

            text_desc = self.data[idx]["annotations"][rand_idx]["text"]

            result_dict['text'] = split_text(text_desc)

        if self.save_video:
            dp_info = {}
            for d in self.data[idx]["annotations"]:
                dp_info[d['type']] = d['text']
            save_text_file = os.path.join(self.video_save_dir, f"{game_name}.json")
            json.dump(dp_info, open(save_text_file, 'w'))

        return result_dict

    def collate_fn(self, result_dict_ls):
        combined_result_dict = {}
        result_keys = list(result_dict_ls[0].keys())
        for key in result_keys:
            combined_result_dict[key] = []
        combined_result_dict['text_idx'] = []
        current_text_id = 0

        for result_dict in result_dict_ls:
            for key in result_keys:
                if key == "game_name":
                    combined_result_dict[key].append(result_dict[key])
                else:
                    combined_result_dict[key] += result_dict[key]
                if key == 'audio' or key == 'video_smap':
                    raise NotImplementedError
                if key == 'text':
                    next_text_id = current_text_id + len(result_dict[key])
                    combined_result_dict['text_idx'].append((current_text_id, next_text_id))
                    current_text_id = next_text_id

        if 'video' in combined_result_dict:
            # Group the video
            combined_result_dict['video'] = torch.stack(combined_result_dict['video'])
            # combined_result_dict['video'] = combined_result_dict['video'].unfold(0, 4, 2).reshape(-1, 4, 256, 256, 3) // learns slower
            combined_result_dict['video'] = combined_result_dict['video'].reshape(-1, 2, 256, 256, 3)

        return combined_result_dict
