

import os.path as osp
from bisect import bisect_right
import torch
import torch.utils.data as data
import numpy as np
import os




class LSTRDataLayer(data.Dataset):

    def __init__(self, video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, memory,num_sel_frames,phase='train'):
        self.data_root = data_root+'train'
        self.visual_feature = visual_feature
        self.motion_feature = motion_feature
        self.target_perframe = target_perframe
        self.long_memory_length = long_memory_length
        self.long_memory_sample_rate = long_memory_sample_rate
        self.long_memory_num_samples = long_memory_num_samples
        self.work_memory_length = work_memory_length
        self.work_memory_sample_rate = work_memory_sample_rate
        self.work_memory_num_samples = work_memory_num_samples
        self.training = phase == 'train'
        self.video_data=video_data
        self.memory=memory
        self.num_sel_frames=num_sel_frames
        self._init_dataset()

    def shuffle(self):
        self._init_dataset()

    def _init_dataset(self):
        self.inputs = []
        self.new_video_data={**self.video_data,**self.memory}   
        for class_n,sessions in self.new_video_data.items():
            for session in sessions:
                target = np.load(osp.join(self.data_root,class_n, session + '_target.npy'))
                seed = np.random.randint(self.work_memory_length) if self.training else 0
                work_end=target.shape[0]
                work_start=target.shape[0]-32
                self.inputs.append([
                        session, class_n,work_start, work_end, target[work_start: work_end],target.shape[0]])


    def segment_sampler(self, start, end, num_samples):
        indices = np.linspace(start, end, num_samples)
        return np.sort(indices).astype(np.int32)
    def down_sampler(self, offsets,num_samples,num_sel_frames):
        indices=np.zeros(num_samples)
        tick=int(num_samples/num_sel_frames)
        if num_sel_frames<num_samples:
            for i in range(num_sel_frames):
                for j in range(tick):
                    indices[i*tick+j]=offsets[i]
            res=num_samples-num_sel_frames*tick
            for i in range(res):
                indices[num_sel_frames*tick+i]=offsets[-1]

        else:
            for i in range(num_samples):
                indices[i]=offsets[int(i*num_sel_frames/num_samples)]
        return np.sort(indices).astype(np.int32)


    def uniform_sampler(self, start, end, num_samples, sample_rate):
        indices = np.arange(start, end + 1)[::sample_rate]
        padding = num_samples - indices.shape[0]
        if padding > 0:
            indices = np.concatenate((np.zeros(padding), indices))
        return np.sort(indices).astype(np.int32)
    def __getitem__(self, index):  
        session, class_n,work_start, work_end, target,num_frames= self.inputs[index]

        visual_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_rgb.npy'), mmap_mode='r')
        motion_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_flow.npy'), mmap_mode='r')

        
        target = target[::self.work_memory_sample_rate]  

        
        work_indices = np.arange(work_start, work_end).clip(0)
        work_indices = work_indices[::self.work_memory_sample_rate]
        work_visual_inputs = visual_inputs[work_indices]
        work_motion_inputs = motion_inputs[work_indices]

        
        if self.long_memory_num_samples > 0:
            tick = (num_frames) / float(self.num_sel_frames)  
            offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_sel_frames)])  
            long_visual_inputs = [visual_inputs[i] for i in offsets]
            long_motion_inputs = [motion_inputs[i] for i in offsets]
            
            long_indices = self.down_sampler(offsets,self.long_memory_num_samples,self.num_sel_frames).clip(0)  
            long_visual_inputs = visual_inputs[long_indices] 
            long_motion_inputs = motion_inputs[long_indices]

            
            memory_key_padding_mask = np.zeros(long_indices.shape[0]) 
            last_zero = bisect_right(long_indices, 0) - 1  
            if last_zero > 0:
                memory_key_padding_mask[:last_zero] = float('-inf')
        else:
            long_visual_inputs = None
            long_motion_inputs = None
            memory_key_padding_mask = None

        
        if long_visual_inputs is not None  and long_motion_inputs is not None:
            fusion_visual_inputs = np.concatenate((long_visual_inputs, work_visual_inputs))  
            fusion_motion_inputs = np.concatenate((long_motion_inputs, work_motion_inputs))
        else:
            fusion_visual_inputs = work_visual_inputs
            fusion_motion_inputs = work_motion_inputs

        
        fusion_visual_inputs = torch.as_tensor(fusion_visual_inputs.astype(np.float32))
        fusion_motion_inputs = torch.as_tensor(fusion_motion_inputs.astype(np.float32))
        target = torch.as_tensor(target.astype(np.float32))

        if memory_key_padding_mask is not None:
            memory_key_padding_mask = torch.as_tensor(memory_key_padding_mask.astype(np.float32))
            return fusion_visual_inputs, fusion_motion_inputs, memory_key_padding_mask, target
        else:
            return fusion_visual_inputs, fusion_motion_inputs, target

    def __len__(self):
        return len(self.inputs)








