


import os.path as osp
from bisect import bisect_right

import torch
import torch.utils.data as data
import numpy as np

from .cls_dataset import DATA_LAYERS as registry        


@registry.register('LSTRTHUMOS')
@registry.register('LSTRTVSeries')
class LSTRDataLayer(data.Dataset):

    def __init__(self, video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, memory,phase='train'):
        self.data_root = data_root+'train'
        self.visual_feature = visual_feature
        self.motion_feature = motion_feature
        self.target_perframe = target_perframe
        self.long_memory_length = long_memory_length
        self.long_memory_sample_rate = long_memory_sample_rate
        self.long_memory_num_samples = long_memory_num_samples
        self.work_memory_length = work_memory_length
        self.work_memory_sample_rate = work_memory_sample_rate
        self.work_memory_num_samples = work_memory_num_samples
        self.training = phase == 'train'
        self.video_data=video_data
        self.memory=memory
        self._init_dataset()

    def shuffle(self):
        self._init_dataset()

    def _init_dataset(self):
        self.inputs = []
        self.new_video_data={**self.video_data,**self.memory}   
        for class_n,sessions in self.new_video_data.items():
            for session in sessions:
                target = np.load(osp.join(self.data_root,class_n, session + '_target.npy'))
                seed = np.random.randint(self.work_memory_length) if self.training else 0
                for work_start, work_end in zip(
                    range(seed, target.shape[0], self.work_memory_length),
                    range(seed + self.work_memory_length, target.shape[0], self.work_memory_length)):
                    self.inputs.append([
                        session, class_n,work_start, work_end, target[work_start: work_end],
                    ])

    def segment_sampler(self, start, end, num_samples):
        indices = np.linspace(start, end, num_samples)
        return np.sort(indices).astype(np.int32)

    def uniform_sampler(self, start, end, num_samples, sample_rate):
        indices = np.arange(start, end + 1)[::sample_rate]
        padding = num_samples - indices.shape[0]
        if padding > 0:
            indices = np.concatenate((np.zeros(padding), indices))
        return np.sort(indices).astype(np.int32)

    def __getitem__(self, index):
        session, class_n,work_start, work_end, target = self.inputs[index]

        visual_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_rgb.npy'), mmap_mode='r')
        motion_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_flow.npy'), mmap_mode='r')

        
        target = target[::self.work_memory_sample_rate]

        
        work_indices = np.arange(work_start, work_end).clip(0)
        work_indices = work_indices[::self.work_memory_sample_rate]
        work_visual_inputs = visual_inputs[work_indices]
        work_motion_inputs = motion_inputs[work_indices]

        
        if self.long_memory_num_samples > 0:
            long_start, long_end = max(0, work_start - self.long_memory_length), work_start - 1
            if self.training:
                long_indices = self.segment_sampler(
                    long_start,
                    long_end,
                    self.long_memory_num_samples).clip(0)
            else:
                long_indices = self.uniform_sampler(
                    long_start,
                    long_end,
                    self.long_memory_num_samples,
                    self.long_memory_sample_rate).clip(0)
            long_visual_inputs = visual_inputs[long_indices]
            long_motion_inputs = motion_inputs[long_indices]

            
            memory_key_padding_mask = np.zeros(long_indices.shape[0])
            last_zero = bisect_right(long_indices, 0) - 1
            if last_zero > 0:
                memory_key_padding_mask[:last_zero] = float('-inf')
        else:
            long_visual_inputs = None
            long_motion_inputs = None
            memory_key_padding_mask = None

        
        if long_visual_inputs is not None and long_motion_inputs is not None:
            fusion_visual_inputs = np.concatenate((long_visual_inputs, work_visual_inputs))
            fusion_motion_inputs = np.concatenate((long_motion_inputs, work_motion_inputs))
        else:
            fusion_visual_inputs = work_visual_inputs
            fusion_motion_inputs = work_motion_inputs

        
        fusion_visual_inputs = torch.as_tensor(fusion_visual_inputs.astype(np.float32))
        fusion_motion_inputs = torch.as_tensor(fusion_motion_inputs.astype(np.float32))
        target = torch.as_tensor(target.astype(np.float32))

        if memory_key_padding_mask is not None:
            memory_key_padding_mask = torch.as_tensor(memory_key_padding_mask.astype(np.float32))
            return fusion_visual_inputs, fusion_motion_inputs, memory_key_padding_mask, target
        else:
            return fusion_visual_inputs, fusion_motion_inputs, target

    def __len__(self):
        return len(self.inputs)






