import torch
import os
import json

from tqdm import tqdm
from transformers import PreTrainedTokenizer

from data.utils import DictWithTo


class EgoExo4DSegmentSummary(torch.utils.data.Dataset):
    evaluation_kwargs = DictWithTo(evaluator='stream_generate')

    def __init__(self, frame_fps: int, system_prompt: str, tokenizer: PreTrainedTokenizer, local_debug: bool, vision_pretrained: str, embed_mark: str, **kwargs):
        if local_debug:
            self.root = ""
            self.video_root = os.path.join(self.root, 'egocentric')
            self.anno_root = os.path.join(self.root, 'annotations')
        else:
            self.root = ""
            self.video_root = os.path.join(self.root, 'egocentric')
            self.anno_root = os.path.join(self.root, 'annotations')
        self.frame_fps = frame_fps
        self.local_debug = local_debug
        self.embed_dir = f"{self.video_root}_{embed_mark}_{vision_pretrained.replace('/', '--')}"
        self.metadata = self.get_metadata() if not local_debug else None

        self.system_prompt = system_prompt
        self.tokenizer = tokenizer

        annos = json.load(open(os.path.join(self.anno_root, 'refined_egocentric_atomic_descriptions_gpt_val.json')))
        self.annos = []
        for video_uid, _annotation_uid_narrations in tqdm(annos.items(), desc='Constructing dataset for real stream generation...'):
            if self.local_debug:
                self.metadata = {video_uid: {"duration": 10000000, "path": f"{video_uid}_tmp.pt"}}

            start, end = float("inf"), 0
            for annotation in _annotation_uid_narrations:
                narrations = annotation['descriptions']
                if not narrations:
                    continue
                start = min(start, narrations[0]['timestamp'])
                end = max(end, narrations[-1]['timestamp'])

            if end == 0:
                continue

            video_path = self.metadata[video_uid]['path']
            self.annos.append((video_path, int(start * frame_fps), int(end * frame_fps) + 1))

    def get_metadata(self):
        metadata_path = f'{self.embed_dir}_metadata.json'
        metadata = json.load(open(metadata_path))
        return metadata

    def __getitem__(self, idx, add_generation_prompt=False, **kwargs):
        video_path, start, end = self.annos[idx]
        if 'tmp' in video_path:
            frames = torch.randn(end, 10, 1024).bfloat16()
            frames = frames[start:end]
        else:
            frames = torch.load(video_path, weights_only=True)[start:end]
        conversation = [
            {"role": "system", "content": self.system_prompt},
            {'role': 'stream', 'num_frames': 1, 'learn': False},  # Prompts always come after the current frame.
            {"role": "user", "content": "Please concisely narrate the video in real time."}
        ]
        video_name = os.path.basename(video_path).split('.')[0]
        sample_name = f'{video_name}_{start}_{end}'

        text = self.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=add_generation_prompt)
        return text, frames, sample_name, idx

    def __len__(self):
        return len(self.annos)


def build_egoexo4d_segment_summary_val(**kwargs):
    return EgoExo4DSegmentSummary(**kwargs)