class LSTRBatchInferenceDataLayer(data.Dataset):

    def __init__(self, video_data,data_root,
                        visual_feature,motion_feature,target_perframe,long_memory_length,
                        long_memory_sample_rate,long_memory_num_samples,work_memory_length,
                        work_memory_sample_rate,work_memory_num_samples, phase='test'):
        self.data_root = data_root+phase
        self.visual_feature = visual_feature
        self.motion_feature = motion_feature
        self.target_perframe = target_perframe
        self.long_memory_length = long_memory_length
        self.long_memory_sample_rate = long_memory_sample_rate
        self.long_memory_num_samples = long_memory_num_samples
        self.work_memory_length = work_memory_length
        self.work_memory_sample_rate = work_memory_sample_rate
        self.work_memory_num_samples = work_memory_num_samples
        self.video_data=video_data


        assert phase == 'test', 'phase must be `test` for batch inference, got {}'
        
        self.inputs = []
        classes=list(video_data.keys())
        sessions_s=list(video_data.values())
        sessions=sessions_s[0]
        class_n=classes[0]
        for session in sessions:
            target = np.load(osp.join(self.data_root,class_n, session + '_target.npy'))
            for work_start, work_end in zip(
                range(0, target.shape[0] + 1),
                range(self.work_memory_length, target.shape[0] + 1)):
                self.inputs.append([
                    session, class_n,work_start, work_end, target[work_start: work_end], target.shape[0]
                ])
    def uniform_sampler(self, start, end, num_samples, sample_rate):
        indices = np.arange(start, end + 1)[::sample_rate]
        padding = num_samples - indices.shape[0]
        if padding > 0:
            indices = np.concatenate((np.zeros(padding), indices))
        return np.sort(indices).astype(np.int32)

    def __getitem__(self, index):
        session,class_n, work_start, work_end, target, num_frames = self.inputs[index]

        visual_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_rgb.npy'), mmap_mode='r')
        motion_inputs = np.load(
            osp.join(self.data_root,class_n, session + '_flow.npy'), mmap_mode='r')

        
        target = target[::self.work_memory_sample_rate]

        
        work_indices = np.arange(work_start, work_end).clip(0)
        work_indices = work_indices[::self.work_memory_sample_rate]
        work_visual_inputs = visual_inputs[work_indices]
        work_motion_inputs = motion_inputs[work_indices]

        
        if self.long_memory_num_samples > 0:
            tick = (num_frames) / float(self.num_sel_frames)  
            offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_sel_frames)])  
            long_start = [visual_inputs[i] for i in offsets]
            long_end = [motion_inputs[i] for i in offsets]
            long_indices = self.uniform_sampler(
                long_start,
                long_end,
                self.long_memory_num_samples,
                self.long_memory_sample_rate).clip(0)
            long_visual_inputs = visual_inputs[long_indices]
            long_motion_inputs = motion_inputs[long_indices]

            
            memory_key_padding_mask = np.zeros(long_indices.shape[0])
            last_zero = bisect_right(long_indices, 0) - 1
            if last_zero > 0:
                memory_key_padding_mask[:last_zero] = float('-inf')
        else:
            long_visual_inputs = None
            long_motion_inputs = None
            memory_key_padding_mask = None

        
        if long_visual_inputs is not None and long_motion_inputs is not None:
            fusion_visual_inputs = np.concatenate((long_visual_inputs, work_visual_inputs))
            fusion_motion_inputs = np.concatenate((long_motion_inputs, work_motion_inputs))
        else:
            fusion_visual_inputs = work_visual_inputs
            fusion_motion_inputs = work_motion_inputs

        
        fusion_visual_inputs = torch.as_tensor(fusion_visual_inputs.astype(np.float32))
        fusion_motion_inputs = torch.as_tensor(fusion_motion_inputs.astype(np.float32))
        target = torch.as_tensor(target.astype(np.float32))

        if memory_key_padding_mask is not None:
            memory_key_padding_mask = torch.as_tensor(memory_key_padding_mask.astype(np.float32))
            return (fusion_visual_inputs, fusion_motion_inputs, memory_key_padding_mask, target,
                    session, work_indices, num_frames)
        else:
            return (fusion_visual_inputs, fusion_motion_inputs, target,
                    session, work_indices, num_frames)

    def __len__(self):
        return len(self.inputs)