@registry.register('LSTRBatchInferenceTHUMOS')
@registry.register('LSTRBatchInferenceTVSeries')
class LSTRBatchInferenceDataLayer(data.Dataset):

    def __init__(self, video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, phase='test'):
        self.data_root = data_root+phase
        self.visual_feature = visual_feature
        self.motion_feature = motion_feature
        self.target_perframe = target_perframe
        self.long_memory_length = long_memory_length
        self.long_memory_sample_rate = long_memory_sample_rate
        self.long_memory_num_samples = long_memory_num_samples
        self.work_memory_length = work_memory_length
        self.work_memory_sample_rate = work_memory_sample_rate
        self.work_memory_num_samples = work_memory_num_samples
        self.video_data=video_data


        assert phase == 'test', 'phase must be `test` for batch inference, got {}'
        
        self.inputs = []
        classes=list(video_data.keys())
        sessions_s=list(video_data.values())
        sessions=sessions_s[0]
        class_n=classes[0]
        for session in sessions:
            target = np.load(osp.join(self.data_root,class_n, session + '_target.npy'))
            for work_start, work_end in zip(
                range(0, target.shape[0] + 1),
                range(self.work_memory_length, target.shape[0] + 1)):
                self.inputs.append([
                    session, class_n,work_start, work_end, target[work_start: work_end], target.shape[0]
                ])
    def uniform_sampler(self, start, end, num_samples, sample_rate):
        indices = np.arange(start, end + 1)[::sample_rate]
        padding = num_samples - indices.shape[0]
        if padding > 0:
            indices = np.concatenate((np.zeros(padding), indices))
        return np.sort(indices).astype(np.int32)

    def __getitem__(self, index):
        session,class_n, work_start, work_end, target, num_frames = self.inputs[index]

        visual_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_rgb.npy'), mmap_mode='r')
        motion_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_flow.npy'), mmap_mode='r')

        
        target = target[::self.work_memory_sample_rate]

        
        work_indices = np.arange(work_start, work_end).clip(0)
        work_indices = work_indices[::self.work_memory_sample_rate]
        work_visual_inputs = visual_inputs[work_indices]
        work_motion_inputs = motion_inputs[work_indices]

        
        if self.long_memory_num_samples > 0:
            long_start, long_end = max(0, work_start - self.long_memory_length), work_start - 1
            long_indices = self.uniform_sampler(
                long_start,
                long_end,
                self.long_memory_num_samples,
                self.long_memory_sample_rate).clip(0)
            long_visual_inputs = visual_inputs[long_indices]
            long_motion_inputs = motion_inputs[long_indices]

            
            memory_key_padding_mask = np.zeros(long_indices.shape[0])
            last_zero = bisect_right(long_indices, 0) - 1
            if last_zero > 0:
                memory_key_padding_mask[:last_zero] = float('-inf')
        else:
            long_visual_inputs = None
            long_motion_inputs = None
            memory_key_padding_mask = None

        
        if long_visual_inputs is not None and long_motion_inputs is not None:
            fusion_visual_inputs = np.concatenate((long_visual_inputs, work_visual_inputs))
            fusion_motion_inputs = np.concatenate((long_motion_inputs, work_motion_inputs))
        else:
            fusion_visual_inputs = work_visual_inputs
            fusion_motion_inputs = work_motion_inputs

        
        fusion_visual_inputs = torch.as_tensor(fusion_visual_inputs.astype(np.float32))
        fusion_motion_inputs = torch.as_tensor(fusion_motion_inputs.astype(np.float32))
        target = torch.as_tensor(target.astype(np.float32))

        if memory_key_padding_mask is not None:
            memory_key_padding_mask = torch.as_tensor(memory_key_padding_mask.astype(np.float32))
            return (fusion_visual_inputs, fusion_motion_inputs, memory_key_padding_mask, target,
                    session, work_indices, num_frames)
        else:
            return (fusion_visual_inputs, fusion_motion_inputs, target,
                    session, work_indices, num_frames)

    def __len__(self):
        return len(self.inputs)
