import torch
import lmdb
import pickle as pkl
import os
import numpy as np
import random

""" group2: change number_per_vid to stride to get windows in each video.
"""

################## Dataloader
def collate_fn_override(data):
    """
       data:
    """
    data = list(filter(lambda x: x is not None, data))
    data_arr, count, labels, clip_length, start, video_id, labels_present_arr, aug_chunk_size = zip(*data)

    return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(clip_length), \
           torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(aug_chunk_size, dtype=torch.int)

class Assembly101(torch.utils.data.Dataset):  # build dataset here
    def __init__(self, args, fold):
        args_common = args[0]
        args_diff = args[1]

        fold_file_name = args_diff.vid_list_path
        self.fold = fold
        self.base_dir_name = args_diff.features_path ##'/data/peiyao/all_data/Assembly101/TSM_features'
        self.frames_format = "/{}/{}_{:010d}.jpg"
        self.ground_truth_files_dir = args_diff.gt_path ##'/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_labels/'
        self.validation = True if fold == 'val' else False
        self.split = args_diff.split
        self.VIEWS = args_diff.VIEWS ##list
        self.env = {view: lmdb.open(f'{args_diff.features_path}/{view}', readonly=True, lock=False) for view in self.VIEWS}
        # a  = lmdb.open('/data/peiyao/all_data/Assembly101/TSM_features/C10095_rgb', readonly=True, lock=False)

        self.win = args_common.train_win
        self.win_dis = args_common.win_dis
        self.stride = args_common.stride
        self.dilation = args_common.dilation
        self.num_in_group = args_common.num_in_group
        self.reduce_data = args_common.reduce_data

        with open('data/statistic_input.pkl', 'rb') as f:
            self.statistic = pkl.load(f)
        # self.data = self.make_data_set(fold_file_name)
        self.data_align = self.make_data_set(fold_file_name)
        self.data_pairs = self.make_clip_pair_group(self.data_align)

    def read_files(self, list_files, fold_file_name):
        data = []
        for file in list_files: ##file:train_coarse_assembly.txt
            lines = open(fold_file_name + file).readlines() ##open the text ground-truth
            for l in lines:
                data.append(l.split('\t')[0])
        return data

    def make_clip_pair_group(self, data_align):
        pairs_collect = []
        num_video = 0
        win = self.win  # 20 the length of clip, aug here for more win
        stride = self.stride
        num_in_group = self.num_in_group  # the pair in nearest window for difficult case
        idx = 0
        pair_group = []

        for video_id in data_align:
            # self.statistic[key_id]
            view_list = data_align[video_id]["view_list"]

            view_c = [key for key in view_list if key.startswith('C')]
            view_h = [key for key in view_list if key.startswith('H')]
            # if key_id =="nusar-2021_action_both_9061-c13d_9061_user_id_2021-02-09_143830":
            #     a = 1
            num_c = len(view_c)
            num_h = len(view_h)

            if len(view_c) == 0 or len(view_h) == 0:
                print('Becuase of the lacking of view, remove the file:' + video_id)
                continue

            span_0 = data_align[video_id]["st_frame"]
            span_1 = data_align[video_id]["end_frame"]
            if span_0>= span_1:
                a =1
            assert span_0 < span_1

            ####============from video to generate many window id (just from the span_0 and span_1)
            # ## change the frame_ids into win for the next indx
            # !! augment win here as the previous code
            if np.random.randint(low=0, high=2) == 0 and (not self.validation):
                max_win= int(1.0 * self.win / 0.5)
                min_win= int(1.0 * self.win / 2)
                aug_win = int(np.exp(np.random.uniform(low=np.log(min_win), high=np.log(max_win))))
            else:
                aug_win = self.win

            vid_len = span_1 - span_0
            dilation_rate = self.dilation
            frame_ids = np.arange(span_0, span_1)
            frame_ids = torch.from_numpy(frame_ids.astype(np.float32))
            Unfold = torch.nn.Unfold(kernel_size = (1, aug_win), dilation= (1, dilation_rate), padding=0, stride=(1,stride))
            unfold_frame = Unfold(frame_ids[None, None, None, :]).squeeze(0).transpose(1,0) # [1,win,num]
            num_win = unfold_frame.shape[0]
            unfold_frame = unfold_frame.numpy().astype(np.int64) # [n_win, win]

            ##=======pair c/h view
            h_view_id = np.random.randint(0, num_h, size=num_c) # each cview to generate a hview
            for c_id in range(num_c): #cview
                h_id = h_view_id[c_id] # hview
                for i in range(self.win_dis): # start of each group win_id, window win_dis to in a group
                    for win_id in range(i, num_win, self.win_dis): # from start to end, find all windows in a win_dis
                        win_clip = unfold_frame[win_id, :]
                        win_start = win_clip[0] # win_start
                        win_end = win_clip[-1] # win_end
                        assert win_end + 1 - win_start == aug_win
                        assert win_end <= span_1
                        clip_pair = {'view_c': view_c[c_id], 'view_h': view_h[h_id], 'st_frame': win_start,
                                     'end_frame': win_end,
                                     'video_id': video_id, 'aug_win': aug_win}
                        pair_group.append(clip_pair)
                        idx += 1 # mark the number in a group

                        if idx == num_in_group: # if the final group is not reach num_in_group, then ignore
                            pairs_collect.append(pair_group)
                            idx = 0
                            pair_group = []
                    a =1

            num_video += 1
        print('============================== Prepare Assembly101 training dataset START! ============================= ')
        print('(Assembly101 are from {} videos, the win length is {}, {} clip pairs in a group.)'.format( num_video, win, num_in_group))
        print('================== DONE for Assembly101 ! {}:The number of aligned pairs  is {}. ====================='.format(self.fold, len(pairs_collect)*num_in_group))

        return pairs_collect

    def make_data_set(self, fold_file_name): #fold_file_name: '/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_splits/'
        #
        if self.fold == 'train':
            if self.split == 'train_val':
                files = ['train_coarse_assembly.txt', 'train_coarse_disassembly.txt', 'val_coarse_assembly.txt',
                         'val_coarse_disassembly.txt']
            elif self.split == 'train':
                files = ['train_coarse_assembly.txt', 'train_coarse_disassembly.txt']
        elif self.fold == 'val':
            files = ['val_coarse_assembly.txt', 'val_coarse_disassembly.txt']
        else:
            print("unknown split, quit")
            exit(1)

        video_align = {}
        data = self.read_files(files, fold_file_name)
        if self.reduce_data:
            reduce_len = int(len(data)/5)
            data = data[ :reduce_len]

        ##=========loop all videos in the split
        data_arr = []
        for i, video_id in enumerate(data):
            video_id = video_id.split(".txt")[0]
            filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt")
            indexs = []
            with open(filename, 'r') as f:  # for ground-truth
                lines = f.readlines()
                for l in lines:
                    tmp = l.split('\t')
                    start_l, end_l, label_l = int(tmp[0]), int(tmp[1]), tmp[2]
                    indexs.extend([start_l, end_l])
            span = [min(indexs), max(indexs)]


            type_action = video_id.split('_')[0]  # assembly
            key_id = video_id.split(type_action)[1][1:] # key_id:'nusar-2021_action_both_9011-b06b_9011_user_id_2021-02-01_154253'

            video_align[video_id] = {}
            st_frame = []
            end_frame = []
            view_list = []
            ##======= parse all view of a video
            for view in self.VIEWS:
                if view not in self.statistic[key_id]:
                    continue

                assert self.statistic[key_id][view][0] <= span[0]
                span[1] = min(span[1], self.statistic[key_id][view][1])
                if span[1] <= span[0]:
                    # the video only involves preparation, no action before it's end.
                    continue

                # video_align[key_id][view]= [self.statistic[key_id][view][0]] # video start
                # video_align[key_id][view].append(self.statistic[key_id][view][1]) # video end
                # video_align[key_id][view].append(span[0]) #video span 0
                # video_align[key_id][view].append(span[1]) #video span 1

                st_frame.append(span[0])
                end_frame.append(span[1])
                view_list.append(view)

            video_align[video_id]["st_frame"] = max(st_frame)
            video_align[video_id]["end_frame"] = min(end_frame)
            video_align[video_id]["view_list"] = view_list
            # if len(video_align[key_id]) != 12 :
            #     print(key_id)
        print("Number of videos logged in {} fold is {}".format(self.fold, len(video_align)))
        return video_align


    ## for each sample, frame by frame
    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data_pairs[index] # a group data include
        num_group = len(ele_dict)
        elements_c_group = np.zeros((num_group, 2048, 40), dtype=np.float32)
        elements_h_group = np.zeros((num_group, 2048, 40), dtype=np.float32)
        elements_mask = np.zeros((num_group, 40))
        if isinstance(ele_dict, list):
            elements_c_list = []
            elements_h_list = []
            for idx, single_pair in enumerate(ele_dict):
                st_frame = single_pair['st_frame']
                end_frame = single_pair['end_frame']
                view_c = single_pair['view_c']
                view_h = single_pair['view_h']
                video_id = single_pair['video_id']
                win = single_pair['aug_win']

                type_action = video_id.split('_')[0]  # assembly
                key_id = video_id.split(type_action)[1][1:]

                elements_c = []
                elements_h = []
                with self.env[view_c].begin() as e_c:
                    with self.env[view_h].begin() as e_h:
                        for i in range(st_frame, end_frame+1):
                            key_c = key_id + self.frames_format.format(view_c, view_c,
                                                                         i)  ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                            key_h = key_id + self.frames_format.format(view_h, view_h,
                                                                         i)  ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                            data_c = e_c.get(key_c.strip().encode('utf-8'))  # get the specific fram here
                            data_h = e_h.get(key_h.strip().encode('utf-8'))  # get the specific fram here
                            if (data_h is None) or (data_c is None):
                                print('no available data.')
                                exit(2)
                            data_c = np.frombuffer(data_c, 'float32')  # chang to 'float32' as each frame
                            data_h = np.frombuffer(data_c, 'float32')
                            assert data_c.shape[0] == 2048 and data_h.shape[0] == 2048
                            elements_c.append(data_c) # add one frame, elements_c as all window frame
                            elements_h.append(data_h)
                    elements_c = np.array(elements_c).T  # a win in a view [2048, 19]
                    elements_h = np.array(elements_h).T
                    # if elements_c.shape[-1] != win or elements_h.shape[-1] != win:
                    #     a =1
                    assert elements_c.shape[-1] == win and elements_h.shape[-1] == win

                elements_c_group[idx, :, :win] = elements_c
                elements_h_group[idx, :, :win] = elements_h
                elements_mask[idx, :win] = 1
            assert idx== num_group-1

            #         elements_c_list.append(elements_c)
            #         elements_h_list.append(elements_h)
            # elements_c_list = np.stack(elements_c_list, axis=0) # [group, c, win]
            # elements_h_list = np.stack(elements_h_list, axis=0)
            # ##======process win changed among group, to pading the group to adopt to other group
            # n, c, _ = elements_c_list.shape
            # elements_c_group = np.zeros((n, c, 40), dtype=np.float32)
            # elements_c_group[:, :, :win] = elements_c_list
            # elements_h_group = np.zeros((n, c, 40),dtype=np.float32)
            # elements_h_group[:, :, :win] = elements_h_list
            # elements_mask = np.zeros((n, 40))
            # elements_mask[:, :win] = 1

            return elements_c_group, elements_h_group, elements_mask

        else:
            st_frame = ele_dict['st_frame']
            end_frame = ele_dict['end_frame']
            view_c = ele_dict['view_c']
            view_h = ele_dict['view_h']
            video_id = ele_dict['video_id']

            elements_c = []
            elements_h = []
            with self.env[view_c].begin() as e_c:
                with self.env[view_h].begin() as e_h:
                    for i in range(st_frame, end_frame):
                        key_c = video_id + self.frames_format.format(view_c, view_c, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                        key_h = video_id + self.frames_format.format(view_h, view_h, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                        data_c = e_c.get(key_c.strip().encode('utf-8')) #get the specific fram here
                        data_h = e_h.get(key_h.strip().encode('utf-8')) #get the specific fram here
                        if (data_h is None) or (data_c is None):
                            print('no available data.')
                            exit(2)
                        data_c = np.frombuffer(data_c, 'float32') # chang to 'float32' as each frame
                        data_h = np.frombuffer(data_c, 'float32')
                        assert data_c.shape[0] == 2048 and data_h.shape[0] == 2048
                        elements_c.append(data_c)
                        elements_h.append(data_h)
                elements_c = np.array(elements_c).T  #a clip in a view [2048, 100]
                elements_h = np.array(elements_h).T
            return elements_c, elements_h


    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data_pairs)
