import os.path as osp
import torch
import torch.utils.data as data
import data.util as util
import torch.nn.functional as F
import random
import cv2
import numpy as np
import glob
import os

def get_videos_dir(list_file_dir):
    list_file = open(list_file_dir, 'r')
    lines = list_file.readlines()

    videos_dir = []
    for line in lines:
        videos_dir.append(line.strip())
    return videos_dir

class VideoSameSizeDataset(data.Dataset):

    def __init__(self, opt):
        super(VideoSameSizeDataset, self).__init__()
        self.opt = opt
        self.cache_data = opt['cache_data']
        self.half_N_frames = opt['N_frames'] // 2
        self.root = opt['dataroot']
        self.data_type = self.opt['data_type']
        self.nIE=self.opt['nIE']
        self.data_info = {'path_LQ': [], 'path_GT': [], 'path_LQ_E': [
        ], 'path_GT_E': [], 'path_mask': [], 'folder': [], 'idx': [], 'border': []}
        if self.data_type == 'lmdb':
            raise ValueError('No need to use LMDB during validation/test.')
        # Generate data info and cache data
        self.imgs_LQ, self.imgs_GT, self.events_LQ, self.events_GT, self.mask = {}, {}, {}, {}, {}

        video_list = get_videos_dir(opt['video_list'])

        print('testing_dir', video_list)

        # read data:
        subfolders = util.glob_file_list(self.root)
        # print(subfolders)
        for subfolder in subfolders:
            subfolder_name = osp.basename(subfolder)  # 视频序号

            if (subfolder_name not in video_list):
                continue

            img_paths_LQ = util.glob_file_list(
                os.path.join(subfolder, 'input_video'))
            events_paths_LQ = util.glob_file_list(
                os.path.join(subfolder, 'input_event'))
            mask_paths = util.glob_file_list(
                os.path.join(subfolder, 'SNR_mask'))

            # mask_paths = util.glob_file_list(
            #     os.path.join(subfolder, 'gray_mask'))


            if len(img_paths_LQ)>30:
                img_paths_LQ = img_paths_LQ[0:30]
                events_paths_LQ = events_paths_LQ[0:30]
                mask_paths = mask_paths[0:30]


            max_idx = len(img_paths_LQ)
            assert max_idx ==  len(events_paths_LQ) == len(mask_paths), 'Different number of images in LQ and GT folders'
            self.data_info['path_LQ'].extend(img_paths_LQ)  # list of path str of images
            self.data_info['path_LQ_E'].extend(events_paths_LQ)
            self.data_info['path_mask'].extend(mask_paths)
            self.data_info['folder'].extend([subfolder_name] * max_idx)
            for i in range(max_idx):
                self.data_info['idx'].append('{}/{}'.format(i, max_idx))

            border_l = [0] * max_idx
            for i in range(self.half_N_frames):
                border_l[i] = 1
                border_l[max_idx - i - 1] = 1
            self.data_info['border'].extend(border_l)

            if self.cache_data:
                self.imgs_LQ[subfolder_name] = img_paths_LQ
                self.events_LQ[subfolder_name] = events_paths_LQ
                self.mask[subfolder_name] = mask_paths
            # print(self.mask[subfolder_name])

    def __getitem__(self, index):
        folder = self.data_info['folder'][index]
        idx, max_idx = self.data_info['idx'][index].split('/')
        idx, max_idx = int(idx), int(max_idx)
        border = self.data_info['border'][index]

        select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
                                           padding=self.opt['padding'])
        imgs_LQ_path = []
        events_LQ_path=[]
        masks_path=[]
        for mm in range(len(select_idx)):
            imgs_LQ_path.append(self.imgs_LQ[folder][select_idx[mm]])
            events_LQ_path.append(self.events_LQ[folder][select_idx[mm]])
            masks_path.append(self.mask[folder][select_idx[mm]])
        
        if not self.nIE:
            events_LQ_path=events_LQ_path[:-1]

        imgs_LQ = util.read_img_seq3(imgs_LQ_path, self.opt['test_size'],'img')
        events_LQ = util.read_img_seq3(events_LQ_path, self.opt['test_size'],'event')
        masks = util.read_img_seq3(masks_path, self.opt['test_size'],'img')

        


        img_LQ_l = list(imgs_LQ.unbind(0))
        events_LQ_l = list(events_LQ.unbind(0))
        masks_l = list(masks.unbind(0))

        

        if self.opt['phase'] == 'train':
            GT_size = self.opt['GT_size']

            _, H, W = img_LQ_l[0].shape  # real img size

            rnd_h = random.randint(0, max(0, H - GT_size))
            rnd_w = random.randint(0, max(0, W - GT_size))
            img_LQ_l = [v[:, rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size]
                        for v in img_LQ_l]          
            events_LQ_l = [v[:, rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size]
                        for v in events_LQ_l]
            masks_l = [v[:, rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size]
                        for v in masks_l]

            # augmentation - flip, rotate


            rlt,events_LQ_l,masks_l=util.augment_ours1(
                [img_LQ_l,events_LQ_l,masks_l],self.opt['use_flip'], self.opt['use_rot'])
            # rlt,events_GT_l,masks_l=util.augment_ours1(
            #     [img_LQ_l,events_GT_l,masks_l],self.opt['use_flip'], self.opt['use_rot'])

            # rlt = util.augment_ours(
            #     img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
            # # rlt = util.augment_torch(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
            img_LQ_l = rlt


        return {
            'LQs': torch.stack(img_LQ_l),  # shape: [N, C, H, W]
            'e_LQS':torch.stack(events_LQ_l),
            'masks':torch.stack(masks_l),
            'folder': folder,
            'idx': self.data_info['idx'][index],
            'border': border
        }

    def __len__(self):
        return len(self.data_info['path_LQ'])
