import torch
import lmdb
import pickle as pkl
import os
import numpy as np
import random

""" aggregate: process all view of a video, to aggregate all view1 while all vie3.
"""

################## Dataloader
def collate_fn_override(data):
    """
       data:
    """
    data = list(filter(lambda x: x is not None, data))
    data_arr, count, labels, clip_length, start, video_id, labels_present_arr, aug_chunk_size = zip(*data)

    return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(clip_length), \
           torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(aug_chunk_size, dtype=torch.int)


class AugmentDataset_uniform_group_multiview(torch.utils.data.Dataset):  # build dataset here
    def __init__(self, args, fold):
        fold_file_name = args.vid_list_path
        self.fold = fold
        self.base_dir_name = args.features_path ##'/data/peiyao/all_data/Assembly101/TSM_features'
        self.frames_format = "/{}/{}_{:010d}.jpg"
        self.ground_truth_files_dir = args.gt_path ##'/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_labels/'
        self.validation = True if fold == 'val' else False
        self.split = args.split
        self.VIEWS = args.VIEWS ##list
        self.env = {view: lmdb.open(f'{args.features_path}/{view}', readonly=True, lock=False) for view in args.VIEWS}
        self.win = args.train_win

        self.win_dis = args.win_dis
        self.stride = args.stride
        self.dilation = args.dilation
        self.num_in_group = args.num_in_group
        self.reduce_data = args.reduce_data

        with open('data/statistic_input.pkl', 'rb') as f:
            self.statistic = pkl.load(f)
        # self.data = self.make_data_set(fold_file_name)
        self.data_align = self.make_data_set(fold_file_name)
        self.data_pairs = self.make_clip_pair_group(self.data_align)
        a = 1

    def read_files(self, list_files, fold_file_name):
        data = []
        for file in list_files: ##file:train_coarse_assembly.txt
            lines = open(fold_file_name + file).readlines() ##open the text ground-truth
            for l in lines:
                data.append(l.split('\t')[0])
        return data

    def make_clip_pair_group(self, data_align):
        pairs_collect = []
        num_video = 0
        win = self.win  # 20the length of clip
        stride = self.stride
        num_in_group = self.num_in_group  # the pair in nearest window for difficult case
        idx = 0
        pair_group = []

        for video_id in data_align:

            span_0 = data_align[video_id]["st_frame"]
            span_1 = data_align[video_id]["end_frame"]
            view_list = data_align[video_id]["view"]

            if span_0>= span_1:
                a =1
            assert span_0 < span_1

            ##-----------change the frame_ids into win for the next indx
            vid_len = span_1 - span_0
            dilation_rate = self.dilation
            frame_ids = np.arange(span_0, span_1)
            frame_ids = torch.from_numpy(frame_ids.astype(np.float32))
            Unfold = torch.nn.Unfold(kernel_size = (1, win), dilation= (1, dilation_rate), padding=0, stride=(1,stride))
            unfold_frame = Unfold(frame_ids[None, None, None, :]).squeeze(0).transpose(1,0) # [1,win,num]
            num_win = unfold_frame.shape[0]
            unfold_frame = unfold_frame.numpy().astype(np.int64)

            view_c = [key for key in view_list if key.startswith('C')]
            view_h = [key for key in view_list if key.startswith('H')]
            # if key_id =="nusar-2021_action_both_9061-c13d_9061_user_id_2021-02-09_143830":
            #     a = 1
            num_c = len(view_c)
            num_h = len(view_h)

            if len(view_c) != 8 or len(view_h) != 4:
                print('Becuase of the lacking of view, remove the file:' + video_id)
                continue

            for i in range(self.win_dis): # window (with win_dis) to in a group
                for win_id in range(i, num_win, self.win_dis): # find all windows in a win_dis
                    win_clip = unfold_frame[win_id, :]
                    win_start = win_clip[0]
                    win_end = win_clip[-1]

                    assert win_end + 1 - win_start == win
                    assert win_end <= span_1
                    clip_pair = {'view_c': view_c, 'view_h': view_h, 'st_frame': win_start,
                                 'end_frame': win_end, #include many view1 and view3
                                 'video_id': video_id}
                    pair_group.append(clip_pair)
                    idx += 1

                    if idx == num_in_group: # if the final group is not reach num_in_group, then ignore
                        pairs_collect.append(pair_group)
                        idx = 0
                        pair_group = []
                pass

            num_video += 1
        print('================== {}:(uniform_group2) The number of aligned pairs  is {}, and are divide into {} group. ====================='.format(self.fold, len(pairs_collect)*num_in_group, len(pairs_collect)))
        print('(all are from {} videos, the win length is {}, {} clip pairs in a group.)'.format( num_video, win, num_in_group))
        print('============================== prepare training dataset DONE! ============================= ')

        return pairs_collect

    def make_data_set(self, fold_file_name): #fold_file_name: '/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_splits/'
        #
        if self.fold == 'train':
            if self.split == 'train_val':
                files = ['train_coarse_assembly.txt', 'train_coarse_disassembly.txt', 'val_coarse_assembly.txt',
                         'val_coarse_disassembly.txt']
            elif self.split == 'train':
                files = ['train_coarse_assembly.txt', 'train_coarse_disassembly.txt']
        elif self.fold == 'val':
            files = ['val_coarse_assembly.txt', 'val_coarse_disassembly.txt']
        else:
            print("unknown split, quit")
            exit(1)

        video_align = {}
        data = self.read_files(files, fold_file_name)

        # reduce_len = int(len(data)/5)
        # data = data[ :reduce_len]
        if self.reduce_data:
            reduce_len = int(len(data)/5)
            data = data[ :reduce_len]

        data_arr = []
        more = 0
        for i, video_id in enumerate(data):
            video_id = video_id.split(".txt")[0]
            filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt")

            indexs = []
            with open(filename, 'r') as f:  # for ground-truth
                lines = f.readlines()
                for l in lines:
                    tmp = l.split('\t')
                    start_l, end_l, label_l = int(tmp[0]), int(tmp[1]), tmp[2]
                    indexs.extend([start_l, end_l])
            span = [min(indexs), max(indexs)]

            type_action = video_id.split('_')[0]  # assembly
            key_id = video_id.split(type_action)[1][1:] # key_id:'nusar-2021_action_both_9011-b06b_9011_user_id_2021-02-01_154253'

            if key_id in video_align:
                a  =1
                more += 1

            ##---------more
            video_align[video_id] = {} # because of diassembly and assembly common use the same key_id, so use video_id for diff
            st_frame = []
            end_frame = []
            all_view = []
            for view in self.VIEWS:
                if view not in self.statistic[key_id]:
                    continue

                assert self.statistic[key_id][view][0] <= span[0] #key_id view is common, and the gt span to extract assembly or disassembly data
                span[1] = min(span[1], self.statistic[key_id][view][1])
                if span[1] <= span[0]:
                    # the video only involves preparation, no action before it's end.
                    continue

                # video_align[video_id][view]= [self.statistic[key_id][view][0]] # video start
                # video_align[video_id][view].append(self.statistic[key_id][view][1]) # video end
                # video_align[video_id][view].append(span[0]) #video span 0
                # video_align[video_id][view].append(span[1]) #video span 1

                st_frame.append(span[0])
                end_frame.append(span[1])
                all_view.append(view)

            video_align[video_id]["st_frame"] = max(st_frame) # the common st of a assembly video
            video_align[video_id]["end_frame"] = min(end_frame)# the common end of a assembly video
            video_align[video_id]["type"] = type_action
            video_align[video_id]["view"] = all_view
            # if len(video_align[key_id]) != 12 :
            #     print(key_id)
        print("Number of videos logged in {} fold is {}".format(self.fold, len(video_align)))
        return video_align


    ## for each sample, frame by frame
    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data_pairs[index]

        elements_c_list = []
        elements_h_list = []
        for single_pair in ele_dict:
            st_frame = single_pair['st_frame']
            end_frame = single_pair['end_frame']
            view_cs = single_pair['view_c']
            view_hs = single_pair['view_h']
            video_id = single_pair['video_id']

            type_action = video_id.split('_')[0]  # assembly
            key_id = video_id.split(type_action)[1][1:]

            elements_c = []
            elements_cs = []
            for view_c in view_cs: #each view_c
                with self.env[view_c].begin() as e_c:
                    elements_c = []
                    for i in range(st_frame, end_frame+1):
                        key_c = key_id + self.frames_format.format(view_c, view_c, i)
                        data_c = e_c.get(key_c.strip().encode('utf-8'))
                        if (data_c is None):
                            print('no available data.')
                            exit(2)
                        data_c = np.frombuffer(data_c, 'float32')  # chang to 'float32' as each frame
                        elements_c.append(data_c)
                    elements_ct = np.array(elements_c) #[win, 2048]
                elements_cs.append(elements_ct) #[]
            elements_cs = np.stack(elements_cs, axis=0)

            elements_hs = []
            for view_h in view_hs: #each view_c
                with self.env[view_h].begin() as e_h:
                    elements_h = []
                    for i in range(st_frame, end_frame+1):
                        key_h = key_id + self.frames_format.format(view_h, view_h, i)
                        data_h = e_h.get(key_h.strip().encode('utf-8'))
                        if (data_h is None) :
                            print('no available data.')
                            exit(2)
                        data_h = np.frombuffer(data_h, 'float32')  # chang to 'float32' as each frame
                        elements_h.append(data_h)
                    elements_ht = np.array(elements_h)
                elements_hs.append(elements_ht)
            elements_hs = np.stack(elements_hs, axis=0)


            elements_c_list.append(elements_cs)
            elements_h_list.append(elements_hs)
        elements_c_list = np.stack(elements_c_list, axis=0)
        elements_h_list = np.stack(elements_h_list, axis=0) #[group, num_view, win, 2048]
        assert elements_c_list.shape[-2] == self.win and elements_h_list.shape[-2] == self.win

        return elements_c_list, elements_h_list

    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data_pairs)


