


import os.path as osp
from bisect import bisect_right

import torch
import torch.utils.data as data
import numpy as np

from .cls_dataset import DATA_LAYERS as registry        


@registry.register('LSTRTHUMOS')
@registry.register('LSTRTVSeries')
class LSTRDataLayer(data.Dataset):

    def __init__(self, video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, memory,num_sel_frames,phase='train'):
        self.data_root = data_root+'train'
        self.visual_feature = visual_feature
        self.motion_feature = motion_feature
        self.target_perframe = target_perframe
        self.long_memory_length = long_memory_length
        self.long_memory_sample_rate = long_memory_sample_rate
        self.long_memory_num_samples = long_memory_num_samples
        self.work_memory_length = work_memory_length
        self.work_memory_sample_rate = work_memory_sample_rate
        self.work_memory_num_samples = work_memory_num_samples
        self.training = phase == 'train'
        self.video_data=video_data
        self.memory=memory
        self.num_sel_frames=num_sel_frames
        self._init_dataset()

    def shuffle(self):
        self._init_dataset()

    def _init_dataset(self):
        self.inputs = []
        
        self.new_video_data=self.video_data
        for class_n,sessions in self.new_video_data.items():
            for session in sessions:
                is_memory=0
                target = np.load(osp.join(self.data_root,class_n, session + '_target.npy'))
                seed = np.random.randint(self.work_memory_length) if self.training else 0
                for work_start, work_end in zip(
                    range(seed, target.shape[0], self.work_memory_length),
                    range(seed + self.work_memory_length, target.shape[0], self.work_memory_length)):  
                    self.inputs.append([
                        session, class_n,work_start, work_end, target[work_start: work_end],None,is_memory
                    ])
        if self.memory:
            for class_n,sessions in self.memory.items():
                for session in sessions:
                    is_memory = 1
                    target = np.load(osp.join(self.data_root, class_n, session + '_target.npy'))
                    work_end = target.shape[0]
                    work_start = target.shape[0] - 32
                    self.inputs.append([
                        session, class_n, work_start, work_end, target[work_start: work_end], target.shape[0],is_memory])


    def segment_sampler(self, start, end, num_samples):
        indices = np.linspace(start, end, num_samples)
        return np.sort(indices).astype(np.int32)
    def down_sampler(self, offsets,num_samples,num_sel_frames):
        indices=np.zeros(num_samples)
        tick=int(num_samples/num_sel_frames)
        if num_sel_frames<num_samples:
            for i in range(num_sel_frames):
                for j in range(tick):
                    indices[i*tick+j]=offsets[i]
            res=num_samples-num_sel_frames*tick
            for i in range(res):
                indices[num_sel_frames*tick+i]=offsets[-1]

        else:
            for i in range(num_samples):
                indices[i]=offsets[int(i*num_sel_frames/num_samples)]
        return np.sort(indices).astype(np.int32)
    def uniform_sampler(self, start, end, num_samples, sample_rate):
        indices = np.arange(start, end + 1)[::sample_rate]
        padding = num_samples - indices.shape[0]
        if padding > 0:
            indices = np.concatenate((np.zeros(padding), indices))
        return np.sort(indices).astype(np.int32)

    def __getitem__(self, index):  
        session, class_n,work_start, work_end, target ,num_frames,is_memory= self.inputs[index]
        if is_memory==0:
            visual_inputs = np.load(
                osp.join(self.data_root,class_n, session + '_rgb.npy'), mmap_mode='r')
            motion_inputs = np.load(
                osp.join(self.data_root,class_n, session + '_flow.npy'), mmap_mode='r')
            
            target = target[::self.work_memory_sample_rate]  
            
            work_indices = np.arange(work_start, work_end).clip(0)
            work_indices = work_indices[::self.work_memory_sample_rate]
            work_visual_inputs = visual_inputs[work_indices]
            work_motion_inputs = motion_inputs[work_indices]

            
            if self.long_memory_num_samples > 0:
                long_start, long_end = max(0, work_start - self.long_memory_length), work_start - 1
                
                if self.training:
                    long_indices = self.segment_sampler(
                        long_start,
                        long_end,
                        self.long_memory_num_samples).clip(0)  
                else:
                    long_indices = self.uniform_sampler(
                        long_start,
                        long_end,
                        self.long_memory_num_samples,
                        self.long_memory_sample_rate).clip(0)
                long_visual_inputs = visual_inputs[long_indices] 
                long_motion_inputs = motion_inputs[long_indices]

                
                memory_key_padding_mask = np.zeros(long_indices.shape[0]) 
                last_zero = bisect_right(long_indices, 0) - 1  
                if last_zero > 0:
                    memory_key_padding_mask[:last_zero] = float('-inf')
            else:
                long_visual_inputs = None
                long_motion_inputs = None
                memory_key_padding_mask = None

            
            if long_visual_inputs is not None  and long_motion_inputs is not None:
                fusion_visual_inputs = np.concatenate((long_visual_inputs, work_visual_inputs))  
                fusion_motion_inputs = np.concatenate((long_motion_inputs, work_motion_inputs))
            else:
                fusion_visual_inputs = work_visual_inputs
                fusion_motion_inputs = work_motion_inputs

            
            fusion_visual_inputs = torch.as_tensor(fusion_visual_inputs.astype(np.float32))
            fusion_motion_inputs = torch.as_tensor(fusion_motion_inputs.astype(np.float32))
            target = torch.as_tensor(target.astype(np.float32))

            if memory_key_padding_mask is not None:
                memory_key_padding_mask = torch.as_tensor(memory_key_padding_mask.astype(np.float32))
                return fusion_visual_inputs, fusion_motion_inputs, memory_key_padding_mask, target
            else:
                return fusion_visual_inputs, fusion_motion_inputs, target
        else:
            visual_inputs = np.load(
                osp.join(self.data_root, class_n, session + '_rgb.npy'), mmap_mode='r')
            motion_inputs = np.load(
                osp.join(self.data_root, class_n, session + '_flow.npy'), mmap_mode='r')

            
            target = target[::self.work_memory_sample_rate]  

            
            work_indices = np.arange(work_start, work_end).clip(0)
            work_indices = work_indices[::self.work_memory_sample_rate]
            work_visual_inputs = visual_inputs[work_indices]
            work_motion_inputs = motion_inputs[work_indices]  

            
            if self.long_memory_num_samples > 0:
                tick = (num_frames) / float(self.num_sel_frames)  
                offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_sel_frames)])  
                long_visual_inputs = [visual_inputs[i] for i in offsets]
                long_motion_inputs = [motion_inputs[i] for i in offsets]
                
                long_indices = self.down_sampler(offsets, self.long_memory_num_samples, self.num_sel_frames).clip(
                    0)  
                long_visual_inputs = visual_inputs[long_indices]  
                long_motion_inputs = motion_inputs[long_indices]

                
                memory_key_padding_mask = np.zeros(long_indices.shape[0])  
                last_zero = bisect_right(long_indices,
                                         0) - 1  
                if last_zero > 0:
                    memory_key_padding_mask[:last_zero] = float('-inf')
            else:
                long_visual_inputs = None
                long_motion_inputs = None
                memory_key_padding_mask = None

            
            if long_visual_inputs is not None and long_motion_inputs is not None:
                fusion_visual_inputs = np.concatenate(
                    (long_visual_inputs, work_visual_inputs))  
                fusion_motion_inputs = np.concatenate((long_motion_inputs, work_motion_inputs))
            else:
                fusion_visual_inputs = work_visual_inputs
                fusion_motion_inputs = work_motion_inputs

            
            fusion_visual_inputs = torch.as_tensor(fusion_visual_inputs.astype(np.float32))
            fusion_motion_inputs = torch.as_tensor(fusion_motion_inputs.astype(np.float32))
            target = torch.as_tensor(target.astype(np.float32))

            if memory_key_padding_mask is not None:
                memory_key_padding_mask = torch.as_tensor(memory_key_padding_mask.astype(np.float32))
                return fusion_visual_inputs, fusion_motion_inputs, memory_key_padding_mask, target  
            else:
                return fusion_visual_inputs, fusion_motion_inputs, target

    def __len__(self):
        return len(self.inputs)






