import os.path as osp
import torch
import torch.utils.data as data
import data.util as util
import torch.nn.functional as F
import random
import cv2
import numpy as np
import glob
import os


class VideoSameSizeDataset(data.Dataset):

    def __init__(self, opt):
        super(VideoSameSizeDataset, self).__init__()
        self.opt = opt
        self.cache_data = opt['cache_data']
        self.half_N_frames = opt['N_frames'] // 2
        self.root = opt['dataroot']
        self.data_type = self.opt['data_type']
        self.nIE=self.opt['nIE']
        self.data_info = {'path_LQ': [], 'path_GT': [], 'path_LQ_E': [
        ], 'path_GT_E': [], 'path_mask': [], 'folder': [], 'idx': [], 'border': []}
        if self.data_type == 'lmdb':
            raise ValueError('No need to use LMDB during validation/test.')
        # Generate data info and cache data
        self.imgs_LQ, self.imgs_GT, self.events_LQ, self.events_GT, self.mask = {}, {}, {}, {}, {}

        if opt['testing_dir'] is not None:
            testing_dir = opt['testing_dir']
            testing_dir = testing_dir.split(',')
        else:
            testing_dir = []
        print('testing_dir', testing_dir)

        # read data:
        subfolders = util.glob_file_list(self.root)
        for subfolder in subfolders:
            subfolder_name = osp.basename(subfolder)  # 视频序号

            if not (subfolder_name in testing_dir):
                continue

            img_paths_LQ = util.glob_file_list(
                os.path.join(subfolder, 'input_video'))
            img_paths_GT = util.glob_file_list(
                os.path.join(subfolder, 'GT_video'))
            events_paths_LQ = util.glob_file_list(
                os.path.join(subfolder, 'input_event'))
            events_paths_GT = util.glob_file_list(
                os.path.join(subfolder, 'GT_event'))
            mask_paths = util.glob_file_list(
                os.path.join(subfolder, 'SNR_mask'))

            # mask_paths = util.glob_file_list(
            #     os.path.join(subfolder, 'gray_mask'))
            
            # print(mask_paths)

            img_paths_LQ = img_paths_LQ[0:30]
            img_paths_GT = img_paths_GT[0:30]
            events_paths_LQ = events_paths_LQ[0:30]
            events_paths_GT = events_paths_GT[0:30]
            mask_paths = mask_paths[0:30]
            
            max_idx = len(img_paths_LQ)
            assert max_idx == len(img_paths_GT) == len(events_paths_LQ) == len(
                events_paths_GT) == len(mask_paths), 'Different number of images in LQ and GT folders'
            self.data_info['path_LQ'].extend(img_paths_LQ)  # list of path str of images
            self.data_info['path_GT'].extend(img_paths_GT)
            self.data_info['path_LQ_E'].extend(events_paths_LQ)
            self.data_info['path_GT_E'].extend(events_paths_GT)
            self.data_info['path_mask'].extend(mask_paths)
            self.data_info['folder'].extend([subfolder_name] * max_idx)
            for i in range(max_idx):
                self.data_info['idx'].append('{}/{}'.format(i, max_idx))

            border_l = [0] * max_idx
            for i in range(self.half_N_frames):
                border_l[i] = 1
                border_l[max_idx - i - 1] = 1
            self.data_info['border'].extend(border_l)

            if self.cache_data:
                self.imgs_LQ[subfolder_name] = img_paths_LQ
                self.imgs_GT[subfolder_name] = img_paths_GT
                self.events_LQ[subfolder_name] = events_paths_LQ
                self.events_GT[subfolder_name] = events_paths_GT
                self.mask[subfolder_name] = mask_paths

    def __getitem__(self, index):
        folder = self.data_info['folder'][index]
        idx, max_idx = self.data_info['idx'][index].split('/')
        idx, max_idx = int(idx), int(max_idx)
        border = self.data_info['border'][index]

        select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
                                           padding=self.opt['padding'])
        imgs_LQ_path = []
        events_LQ_path=[]
        events_GT_path=[]
        masks_path=[]
        for mm in range(len(select_idx)):
            imgs_LQ_path.append(self.imgs_LQ[folder][select_idx[mm]])
            events_LQ_path.append(self.events_LQ[folder][select_idx[mm]])
            events_GT_path.append(self.events_GT[folder][select_idx[mm]])
            masks_path.append(self.mask[folder][select_idx[mm]])
        
        if not self.nIE:
            events_LQ_path=events_LQ_path[:-1]
            events_GT_path=events_GT_path[:-1]
        img_GT_path = self.imgs_GT[folder][idx:idx+1]

        imgs_LQ = util.read_img_seq3(imgs_LQ_path, self.opt['test_size'],'np')
        img_GT = util.read_img_seq3(img_GT_path, self.opt['test_size'],'np')
        events_LQ = util.read_img_seq3(events_LQ_path, self.opt['test_size'],'event')
        events_GT= util.read_img_seq3(events_GT_path, self.opt['test_size'],'event')
        masks = util.read_img_seq3(masks_path, self.opt['test_size'],'img')

        

        img_GT = img_GT[0]

        img_LQ_l = list(imgs_LQ.unbind(0))
        events_LQ_l = list(events_LQ.unbind(0))
        events_GT_l = list(events_GT.unbind(0))
        masks_l = list(masks.unbind(0))


        return {
            'LQs': torch.stack(img_LQ_l),  # shape: [N, C, H, W]
            'GT': img_GT,
            'e_LQS':torch.stack(events_LQ_l),
            'e_GTS':torch.stack(events_GT_l),
            'e_ref':torch.stack(events_LQ_l),
            'masks':torch.stack(masks_l),
            'folder': folder,
            'idx': self.data_info['idx'][index],
            'border': border
        }

    def __len__(self):
        return len(self.data_info['path_GT'])
