import os
import cv2
import io
import numpy as np
import torch
import decord
from PIL import Image
from decord import VideoReader, cpu
import random

try:
    from petrel_client.client import Client
    has_client = True
except ImportError:
    has_client = False


class VideoMAE_multi(torch.utils.data.Dataset):
    """Load your own video classification dataset.
    Parameters
    ----------
    root : str, required.
        Path to the root folder storing the dataset.
    setting : str, required.
        A text file describing the dataset, each line per video sample.
        There are three items in each line: (1) video path; (2) video length and (3) video label.
    prefix : str, required.
        The prefix for loading data.
    split : str, required.
        The split character for metadata.
    train : bool, default True.
        Whether to load the training or validation set.
    test_mode : bool, default False.
        Whether to perform evaluation on the test set.
        Usually there is three-crop or ten-crop evaluation strategy involved.
    name_pattern : str, default None.
        The naming pattern of the decoded video frames.
        For example, img_00012.jpg.
    is_color : bool, default True.
        Whether the loaded image is color or grayscale.
    modality : str, default 'rgb'.
        Input modalities, we support only rgb video frames for now.
        Will add support for rgb difference image and optical flow image later.
    num_segments : int, default 1.
        Number of segments to evenly divide the video into clips.
        A useful technique to obtain global video-level information.
        Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
    num_crop : int, default 1.
        Number of crops for each image. default is 1.
        Common choices are three crops and ten crops during evaluation.
    new_length : int, default 1.
        The length of input video clip. Default is a single image, but it can be multiple video frames.
        For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
    new_step : int, default 1.
        Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
        new_step=2 means we will extract a video clip of every other frame.
    temporal_jitter : bool, default False.
        Whether to temporally jitter if new_step > 1.
    video_loader : bool, default False.
        Whether to use video loader to load data.
    use_decord : bool, default True.
        Whether to use Decord video loader to load data. Otherwise load image.
    transform : function, default None.
        A function that takes data and label and transforms them.
    transform_ssv2 : function, default None.
        A function that takes data and label and transforms them.
    data_aug : str, default 'v1'.
        Different types of data augmentation auto. Supports v1, v2, v3 and v4.
    lazy_init : bool, default False.
        If set to True, build a dataset instance without loading any dataset.
    """
    def __init__(self,
                 root,
                 setting,
                 prefix='',
                 split=' ',
                 train=True,
                 test_mode=False,
                 name_pattern='img_%05d.jpg',
                 is_color=True,
                 modality='rgb',
                 num_segments=1,
                 num_crop=1,
                 new_length=1,
                 new_step=1,
                 transform=None,
                 transform_ssv2=None,
                 temporal_jitter=False,
                 video_loader=False,
                 use_decord=True,
                 lazy_init=False,
                 num_sample=1,
                 ):

        super(VideoMAE_multi, self).__init__()
        self.root = root
        self.setting = setting
        self.prefix = prefix
        self.split = split
        self.train = train
        self.test_mode = test_mode
        self.is_color = is_color
        self.modality = modality
        self.num_segments = num_segments
        self.num_crop = num_crop
        self.new_length = new_length
        self.new_step = new_step
        self.skip_length = self.new_length * self.new_step
        self.temporal_jitter = temporal_jitter
        self.name_pattern = name_pattern
        self.video_loader = video_loader
        self.use_decord = use_decord
        self.transform = transform
        self.transform_ssv2 = transform_ssv2
        self.lazy_init = lazy_init
        self.num_sample = num_sample

        assert use_decord == True, "Only support to read video now!"

        # sparse sampling, num_segments != 1
        if self.num_segments != 1:
            print('Use sparse sampling, change frame and stride')
            self.new_length = self.num_segments
            self.skip_length = 1

        self.client = None
        if has_client:
            self.client = Client()

        if not self.lazy_init:
            self.clips = self._make_dataset(root, setting)
            if len(self.clips) == 0:
                raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
                                   "Check your data directory (opt.data-dir)."))

    def __getitem__(self, index):
        idx = 0
        while True:
            try:
                images = None
                if self.use_decord:
                    source, path, total_time, start_time, end_time, target = self.clips[index]
                    if self.video_loader:
                        video_name = os.path.join(self.prefix, path)
                        # print(video_name)
                        if "s3://" in video_name:
                            video_bytes = self.client.get(video_name)
                            # print(f'Got {video_name}')
                            decord_vr = VideoReader(io.BytesIO(video_bytes),
                                                    num_threads=1,
                                                    ctx=cpu(0))
                        else:
                            decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0))
                        duration = len(decord_vr)
                        start_index = 0
                    
                    if total_time!= -1 and start_time != -1 and end_time != -1:
                        fps = duration / total_time
                        duration = int(fps * (end_time - start_time))
                        start_index = int(fps * start_time)
                    segment_indices, skip_offsets = self._sample_train_indices(duration, start_index)
                    images = self._video_TSN_decord_batch_loader(video_name, decord_vr, duration, segment_indices, skip_offsets)
                else:
                    raise NotImplementedError
                   
                if images is not None:
                    break
            except Exception as e:
                print("Failed to load video from {} with error {}".format(
                    video_name, e))
            
            idx += 1
            if idx >= 11:
                idx = 0
                index = random.randint(0, len(self.clips) - 1)
                print(f'retry with new index {index}')
            else:
                print(f'retry with video_name {video_name}')

        if self.num_sample > 1:
            process_data_list = []
            mask_list = []
            for _ in range(self.num_sample):
                if source == "ssv2":
                    process_data, mask = self.transform_ssv2((images, None))
                else:
                    process_data, mask = self.transform((images, None))
                process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1)
                process_data_list.append(process_data)
                mask_list.append(mask)
            return process_data_list, mask_list
        else:
            if source == "ssv2":
                process_data, mask = self.transform_ssv2((images, None)) # T*C,H,W
            else:
                process_data, mask = self.transform((images, None)) # T*C,H,W
            process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1)  # T*C,H,W -> T,C,H,W -> C,T,H,W
            return (process_data, mask)

    def __len__(self):
        return len(self.clips)

    def _make_dataset(self, directory, setting):
        if not os.path.exists(setting):
            raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
        clips = []

        print(f'Load dataset using decord: {self.use_decord}')
        with open(setting) as split_f:
            data = split_f.readlines()
            for line in data:
                line_info = line.split(self.split)
                if len(line_info) < 2:
                    raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
                if self.use_decord:
                    # line format: source, path, total_time, start_time, end_time, target
                    source = line_info[0]
                    path = line_info[1]
                    total_time = float(line_info[2])
                    start_time = float(line_info[3])
                    end_time = float(line_info[4])
                    target = int(line_info[5])
                    item = (source, path, total_time, start_time, end_time, target)
                else:
                    raise NotImplementedError
                
                clips.append(item)
        return clips

    def _sample_train_indices(self, num_frames, start_index=0):
        average_duration = (num_frames - self.skip_length + 1) // self.num_segments
        if average_duration > 0:
            offsets = np.multiply(list(range(self.num_segments)),
                                  average_duration)
            offsets = offsets + np.random.randint(average_duration,
                                                  size=self.num_segments)
        elif num_frames > max(self.num_segments, self.skip_length):
            offsets = np.sort(np.random.randint(
                num_frames - self.skip_length + 1,
                size=self.num_segments))
        else:
            offsets = np.zeros((self.num_segments,))

        if self.temporal_jitter:
            skip_offsets = np.random.randint(
                self.new_step, size=self.skip_length // self.new_step)
        else:
            skip_offsets = np.zeros(
                self.skip_length // self.new_step, dtype=int)
        return offsets + start_index, skip_offsets

    def _get_frame_id_list(self, duration, indices, skip_offsets):
        frame_id_list = []
        for seg_ind in indices:
            offset = int(seg_ind)
            for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
                if offset + skip_offsets[i] <= duration:
                    frame_id = offset + skip_offsets[i] - 1
                else:
                    frame_id = offset - 1
                frame_id_list.append(frame_id)
                if offset + self.new_step < duration:
                    offset += self.new_step
        return frame_id_list

    def _video_TSN_decord_batch_loader(self, video_name, video_reader, duration, indices, skip_offsets):
        sampled_list = []
        frame_id_list = []
        for seg_ind in indices:
            offset = int(seg_ind)
            for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
                if offset + skip_offsets[i] <= duration:
                    frame_id = offset + skip_offsets[i] - 1
                else:
                    frame_id = offset - 1
                frame_id_list.append(frame_id)
                if offset + self.new_step < duration:
                    offset += self.new_step
        try:
            video_data = video_reader.get_batch(frame_id_list).asnumpy()
            sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
        except:
            raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, video_name, duration))
        return sampled_list