@registry.register('LSTRBatchInferenceTHUMOS')
@registry.register('LSTRBatchInferenceTVSeries')
class LSTRBatchInferenceDataLayer(data.Dataset):

    def __init__(self, video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, phase='test'):
        self.data_root = data_root+phase
        self.visual_feature = visual_feature
        self.motion_feature = motion_feature
        self.target_perframe = target_perframe
        self.long_memory_length = long_memory_length
        self.long_memory_sample_rate = long_memory_sample_rate
        self.long_memory_num_samples = long_memory_num_samples
        self.work_memory_length = work_memory_length
        self.work_memory_sample_rate = work_memory_sample_rate
        self.work_memory_num_samples = work_memory_num_samples
        self.video_data=video_data


        assert phase == 'test', 'phase must be `test` for batch inference, got {}'
        
        self.inputs = []
        classes=list(video_data.keys())
        sessions_s=list(video_data.values())
        sessions=sessions_s[0]
        class_n=classes[0]
        for session in sessions:
            target = np.load(osp.join(self.data_root,class_n, session + '_target.npy'))
            for work_start, work_end in zip(
                range(0, target.shape[0] + 1),
                range(self.work_memory_length, target.shape[0] + 1)):
                self.inputs.append([
                    session, class_n,work_start, work_end, target[work_start: work_end], target.shape[0]
                ])
    def uniform_sampler(self, start, end, num_samples, sample_rate):
        indices = np.arange(start, end + 1)[::sample_rate]
        padding = num_samples - indices.shape[0]
        if padding > 0:
            indices = np.concatenate((np.zeros(padding), indices))
        return np.sort(indices).astype(np.int32)

    def __getitem__(self, index):
        session,class_n, work_start, work_end, target, num_frames = self.inputs[index]

        visual_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_rgb.npy'), mmap_mode='r')
        motion_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_flow.npy'), mmap_mode='r')

        
        target = target[::self.work_memory_sample_rate]

        
        work_indices = np.arange(work_start, work_end).clip(0)
        work_indices = work_indices[::self.work_memory_sample_rate]
        work_visual_inputs = visual_inputs[work_indices]
        work_motion_inputs = motion_inputs[work_indices]

        
        if self.long_memory_num_samples > 0:
            long_start, long_end = max(0, work_start - self.long_memory_length), work_start - 1
            long_indices = self.uniform_sampler(
                long_start,
                long_end,
                self.long_memory_num_samples,
                self.long_memory_sample_rate).clip(0)
            long_visual_inputs = visual_inputs[long_indices]
            long_motion_inputs = motion_inputs[long_indices]

            
            memory_key_padding_mask = np.zeros(long_indices.shape[0])
            last_zero = bisect_right(long_indices, 0) - 1
            if last_zero > 0:
                memory_key_padding_mask[:last_zero] = float('-inf')
        else:
            long_visual_inputs = None
            long_motion_inputs = None
            memory_key_padding_mask = None

        
        if long_visual_inputs is not None and long_motion_inputs is not None:
            fusion_visual_inputs = np.concatenate((long_visual_inputs, work_visual_inputs))
            fusion_motion_inputs = np.concatenate((long_motion_inputs, work_motion_inputs))
        else:
            fusion_visual_inputs = work_visual_inputs
            fusion_motion_inputs = work_motion_inputs

        
        fusion_visual_inputs = torch.as_tensor(fusion_visual_inputs.astype(np.float32))
        fusion_motion_inputs = torch.as_tensor(fusion_motion_inputs.astype(np.float32))
        target = torch.as_tensor(target.astype(np.float32))

        if memory_key_padding_mask is not None:
            memory_key_padding_mask = torch.as_tensor(memory_key_padding_mask.astype(np.float32))
            return (fusion_visual_inputs, fusion_motion_inputs, memory_key_padding_mask, target,
                    session, work_indices, num_frames)
        else:
            return (fusion_visual_inputs, fusion_motion_inputs, target,
                    session, work_indices, num_frames)

    def __len__(self):
        return len(self.inputs)
