import torch
import lmdb
import pickle as pkl
import os
import numpy as np
import random

################## Dataloader
def collate_fn_override(data):
    """
       data:
    """
    data = list(filter(lambda x: x is not None, data))
    data_arr, count, labels, clip_length, start, video_id, labels_present_arr, aug_chunk_size = zip(*data)

    return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(clip_length), \
           torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(aug_chunk_size, dtype=torch.int)

class AugmentDataset_uniform(torch.utils.data.Dataset):  # build dataset here
    def __init__(self, args, fold):
        fold_file_name = args.vid_list_path
        self.fold = fold
        self.base_dir_name = args.features_path ##'/data/peiyao/all_data/Assembly101/TSM_features'
        self.frames_format = "/{}/{}_{:010d}.jpg"
        self.ground_truth_files_dir = args.gt_path ##'/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_labels/'
        self.validation = True if fold == 'val' else False
        self.split = args.split
        self.VIEWS = args.VIEWS ##list
        self.env = {view: lmdb.open(f'{args.features_path}/{view}', readonly=True, lock=False) for view in args.VIEWS}
        # a  = lmdb.open('/data/peiyao/all_data/Assembly101/TSM_features/C10095_rgb', readonly=True, lock=False)
        self.win = args.train_win
        self.clip_pervid = args.train_clip_pervid

        with open('data/statistic_input.pkl', 'rb') as f:
            self.statistic = pkl.load(f)
        # self.data = self.make_data_set(fold_file_name)
        self.data_align = self.make_data_set(fold_file_name)
        self.data_pairs = self.make_clip_pair(self.data_align)

    def read_files(self, list_files, fold_file_name):
        data = []
        for file in list_files: ##file:train_coarse_assembly.txt
            lines = open(fold_file_name + file).readlines() ##open the text ground-truth
            for l in lines:
                data.append(l.split('\t')[0])
        return data

    def make_clip_pair(self, data_align):
        pairs_coll = []
        num_video = 0
        for key_id, view in data_align.items():
            self.statistic[key_id]
            view_c = [key for key in view if key.startswith('C')]
            view_h = [key for key in view if key.startswith('H')]
            # if key_id =="nusar-2021_action_both_9061-c13d_9061_user_id_2021-02-09_143830":
            #     a = 1
            num_c = len(view_c)
            num_h = len(view_h)
            # print(num_c, num_h)
            if len(view_c) == 0 or len(view_h) == 0:
                print('Becuase of the lacking of view, remove the file:' + key_id)
                continue

            span_0 = data_align[key_id]["st_frame"]
            span_1 = data_align[key_id]["end_frame"]
            if span_0>= span_1:
                a =1
            assert span_0 < span_1
            win = self.win # the length of clip
            num = self.clip_pervid # the number of pair for each video

            vid_len = span_1 - span_0
            clip_slide = int((vid_len - win)/num)

            h_view_id = np.random.randint(0, num_h, size=num_c)
            for i in range(num):
                for j in range(num_c): #not 2, is the length of Cview
                    c_id = j # 0 0r 1
                    h_id = h_view_id[j]
                    win_start = i * clip_slide + span_0
                    win_end = win_start + win
                    # c_id = np.random.randint(0, len(view_c)) #pick up a view from c_view
                    # h_id = np.random.randint(0, len(view_h))
                    # win_start = np.random.randint(span_0, span_1 - win)
                    # win_end = win_start + win
                    assert win_end <= span_1
                    clip_pair = { 'view_c': view_c[c_id], 'view_h': view_h[h_id] , 'st_frame': win_start, 'end_frame': win_end,
                                'video_id': key_id}
                    pairs_coll.append(clip_pair)

            num_video += 1
        print('================== TRAIN:(uniform) The number of aligned pairs  is {}. ====================='.format(len(pairs_coll)))
        print('(all are from {} videos, each video generate {} pairs, the clip length is {}.)'.format( num_video, num, win))
        print('============================== prepare training dataset DONE! ============================= ')

        return pairs_coll

    def make_data_set(self, fold_file_name): #fold_file_name: '/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_splits/'
        if self.fold == 'train':
            if self.split == 'train_val':
                files = ['train_coarse_assembly.txt', 'train_coarse_disassembly.txt', 'val_coarse_assembly.txt',
                         'val_coarse_disassembly.txt']
            elif self.split == 'train':
                files = ['train_coarse_assembly.txt', 'train_coarse_disassembly.txt']
        elif self.fold == 'val':
            files = ['val_coarse_assembly.txt', 'val_coarse_disassembly.txt']
        else:
            print("unknown split, quit")
            exit(1)

        video_align = {}
        data = self.read_files(files, fold_file_name)
        data_arr = []
        for i, video_id in enumerate(data):
            video_id = video_id.split(".txt")[0]
            filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt")

            indexs = []
            with open(filename, 'r') as f:  # for ground-truth
                lines = f.readlines()
                for l in lines:
                    tmp = l.split('\t')
                    start_l, end_l, label_l = int(tmp[0]), int(tmp[1]), tmp[2]
                    indexs.extend([start_l, end_l])
            span = [min(indexs), max(indexs)]


            type_action = video_id.split('_')[0]  # assembly
            key_id = video_id.split(type_action)[1][1:]
            video_align[key_id] = {}
            st_frame = []
            end_frame = []
            for view in self.VIEWS:
                if view not in self.statistic[key_id]:
                    continue

                assert self.statistic[key_id][view][0] <= span[0]
                span[1] = min(span[1], self.statistic[key_id][view][1])
                if span[1] <= span[0]:
                    # the video only involves preparation, no action before it's end.
                    continue

                video_align[key_id][view]= [self.statistic[key_id][view][0]] # video start
                video_align[key_id][view].append(self.statistic[key_id][view][1]) # video end
                video_align[key_id][view].append(span[0]) #video span 0
                video_align[key_id][view].append(span[1]) #video span 1
                st_frame.append(span[0])
                end_frame.append(span[1])

            video_align[key_id]["st_frame"] = max(st_frame)
            video_align[key_id]["end_frame"] = min(end_frame)
            # if len(video_align[key_id]) != 12 :
            #     print(key_id)
        print("Number of videos logged in {} fold is {}".format(self.fold, len(video_align)))
        return video_align


    ## for each sample, frame by frame
    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data_pairs[index]
        st_frame = ele_dict['st_frame']
        end_frame = ele_dict['end_frame']
        view_c = ele_dict['view_c']
        view_h = ele_dict['view_h']
        video_id = ele_dict['video_id']

        elements_c = []
        elements_h = []
        with self.env[view_c].begin() as e_c:
            with self.env[view_h].begin() as e_h:
                for i in range(st_frame, end_frame):
                    key_c = video_id + self.frames_format.format(view_c, view_c, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                    key_h = video_id + self.frames_format.format(view_h, view_h, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                    data_c = e_c.get(key_c.strip().encode('utf-8')) #get the specific fram here
                    data_h = e_h.get(key_h.strip().encode('utf-8')) #get the specific fram here
                    if (data_h is None) or (data_c is None):
                        print('no available data.')
                        exit(2)
                    data_c = np.frombuffer(data_c, 'float32') # chang to 'float32' as each frame
                    data_h = np.frombuffer(data_c, 'float32')
                    assert data_c.shape[0] == 2048 and data_h.shape[0] == 2048
                    elements_c.append(data_c)
                    elements_h.append(data_h)
            elements_c = np.array(elements_c).T  #a clip in a view [2048, 100]
            elements_h = np.array(elements_h).T
        return elements_c, elements_h


    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data_pairs)



class AugmentDataset_val_uniform(torch.utils.data.Dataset):  # build dataset here
    def __init__(self, args, fold ):
        fold_file_name = args.vid_list_path
        self.fold = fold
        self.base_dir_name = args.features_path ##'/data/peiyao/all_data/Assembly101/TSM_features'
        self.frames_format = "/{}/{}_{:010d}.jpg"
        self.ground_truth_files_dir = args.gt_path ##'/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_labels/'
        self.VIEWS = args.VIEWS ##list
        self.env = {view: lmdb.open(f'{args.features_path}/{view}', readonly=True, lock=False) for view in args.VIEWS}
        self.win = args.val_win
        self.clip_pervid = args.val_clip_pervid

        with open('data/statistic_input.pkl', 'rb') as f:
            self.statistic = pkl.load(f)
        # self.data = self.make_data_set(fold_file_name)
        self.data_align = self.make_data_set(fold_file_name)
        self.data_pairs = self.make_clip_pair(self.data_align)

    def read_files(self, list_files, fold_file_name):
        data = []
        for file in list_files: ##file:train_coarse_assembly.txt
            lines = open(fold_file_name + file).readlines() ##open the text ground-truth
            for l in lines:
                data.append(l.split('\t')[0])  ## all txt file name
        return data

    def make_clip_pair(self, data_align):
        pairs_coll = []
        num_video = 0
        for key_id, view in data_align.items():
            self.statistic[key_id]
            view_c = [key for key in view if key.startswith('C')]
            view_h = [key for key in view if key.startswith('H')]
            # if key_id =="nusar-2021_action_both_9061-c13d_9061_user_id_2021-02-09_143830":
            #     a = 1
            if len(view_c) == 0 or len(view_h) == 0:
                print('Becuase of the lacking of view, remove the file:' + key_id)
                continue

            span_0 = data_align[key_id]["st_frame"]
            span_1 = data_align[key_id]["end_frame"]
            if span_0>= span_1:
                a =1
            assert span_0 < span_1
            win = self.win # the length of clip
            num = self.clip_pervid # the number of pair for each video
            ## average to cut clips
            vid_len = span_1 - span_0
            clip_slide = int((vid_len - win)/num)
            for i in range(num):
                for j in range(2):
                    c_id = j #(0,0) (1,1) fixed choice
                    h_id = j
                    win_start = i * clip_slide + span_0
                    win_end = win_start + win
                    assert win_end <= span_1, "Out of video range!"
                    clip_pair = { 'view_c': view_c[c_id], 'view_h': view_h[h_id] , 'st_frame': win_start, 'end_frame': win_end,
                                'video_id': key_id}
                    pairs_coll.append(clip_pair)

            num_video += 1
        print('====================  VAL: The number of aligned pairs  for val is {}. ==================== '.format(len(pairs_coll)))
        print('(All are from {} videos, each video generate {} pairs, the clip length is {}.)'.format( num_video, num, win))
        print('============================== prepare validation dataset DONE! ============================= ')

        return pairs_coll

    def make_data_set(self, fold_file_name): #fold_file_name: '/data/peiyao/all_data/Assembly101/annotations/coarse-annotations/coarse_splits/'
        if self.fold == 'val':
            files = ['val_coarse_assembly.txt', 'val_coarse_disassembly.txt']
        elif self.fold == 'test':
            files = ['test_coarse_assembly.txt', 'test_coarse_disassembly.txt']
        else:
            print("unknown split, quit")
            exit(1)

        video_align = {}
        data = self.read_files(files, fold_file_name)
        data_arr = []
        for i, video_id in enumerate(data):
            video_id = video_id.split(".txt")[0]
            filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt")

            indexs = []
            with open(filename, 'r') as f:  # for ground-truth
                lines = f.readlines()
                for l in lines:
                    tmp = l.split('\t')
                    start_l, end_l, label_l = int(tmp[0]), int(tmp[1]), tmp[2]
                    indexs.extend([start_l, end_l])
            span = [min(indexs), max(indexs)]

            type_action = video_id.split('_')[0]  # assembly
            key_id = video_id.split(type_action)[1][1:]
            video_align[key_id] = {}
            st_frame = []
            end_frame = []
            for view in self.VIEWS:
                if view not in self.statistic[key_id]:
                    continue

                assert self.statistic[key_id][view][0] <= span[0]
                span[1] = min(span[1], self.statistic[key_id][view][1])
                if span[1] <= span[0]:
                    # the video only involves preparation, no action before it's end.
                    continue

                video_align[key_id][view]= [self.statistic[key_id][view][0]] # video start
                video_align[key_id][view].append(self.statistic[key_id][view][1]) # video end
                video_align[key_id][view].append(span[0]) #video span 0
                video_align[key_id][view].append(span[1]) #video span 1
                st_frame.append(span[0])
                end_frame.append(span[1])

            video_align[key_id]["st_frame"] = max(st_frame)
            video_align[key_id]["end_frame"] = min(end_frame)
            # print(len(video_align[key_id]))
            # if len(video_align[key_id]) != 12 :
            #     print(key_id)
        print("Number of videos logged in {} fold is {}".format(self.fold, len(video_align)))
        return video_align


    ## for each sample, frame by frame
    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data_pairs[index]
        st_frame = ele_dict['st_frame']
        end_frame = ele_dict['end_frame']
        view_c = ele_dict['view_c']
        view_h = ele_dict['view_h']
        video_id = ele_dict['video_id']

        elements_c = []
        elements_h = []
        with self.env[view_c].begin() as e_c:
            with self.env[view_h].begin() as e_h:
                for i in range(st_frame, end_frame):
                    key_c = video_id + self.frames_format.format(view_c, view_c, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                    key_h = video_id + self.frames_format.format(view_h, view_h, i) ## each vid, each view, final name!!! the name of specific frame, frame by frame extracting: key: 'nusar-2021_action_both_9021-a29_9021_user_id_2021-02-23_094113/HMC_84355350_mono10bit/HMC_84355350_mono10bit_0000000192.jpg'
                    data_c = e_c.get(key_c.strip().encode('utf-8')) #get the specific fram here
                    data_h = e_h.get(key_h.strip().encode('utf-8')) #get the specific fram here
                    if (data_h is None) or (data_c is None):
                        print('no available data.')
                        exit(2)
                    data_c = np.frombuffer(data_c, 'float32') # chang to 'float32' as each frame
                    data_h = np.frombuffer(data_c, 'float32')
                    assert data_c.shape[0] == 2048 and data_h.shape[0] == 2048
                    elements_c.append(data_c)
                    elements_h.append(data_h)
            elements_c = np.array(elements_c).T  #a clip in a view [2048, 100]
            elements_h = np.array(elements_h).T
        return elements_c, elements_h


    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data_pairs)



def collate_fn_override_test(data):
    """
       data:
    """
    data = list(filter(lambda x: x is not None, data))
    data_arr, count, labels, video_len, start, video_id, labels_present_arr, chunk_size, chunk_id = zip(*data)
    return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(video_len), \
           torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(chunk_size), \
           torch.tensor(chunk_id)


class AugmentDataset_test(torch.utils.data.Dataset):
    def __init__(self, args, fold, fold_file_name, actions_dict, chunk_size):
        self.fold = fold
        self.max_frames_per_video = args.max_frames_per_video
        self.base_dir_name = args.features_path
        self.frames_format = "/{}/{}_{:010d}.jpg" ##each frame here
        self.ground_truth_files_dir = args.gt_path
        self.num_class = args.num_class
        self.VIEWS = args.VIEWS
        self.actions_dict = actions_dict
        self.env = {view: lmdb.open(f'{args.features_path}/{view}', readonly=True, lock=False) for view in args.VIEWS}
        with open('data/statistic_input.pkl', 'rb') as f:
            self.statistic = pkl.load(f)
        self.chunk_size_arr = chunk_size
        self.data = self.make_data_set(fold_file_name)

    def read_files(self, list_files, fold_file_name):
        data = []
        for file in list_files:
            lines = open(fold_file_name + file).readlines()
            for l in lines:
                data.append(l.split('\t')[0])
        return data

    def make_data_set(self, fold_file_name):
        label_name_to_label_id_dict = self.actions_dict
        if self.fold == 'val':
            files = ['val_coarse_assembly.txt', 'val_coarse_disassembly.txt']
        elif self.fold == 'test':
            files = ['test_coarse_assembly.txt', 'test_coarse_disassembly.txt']
        else:
            print("Unknown data folder")
            exit(3)
        data = self.read_files(files, fold_file_name)

        data_arr = []
        for i, video_id in enumerate(data):
            video_id = video_id.split(".txt")[0]
            if 'disassembly' in video_id:
                video_id = video_id.replace('disassembly', 'disassebly')
            filename = os.path.join(self.ground_truth_files_dir, video_id + ".txt")

            recog_content, indexs = [], []
            with open(filename, 'r') as f:
                lines = f.readlines()
                for l in lines:
                    tmp = l.split('\t')
                    start_l, end_l, label_l = int(tmp[0]), int(tmp[1]), tmp[2]
                    indexs.extend([start_l, end_l])
                    recog_content.extend([label_l] * (end_l - start_l))

            recog_content = [label_name_to_label_id_dict[e] for e in recog_content]
            span = [min(indexs), max(indexs)]  # [start end)

            len_video = len(recog_content)
            assert len_video == (span[1] - span[0])

            chunk_size_arr = self.chunk_size_arr
            for view in self.VIEWS:
                type_action = video_id.split('_')[0]
                key_id = video_id.split(type_action)[1][1:]

                if view not in self.statistic[key_id]:
                    continue
                # span[0] = max(span[0], self.statistic[key_id][view][0])
                assert self.statistic[key_id][view][0] <= span[0]
                span[1] = min(span[1], self.statistic[key_id][view][1])
                if span[1] <= span[0]:
                    continue
                for j, chunk_size in enumerate(chunk_size_arr):
                    start_frame_arr = []
                    end_frame_arr = []
                    for st in range(span[0], span[1], self.max_frames_per_video * chunk_size):
                        start_frame_arr.append(st)
                        max_end = st + (self.max_frames_per_video * chunk_size)
                        end_frame = max_end if max_end < span[1] else span[1]
                        end_frame_arr.append(end_frame)

                    # print(span[1] - span[0])
                    # if len(start_frame_arr) >= 2:
                    #     print(video_id, view)

                    for st_frame, end_frame in zip(start_frame_arr, end_frame_arr):
                        ele_dict = {'type': type_action, 'view': view, 'st_frame': st_frame, 'end_frame': end_frame,
                                    'chunk_id': j, 'chunk_size': chunk_size, 'video_id': key_id,
                                    'tot_frames': (end_frame - st_frame) // chunk_size}

                        ele_dict["labels"] = np.array(recog_content[st_frame - span[0]:end_frame - span[0]], dtype=int)
                        data_arr.append(ele_dict)

        print("Number of datapoints logged in {} fold is {}".format(self.fold, len(data_arr)))
        return data_arr

    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data[index]
        st_frame = ele_dict['st_frame']
        end_frame = ele_dict['end_frame']
        aug_chunk_size = ele_dict['chunk_size']
        view = ele_dict['view']
        vid_type = ele_dict['type']

        elements = []
        with self.env[view].begin() as e:
            for i in range(st_frame, end_frame):
                key = ele_dict['video_id'] + self.frames_format.format(view, view, i) ##get frame by frame
                data = e.get(key.strip().encode('utf-8'))
                if data is None:
                    print('no available data.')
                    exit(2)
                data = np.frombuffer(data, 'float32')
                assert data.shape[0] == 2048
                elements.append(data)

        elements = np.array(elements).T

        count = 0
        labels_present_arr = torch.zeros(self.num_class, dtype=torch.float32)
        data_arr = torch.zeros((self.max_frames_per_video, self.feature_size))
        label_arr = torch.ones(self.max_frames_per_video, dtype=torch.long) * -100
        for i in range(st_frame, end_frame, aug_chunk_size):
            end = min(end_frame, i + aug_chunk_size)
            key = elements[:, i - st_frame: end - st_frame]
            values, counts = np.unique(ele_dict["labels"][i - st_frame: end - st_frame], return_counts=True)
            label_arr[count] = values[np.argmax(counts)]
            labels_present_arr[label_arr[count]] = 1
            data_arr[count, :] = torch.tensor(np.max(key, axis=-1), dtype=torch.float32)
            count += 1

        return data_arr, count, label_arr, elements.shape[1], st_frame, vid_type + '_' + ele_dict['video_id'] \
               + '%{}'.format(view), labels_present_arr, aug_chunk_size, ele_dict['chunk_id']

    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data)