"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
import os
import re

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor

from lavis.datasets.datasets.video_vqa_datasets import VideoQADataset


class COINCLSDataset(VideoQADataset):
    def __init__(
        self,
        vis_processor,
        text_processor,
        vis_root,
        ann_paths,
        num_frames,
        prompt="",
        split="train",
    ):
        self.vis_root = vis_root

        self.gt_dict = {}
        for ann_path in ann_paths:
            self.gt_dict.update(json.load(open(ann_path)))

        self.fps = 10
        self.annotation = {}
        for video_id in self.gt_dict:
            if video_id in ["AfiVmAjfTNs", "cjwtcDKTQM8"]:
                continue
            if split == self.gt_dict[video_id]["subset"]:
                duration = self.gt_dict[video_id]["duration"]
                label = self.gt_dict[video_id]["target"]
                if "num_frames_actual" not in self.gt_dict[video_id]:
                    continue
                frame_length = self.gt_dict[video_id]["num_frames_actual"]
                label_after_process = text_processor(label)
                assert label == label_after_process, "{} not equal to {}".format(
                    label, label_after_process
                )
                self.annotation[video_id] = {
                    "video_id": video_id,
                    "frame_length": frame_length,
                    "label": label_after_process,
                }

        self.video_id_list = list(self.annotation.keys())
        self.video_id_list.sort()

        self.num_frames = num_frames
        self.vis_processor = vis_processor
        self.text_processor = text_processor
        self.prompt = prompt
        # self._add_instance_ids()

    def __getitem__(self, index):
        video_id = self.video_id_list[index]
        ann = self.annotation[video_id]

        # Divide the range into num_frames segments and select a random index from each segment
        segment_list = np.linspace(
            0, ann["frame_length"], self.num_frames + 1, dtype=int
        )
        segment_start_list = segment_list[:-1]
        segment_end_list = segment_list[1:]
        selected_frame_index = []
        for start, end in zip(segment_start_list, segment_end_list):
            if start == end:
                selected_frame_index.append(start)
            else:
                selected_frame_index.append(np.random.randint(start, end))

        frame_list = []
        for frame_index in selected_frame_index:
            frame = Image.open(
                os.path.join(
                    self.vis_root, video_id, "frame{:06d}.jpg".format(frame_index + 1)
                )
            ).convert("RGB")
            frame = pil_to_tensor(frame).to(torch.float32)
            frame_list.append(frame)
        video = torch.stack(frame_list, dim=1)
        video = self.vis_processor(video)

        text_input = self.text_processor("what is the activity in the video?")
        caption = self.text_processor(ann["label"])
        return {
            "image": video,
            "text_input": text_input,
            "text_output": caption,
            "image_id": video_id,
        }

    def __len__(self):
        return len(self.video_id_list)


class COINCLSEvalDataset(COINCLSDataset):
    def __init__(
        self,
        vis_processor,
        text_processor,
        vis_root,
        ann_paths,
        num_frames,
        prompt,
        split="val",
    ):

        super().__init__(
            vis_processor,
            text_processor,
            vis_root,
            ann_paths,
            num_frames,
            prompt,
            split="val",
        )

    def __getitem__(self, index):
        video_id = self.video_id_list[index]
        ann = self.annotation[video_id]

        # Divide the range into num_frames segments and select a random index from each segment
        selected_frame_index = (
            np.rint(np.linspace(0, ann["frame_length"] - 1, self.num_frames))
            .astype(int)
            .tolist()
        )
        frame_list = []
        for frame_index in selected_frame_index:
            frame = Image.open(
                os.path.join(
                    self.vis_root, video_id, "frame{:06d}.jpg".format(frame_index + 1)
                )
            ).convert("RGB")
            frame = pil_to_tensor(frame).to(torch.float32)
            frame_list.append(frame)
        video = torch.stack(frame_list, dim=1)
        video = self.vis_processor(video)

        text_input = self.text_processor("what is the activity in the video?")
        caption = self.text_processor(ann["label"])
        return {
            "image": video,
            "text_input": text_input,
            "prompt": text_input,
            "text_output": caption,
            "image_id": video_id,
        }