class AugmentDataset_val_uniform_group(torch.utils.data.Dataset):  # build dataset here
    def __init__(self, args, fold ):
        fold_file_name = args.vid_list_path
        self.fold = fold
        self.base_dir_name = args.features_path ##'/data/peiyao/all_data/Assembly101/TSM_features'
        self.frames_format = "/{}/{}_{:010d}.jpg"
        self.ground_truth_files_dir = args.gt_path ##'/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_labels/'
        self.VIEWS = args.VIEWS ##list
        self.env = {view: lmdb.open(f'{args.features_path}/{view}', readonly=True, lock=False) for view in args.VIEWS}
        self.win = args.val_win
        self.clip_pervid = args.val_clip_pervid

        with open('data/statistic_input.pkl', 'rb') as f:
            self.statistic = pkl.load(f)
        # self.data = self.make_data_set(fold_file_name)
        self.data_align = self.make_data_set(fold_file_name)
        self.data_pairs = self.make_clip_pair(self.data_align)

    def read_files(self, list_files, fold_file_name):
        data = []
        for file in list_files: ##file:train_coarse_assembly.txt
            lines = open(fold_file_name + file).readlines() ##open the text ground-truth
            for l in lines:
                data.append(l.split('\t')[0])  ## all txt file name
        return data

    def make_clip_pair(self, data_align):
        pairs_coll = []
        num_video = 0
        for key_id, view in data_align.items():
            self.statistic[key_id]
            view_c = [key for key in view if key.startswith('C')]
            view_h = [key for key in view if key.startswith('H')]
            # if key_id =="nusar-2021_action_both_9061-c13d_9061_user_id_2021-02-09_143830":
            #     a = 1
            if len(view_c) == 0 or len(view_h) == 0:
                print('Becuase of the lacking of view, remove the file:' + key_id)
                continue

            span_0 = data_align[key_id]["st_frame"]
            span_1 = data_align[key_id]["end_frame"]
            if span_0>= span_1:
                a =1
            assert span_0 < span_1
            win = self.win # the length of clip
            num = self.clip_pervid # the number of pair for each video
            ## average to cut clips
            vid_len = span_1 - span_0
            clip_slide = int((vid_len - win)/num)
            for i in range(num):
                for j in range(2):
                    c_id = j #(0,0) (1,1) fixed choice
                    h_id = j
                    win_start = i * clip_slide + span_0
                    win_end = win_start + win
                    assert win_end <= span_1, "Out of video range!"
                    clip_pair = { 'view_c': view_c[c_id], 'view_h': view_h[h_id] , 'st_frame': win_start, 'end_frame': win_end,
                                'video_id': key_id}
                    pairs_coll.append(clip_pair)

            num_video += 1
        print('====================  VAL: The number of aligned pairs  for val is {}. ==================== '.format(len(pairs_coll)))
        print('(All are from {} videos, each video generate {} pairs, the clip length is {}.)'.format( num_video, num, win))
        print('============================== prepare validation dataset DONE! ============================= ')

        return pairs_coll

    def make_data_set(self, fold_file_name): #fold_file_name: '/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_splits/'
        if self.fold == 'val':
            files = ['val_coarse_assembly.txt', 'val_coarse_disassembly.txt']
        elif self.fold == 'test':
            files = ['test_coarse_assembly.txt', 'test_coarse_disassembly.txt']
        else:
            print("unknown split, quit")
            exit(1)

        video_align = {}
        data = self.read_files(files, fold_file_name)
        data_arr = []
        for i, video_id in enumerate(data):
            video_id = video_id.split(".txt")[0]
            filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt")

            indexs = []
            with open(filename, 'r') as f:  # for ground-truth
                lines = f.readlines()
                for l in lines:
                    tmp = l.split('\t')
                    start_l, end_l, label_l = int(tmp[0]), int(tmp[1]), tmp[2]
                    indexs.extend([start_l, end_l])
            span = [min(indexs), max(indexs)]

            type_action = video_id.split('_')[0]  # assembly
            key_id = video_id.split(type_action)[1][1:]
            video_align[key_id] = {}
            st_frame = []
            end_frame = []
            for view in self.VIEWS:
                if view not in self.statistic[key_id]:
                    continue

                assert self.statistic[key_id][view][0] <= span[0]
                span[1] = min(span[1], self.statistic[key_id][view][1])
                if span[1] <= span[0]:
                    # the video only involves preparation, no action before it's end.
                    continue

                video_align[key_id][view]= [self.statistic[key_id][view][0]] # video start
                video_align[key_id][view].append(self.statistic[key_id][view][1]) # video end
                video_align[key_id][view].append(span[0]) #video span 0
                video_align[key_id][view].append(span[1]) #video span 1
                st_frame.append(span[0])
                end_frame.append(span[1])

            video_align[key_id]["st_frame"] = max(st_frame)
            video_align[key_id]["end_frame"] = min(end_frame)
            # print(len(video_align[key_id]))
            # if len(video_align[key_id]) != 12 :
            #     print(key_id)
        print("Number of videos logged in {} fold is {}".format(self.fold, len(video_align)))
        return video_align


    ## for each sample, frame by frame
    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data_pairs[index]
        st_frame = ele_dict['st_frame']
        end_frame = ele_dict['end_frame']
        view_c = ele_dict['view_c']
        view_h = ele_dict['view_h']
        video_id = ele_dict['video_id']

        elements_c = []
        elements_h = []
        with self.env[view_c].begin() as e_c:
            with self.env[view_h].begin() as e_h:
                for i in range(st_frame, end_frame):
                    key_c = video_id + self.frames_format.format(view_c, view_c, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                    key_h = video_id + self.frames_format.format(view_h, view_h, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                    data_c = e_c.get(key_c.strip().encode('utf-8')) #get the specific fram here
                    data_h = e_h.get(key_h.strip().encode('utf-8')) #get the specific fram here
                    if (data_h is None) or (data_c is None):
                        print('no available data.')
                        exit(2)
                    data_c = np.frombuffer(data_c, 'float32') # chang to 'float32' as each frame
                    data_h = np.frombuffer(data_c, 'float32')
                    assert data_c.shape[0] == 2048 and data_h.shape[0] == 2048
                    elements_c.append(data_c)
                    elements_h.append(data_h)
            elements_c = np.array(elements_c).T  #a clip in a view [2048, 100]
            elements_h = np.array(elements_h).T

            assert elements_c.shape[-1] == win and elements_h.shape[-1] == win

        return elements_c, elements_h


    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data_pairs)



def collate_fn_override_test(data):
    """
       data:
    """
    data = list(filter(lambda x: x is not None, data))
    data_arr, count, labels, video_len, start, video_id, labels_present_arr, chunk_size, chunk_id = zip(*data)
    return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(video_len), \
           torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(chunk_size), \
           torch.tensor(chunk_id)


class AugmentDataset_test(torch.utils.data.Dataset):
    def __init__(self, args, fold, fold_file_name, actions_dict, chunk_size):
        self.fold = fold
        self.max_frames_per_video = args.max_frames_per_video
        self.base_dir_name = args.features_path
        self.frames_format = "/{}/{}_{:010d}.jpg" ##each frame here
        self.ground_truth_files_dir = args.gt_path
        self.num_class = args.num_class
        self.VIEWS = args.VIEWS
        self.actions_dict = actions_dict
        self.env = {view: lmdb.open(f'{args.features_path}/{view}', readonly=True, lock=False) for view in args.VIEWS}
        with open('data/statistic_input.pkl', 'rb') as f:
            self.statistic = pkl.load(f)
        self.chunk_size_arr = chunk_size
        self.data = self.make_data_set(fold_file_name)

    def read_files(self, list_files, fold_file_name):
        data = []
        for file in list_files:
            lines = open(fold_file_name + file).readlines()
            for l in lines:
                data.append(l.split('\t')[0])
        return data

    def make_data_set(self, fold_file_name):
        label_name_to_label_id_dict = self.actions_dict
        if self.fold == 'val':
            files = ['val_coarse_assembly.txt', 'val_coarse_disassembly.txt']
        elif self.fold == 'test':
            files = ['test_coarse_assembly.txt', 'test_coarse_disassembly.txt']
        else:
            print("Unknown data folder")
            exit(3)
        data = self.read_files(files, fold_file_name)

        data_arr = []
        for i, video_id in enumerate(data):
            video_id = video_id.split(".txt")[0]
            if 'disassembly' in video_id:
                video_id = video_id.replace('disassembly', 'disassebly')
            filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt")

            recog_content, indexs = [], []
            with open(filename, 'r') as f:
                lines = f.readlines()
                for l in lines:
                    tmp = l.split('\t')
                    start_l, end_l, label_l = int(tmp[0]), int(tmp[1]), tmp[2]
                    indexs.extend([start_l, end_l])
                    recog_content.extend([label_l] * (end_l - start_l))

            recog_content = [label_name_to_label_id_dict[e] for e in recog_content]
            span = [min(indexs), max(indexs)]  # [start end)

            len_video = len(recog_content)
            assert len_video == (span[1] - span[0])

            chunk_size_arr = self.chunk_size_arr
            for view in self.VIEWS:
                type_action = video_id.split('_')[0]
                key_id = video_id.split(type_action)[1][1:]

                if view not in self.statistic[key_id]:
                    continue
                # span[0] = max(span[0], self.statistic[key_id][view][0])
                assert self.statistic[key_id][view][0] <= span[0]
                span[1] = min(span[1], self.statistic[key_id][view][1])
                if span[1] <= span[0]:
                    continue
                for j, chunk_size in enumerate(chunk_size_arr):
                    start_frame_arr = []
                    end_frame_arr = []
                    for st in range(span[0], span[1], self.max_frames_per_video * chunk_size):
                        start_frame_arr.append(st)
                        max_end = st + (self.max_frames_per_video * chunk_size)
                        end_frame = max_end if max_end < span[1] else span[1]
                        end_frame_arr.append(end_frame)

                    # print(span[1] - span[0])
                    # if len(start_frame_arr) >= 2:
                    #     print(video_id, view)

                    for st_frame, end_frame in zip(start_frame_arr, end_frame_arr):
                        ele_dict = {'type': type_action, 'view': view, 'st_frame': st_frame, 'end_frame': end_frame,
                                    'chunk_id': j, 'chunk_size': chunk_size, 'video_id': key_id,
                                    'tot_frames': (end_frame - st_frame) // chunk_size}

                        ele_dict["labels"] = np.array(recog_content[st_frame - span[0]:end_frame - span[0]], dtype=int)
                        data_arr.append(ele_dict)

        print("Number of datapoints logged in {} fold is {}".format(self.fold, len(data_arr)))
        return data_arr

    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data[index]
        st_frame = ele_dict['st_frame']
        end_frame = ele_dict['end_frame']
        aug_chunk_size = ele_dict['chunk_size']
        view = ele_dict['view']
        vid_type = ele_dict['type']

        elements = []
        with self.env[view].begin() as e:
            for i in range(st_frame, end_frame):
                key = ele_dict['video_id'] + self.frames_format.format(view, view, i) ##get frame by frame
                data = e.get(key.strip().encode('utf-8'))
                if data is None:
                    print('no available data.')
                    exit(2)
                data = np.frombuffer(data, 'float32')
                assert data.shape[0] == 2048
                elements.append(data)

        elements = np.array(elements).T

        count = 0
        labels_present_arr = torch.zeros(self.num_class, dtype=torch.float32)
        data_arr = torch.zeros((self.max_frames_per_video, self.feature_size))
        label_arr = torch.ones(self.max_frames_per_video, dtype=torch.long) * -100
        for i in range(st_frame, end_frame, aug_chunk_size):
            end = min(end_frame, i + aug_chunk_size)
            key = elements[:, i - st_frame: end - st_frame]
            values, counts = np.unique(ele_dict["labels"][i - st_frame: end - st_frame], return_counts=True)
            label_arr[count] = values[np.argmax(counts)]
            labels_present_arr[label_arr[count]] = 1
            data_arr[count, :] = torch.tensor(np.max(key, axis=-1), dtype=torch.float32)
            count += 1

        return data_arr, count, label_arr, elements.shape[1], st_frame, vid_type + '_' + ele_dict['video_id'] \
               + '%{}'.format(view), labels_present_arr, aug_chunk_size, ele_dict['chunk_id']

    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data)