from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import os
import glob
import torch
from torch.utils.data import Dataset
import numpy as np
import math


WH_INPUT = 224
SEED_NO = 2021

# Min = min step size is 8 * start_end(2) * 4
MIN_VIS_LEN = 64
# Max num keysteps 21 * start_end(2) * 4 + spare
MAX_VIS_LEN = 190

MAX_POS_LEN = 5000

FPS = 30

# the sentence sample cover nearby narrations. It includes some random others.
MAX_TEXT_LEN = MAX_VIS_LEN * 3

INTERMEDIATE_CLUES = 3
DISTANCE_BTW_FEAT = INTERMEDIATE_CLUES + 2

# MAX_VIS_LEN should be devidable by DISTANCE_BTW_FEAT
assert(MAX_VIS_LEN % DISTANCE_BTW_FEAT == 0, 'MAX_VIS_LEN should be devidable by DISTANCE_BTW_FEAT')

# amound of sequence to randomly trim
RANDOM_HEAD_CUT = 0.4
RANDOM_TAIL_CUT = 0.4


class OurDataset(Dataset):
    r"""Base class for our method.
    """
    def __init__(self, npz_root_path: str):
        super(OurDataset, self).__init__()
        list_of_files = glob.glob(os.path.join(npz_root_path, '*.npz'))
        self.file_list = sorted(list_of_files)

        self.internal_repeat = 1
        self.single_batch_len = len(self.file_list)

    def __len__(self):
        return len(self.file_list) * self.internal_repeat



