import torch
import lmdb
import pickle as pkl
import os
import numpy as np
import random

""" group2: change number_per_vid to stride to get windows in each video.
"""

################## Dataloader
def collate_fn_override(data):
    """
       data:
    """
    data = list(filter(lambda x: x is not None, data))
    data_arr, count, labels, clip_length, start, video_id, labels_present_arr, aug_chunk_size = zip(*data)

    return torch.stack(data_arr), torch.tensor(count), torch.stack(labels), torch.tensor(clip_length), \
           torch.tensor(start), video_id, torch.stack(labels_present_arr), torch.tensor(aug_chunk_size, dtype=torch.int)

class CharadesEgoDataset(torch.utils.data.Dataset):  # build dataset here
    def __init__(self, args, fold):
        args_common = args[0]
        args_diff = args[1]

        self.fold = fold
        self.fea_path = args_diff.features_path
        self.statistic_file = args_diff.statistic_file

        self.win = args_common.train_win
        self.win_dis = args_common.win_dis
        self.stride = args_common.stride
        self.dilation = args_common.dilation
        self.num_in_group = args_common.num_in_group
        self.reduce_data = args_common.reduce_data

        self.data_pairs = self.make_clip_pair_group()

    def read_files(self,):
        data = []
        length = []
        with open(self.statistic_file, 'r') as file:
            for line in file:
                l = line.replace('\n', '')
                data.append(l.split(',')[0])
                length.append(int(l.split(',')[1]))
        return data, length

    def make_clip_pair_group(self, ):
        video, length = self.read_files()

        pairs_collect = []
        num_video = 0
        win = self.win  # 20 the length of clip, aug here for more win
        stride = self.stride
        num_in_group = self.num_in_group  # the pair in nearest window for difficult case
        idx = 0
        pair_group = []
        for vid_id in range(len(video)):
            key_id = video[vid_id]
            vid_len = length[vid_id]

            ####============from video to generate many window id (just from the span_0 and span_1)
            # !! augment win here as the previous code
            if np.random.randint(low=0, high=2) == 0:
                max_win= int(1.0 * self.win / 0.5)
                min_win= int(1.0 * self.win / 2)
                aug_win = int(np.exp(np.random.uniform(low=np.log(min_win), high=np.log(max_win))))
            else:
                aug_win = self.win

            dilation_rate = self.dilation
            frame_ids = np.arange(0, vid_len)
            if aug_win >= vid_len:
                aug_win = vid_len
            frame_ids = torch.from_numpy(frame_ids.astype(np.float32))
            Unfold = torch.nn.Unfold(kernel_size = (1, aug_win), dilation= (1, dilation_rate), padding=0, stride=(1,stride))
            unfold_frame = Unfold(frame_ids[None, None, None, :]).squeeze(0).transpose(1,0) # [num, win]
            num_win = unfold_frame.shape[0]
            unfold_frame = unfold_frame.numpy().astype(np.int64) # [n_win, win]

            for i in range(self.win_dis):
                for win_id in range(i, num_win, self.win_dis): # from start to end, find all windows in a win_dis
                    win_clip = unfold_frame[win_id, :]
                    win_start = win_clip[0] # win_start
                    win_end = win_clip[-1] # win_end
                    assert win_end + 1 - win_start == aug_win
                    clip_pair = {'st_frame': win_start,
                                 'end_frame': win_end,
                                 'video_id': key_id, 'aug_win': aug_win}
                    pair_group.append(clip_pair)
                    idx += 1 # mark the number in a group

                    if idx == num_in_group: # if the final group is not reach num_in_group, then ignore
                        pairs_collect.append(pair_group)
                        idx = 0
                        pair_group = []
                        a =1

            num_video += 1
        print('============================== Prepare Charades-Ego training dataset START! ============================= ')
        print('(Charades-Ego are from {} videos, the win length is {}, {} clip pairs in a group.)'.format( num_video, win, num_in_group))
        print('================== DONE for Charades-Ego ! {}:The number of aligned pairs  is {}. ====================='.format(self.fold, len(pairs_collect)*num_in_group))

        return pairs_collect


    ## for each sample, frame by frame
    def getitem(self, index):  # Try to use this for debugging purpose
        ele_dict = self.data_pairs[index] # a group data include
        num_group = len(ele_dict)
        elements_c_group = np.zeros((num_group, 2048, 40), dtype=np.float32)
        elements_h_group = np.zeros((num_group, 2048, 40), dtype=np.float32)
        elements_mask = np.zeros((num_group, 40))
        if isinstance(ele_dict, list):
            elements_c_list = []
            elements_h_list = []
            for idx, single_pair in enumerate(ele_dict):
                st_frame = single_pair['st_frame']
                end_frame = single_pair['end_frame']+1
                video_id = single_pair['video_id']
                win = single_pair['aug_win']

                exo_file_path = f'{self.fea_path}/{video_id}.npy'
                ego_file_path = f'{self.fea_path}/{video_id}EGO.npy'

                elements_c = np.load(exo_file_path)[st_frame:end_frame, : ].T
                elements_h = np.load(ego_file_path)[st_frame:end_frame, : ].T
                # if elements_c.shape[-1] != win or elements_h.shape[-1] != win:
                #     a = 1
                assert elements_c.shape[-1] == win and elements_h.shape[-1] == win

                elements_c_group[idx, :, :win] = elements_c
                elements_h_group[idx, :, :win] = elements_h
                elements_mask[idx, :win] = 1
            assert idx== num_group-1

            return elements_c_group, elements_h_group, elements_mask

        else:
            return None


    def __getitem__(self, index):
        return self.getitem(index)

    def __len__(self):
        return len(self.data_pairs)