class NpzDataset(OurDataset):
    def __init__(self, npz_root_path: str, is_train: bool = True,
                 with_text: bool = False, expected_batch_size: int = -1,
                 resample_lowerbound: float = 1.0, extract_ilp: bool = False, nextstep_pred: bool = False):
        super(NpzDataset, self).__init__(npz_root_path)

        if extract_ilp:
            additional_path = None
            if is_train and ('/train' in npz_root_path):
                additional_path = npz_root_path.replace('/train', '/val')
            elif (not is_train) and ('/val' in npz_root_path):
                additional_path = npz_root_path.replace('/val', '/train')
            if additional_path is not None:
                list_of_files = glob.glob(os.path.join(additional_path, '*.npz'))
                self.file_list.extend(sorted(list_of_files))

        torch.manual_seed(SEED_NO)
        np.random.seed(SEED_NO)
        self.nextstep_pred = nextstep_pred
        self.is_train = is_train
        self.with_text = with_text
        self.resample_lowerbound = resample_lowerbound

        self.internal_repeat = 1
        self.single_batch_len = len(self.file_list)
        if expected_batch_size > 0:
            self.internal_repeat = math.ceil((expected_batch_size * 1.0) / self.single_batch_len)


    def __getitem__(self, index):
        return self.get_single_item(index % self.single_batch_len)


    def get_single_item(self, index):
        # read the single file
        single_npz_file = np.load(self.file_list[index % self.single_batch_len])

        # read CLIP feature
        clip_seq = single_npz_file['clip'].astype(np.float32)
        clip_seq_len = len(clip_seq)
        # read labels
        label_seq = single_npz_file['label']
        # read availability mask
        availability = single_npz_file['availability']
        # read vid id
        vid_id = single_npz_file['vid_id']

        # read subtask index of each frame
        subtask_id_seq = single_npz_file['subtask_ids']


        # initial filtering
        start_end_frame = single_npz_file['start_end'].astype(np.int32) - 1
        # start from frame zero, then change it to zero
        start_end_frame[start_end_frame == -1] = 0
        # if it is out of bound, then move it to correct position
        start_end_frame[start_end_frame >= clip_seq_len] = clip_seq_len - 1

        # first filter out valid subtask ids only
        valid_subtask_ids = []
        for subtask_id, (start_idx, end_idx) in enumerate(start_end_frame):
            if (start_idx < 0) or (end_idx < 0):
                continue

            duration = end_idx - start_idx
            if duration <= 1: # 1 frame diff is mainly nothing
                continue
            valid_subtask_ids.append(subtask_id)


        if len(valid_subtask_ids) == 0:
            print(self.file_list[index])
            raise Exception('not expected, fix me')

        else:
            start_end_frame = start_end_frame[valid_subtask_ids]

            # sort the selected times to be sorted
            start_end_frame = start_end_frame[start_end_frame[:, 0].argsort()]
            if self.nextstep_pred:
                subtask_id_seq = np.zeros_like(subtask_id_seq)
                for idx, (start_idx, end_idx) in enumerate(start_end_frame[:-1]):
                    subtask_id_seq[start_idx:] = valid_subtask_ids[idx+1]
                subtask_id_seq[start_end_frame[-1][0]:] = -1

        # then first resample few subtask upon request
        if self.is_train:
            if self.nextstep_pred:
                num_valid_subtask = len(valid_subtask_ids)
                if num_valid_subtask >= 2:
                    sample_size = np.random.randint(1, high=num_valid_subtask, size=None, dtype=int)
                else:
                    sample_size = 1
                start_end_frame = start_end_frame[:sample_size]

            elif self.resample_lowerbound < 1.0:
                num_valid_subtask = len(valid_subtask_ids)
                sample_size = np.random.randint(int(num_valid_subtask * self.resample_lowerbound), high=num_valid_subtask+1, size=None, dtype=int)
                if sample_size == 0:
                    sample_size = 1
                subsampled_ids = np.random.choice(np.arange(num_valid_subtask), sample_size, replace=False)
                start_end_frame = start_end_frame[subsampled_ids]

            # subsample random sequence, but in sorted manner
            num_rough_seq_item = np.random.randint(low=MIN_VIS_LEN - 2 * len(start_end_frame), high=MAX_VIS_LEN - 2 * len(start_end_frame), size=1)[0]
            max_internal = num_rough_seq_item // len(start_end_frame)
            if max_internal <= INTERMEDIATE_CLUES:
                max_internal = INTERMEDIATE_CLUES

            # random sample
            frame_numbers = []
            mask_begin_set = set()
            mask_end_set = set()
            for start_idx, end_idx in start_end_frame:
                duration = end_idx - start_idx
                # duration is greater than 0
                if duration < FPS:
                    frame_numbers.append(np.array([start_idx+1, end_idx-1], dtype=np.int32))
                else:
                    # add internal
                    frame_numbers.append(np.random.choice(np.arange(start=start_idx+1, stop=end_idx), max_internal))
                # add start and end
                frame_numbers.append(np.array([start_idx, end_idx], dtype=np.int32))
                # mask frame numbers
                mask_begin_set.add(start_idx)
                mask_end_set.add(end_idx)
            frame_numbers = np.unique(np.concatenate(frame_numbers, axis=None), axis=None)
        else:
            frame_numbers = []
            mask_begin_set = set()
            mask_end_set = set()
            for start_idx, end_idx in start_end_frame:
                duration = end_idx - start_idx - 2
                # add all
                frame_numbers.append(start_idx)
                for idx in range(1, INTERMEDIATE_CLUES+1):
                    frame_numbers.append(start_idx + 1 + int(1.0 * idx * duration / (INTERMEDIATE_CLUES + 1.0)))
                frame_numbers.append(end_idx)
                # mask frame numbers
                mask_begin_set.add(start_idx)
                mask_end_set.add(end_idx)
            frame_numbers = np.sort(np.array(frame_numbers, dtype=np.int32), axis=None)

        begin_mask = np.array([_ in mask_begin_set for _ in frame_numbers], dtype=np.bool_)
        end_mask = np.array([_ in mask_end_set for _ in frame_numbers], dtype=np.bool_)

        # select subseq
        label_seq = label_seq[frame_numbers]
        subtask_id_seq = subtask_id_seq[frame_numbers]
        seq_len = len(frame_numbers)

        frame_numbers[frame_numbers <= FPS//2] = FPS//2
        frame_numbers[frame_numbers >= clip_seq_len - FPS // 2] = clip_seq_len - FPS // 2 - 1
        pooled_seq = [clip_seq[idx - FPS//2:idx + FPS//2].max(axis=0) for idx in frame_numbers]
        absolute_pos = (frame_numbers * MAX_POS_LEN) // clip_seq_len

        if seq_len < MAX_VIS_LEN:
            pooled_seq = np.pad(pooled_seq, [(0, MAX_VIS_LEN - seq_len), (0, 0)], mode='constant')
            label_seq = np.pad(label_seq, [(0, MAX_VIS_LEN - seq_len), (0, 0)], mode='constant')
            subtask_id_seq = np.pad(subtask_id_seq, [(0, MAX_VIS_LEN - seq_len)], mode='constant')
            absolute_pos = np.pad(absolute_pos, [(0, MAX_VIS_LEN - seq_len)], mode='constant')
            begin_mask = np.pad(begin_mask, [(0, MAX_VIS_LEN - seq_len)], mode='constant')
            end_mask = np.pad(end_mask, [(0, MAX_VIS_LEN - seq_len)], mode='constant')
        else:
            pooled_seq = pooled_seq[:MAX_VIS_LEN]
            label_seq = label_seq[:MAX_VIS_LEN]
            subtask_id_seq = subtask_id_seq[:MAX_VIS_LEN]
            absolute_pos = absolute_pos[:MAX_VIS_LEN]
            begin_mask = begin_mask[:MAX_VIS_LEN]
            end_mask = end_mask[:MAX_VIS_LEN]
            seq_len = MAX_VIS_LEN

        if not self.with_text:
            return {
                    'label': torch.from_numpy(label_seq).to(torch.float32),
                    'subtask_id': torch.from_numpy(subtask_id_seq).to(torch.int64),
                    'vis_feature': torch.from_numpy(pooled_seq),
                    'vis_seq_len': seq_len,
                    'absolute_pos': torch.from_numpy(absolute_pos).to(torch.int64),
                    'pred_mask': torch.from_numpy(availability).to(torch.float32),
                    'vid_id': vid_id,
                    'max_vid_idx': frame_numbers[-1],
                    'begin_mask': torch.from_numpy(begin_mask).to(torch.bool),
                    'end_mask': torch.from_numpy(end_mask).to(torch.bool),
                    }

        # Frames are chosen. Now we choose langauge counterpart
        # first to load the frame offsets
        sent_with_frame_idx = single_npz_file['text_clip'].astype(np.float32)

        num_sents = len(sent_with_frame_idx)
        idx_start_end = sent_with_frame_idx[:, 0:2].astype(np.int32)
        sent_features = sent_with_frame_idx[:, 2:]

        # random indexes to add
        random_idx = np.random.choice(num_sents, min(num_sents, MAX_TEXT_LEN - 3 * seq_len), replace=False)

        # relavant text asr idx is chosen here
        start_idx = (np.expand_dims(frame_numbers, -1) <= np.expand_dims(idx_start_end[:, 0], 0)).argmax(axis=1)
        end_idx = (np.expand_dims(frame_numbers, -1) <= np.expand_dims(idx_start_end[:, 1], 0)).argmax(axis=1)
        base_idx = np.stack([start_idx, end_idx], axis=0).min(axis=0)
        base_minus_1 = base_idx - 1
        base_minus_1[base_minus_1 < 0] = 0
        base_plus_1 = base_idx + 1
        base_plus_1[base_plus_1 >= num_sents] = num_sents - 1

        # input lanuage seq
        language_base_idx = np.unique(np.concatenate([base_minus_1, base_idx, base_plus_1, random_idx], axis=-1))

        # now pad language to the maximum size
        text_seq = sent_features[language_base_idx]
        text_seq_len = len(text_seq)

        if seq_len < MAX_VIS_LEN:
            vis_max_text_idx = np.pad(base_plus_1, [(0, MAX_VIS_LEN - seq_len)], mode='constant')
        else:
            vis_max_text_idx = base_plus_1[:MAX_VIS_LEN]


        if text_seq_len < MAX_TEXT_LEN:
            text_seq = np.pad(text_seq, [(0, MAX_TEXT_LEN - text_seq_len), (0, 0)], mode='constant')
            text_base_idx = np.pad(language_base_idx, [(0, MAX_TEXT_LEN - text_seq_len)], mode='constant')
        else:
            text_seq = text_seq[:MAX_TEXT_LEN]
            text_base_idx = language_base_idx[:MAX_TEXT_LEN]
            text_seq_len = MAX_TEXT_LEN

        return {
                'label': torch.from_numpy(label_seq).to(torch.float32),
                'subtask_id': torch.from_numpy(subtask_id_seq).to(torch.int64),
                'vis_feature': torch.from_numpy(pooled_seq),
                'vis_seq_len': seq_len,
                'text_feature': torch.from_numpy(text_seq),
                'text_seq_len': text_seq_len,
                'text_base_idx': torch.from_numpy(text_base_idx).to(torch.int64),
                'vis_max_text_idx': torch.from_numpy(vis_max_text_idx).to(torch.int64),
                'absolute_pos': torch.from_numpy(absolute_pos).to(torch.int64),
                'pred_mask': torch.from_numpy(availability).to(torch.float32),
                'vid_id': vid_id,
                'begin_mask': torch.from_numpy(begin_mask).to(torch.bool),
                'end_mask': torch.from_numpy(end_mask).to(torch.bool),
                }
