import argparse
import copy
import json
import os
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.utils.data
from os import listdir


def load_json(file):
    with open(file) as json_file:
        data = json.load(json_file)
        return data

def load_txt(file):
    with open(file) as f:
        lines = f.readlines()
        return lines
        
class VideoRecord:
    def __init__(self, vid, num_frames, locations, gt, fps, args):
        self.id = vid
        self.locations = locations
        self.base = float(locations[0])
        self.window_size = args.window_size
        self.interval = args.interval
        self.locations_norm = [
            (i - self.base) / (self.window_size * self.interval)
            for i in locations
        ]
        self.locations_offset = [
            location - self.base for location in locations
        ]
        self.num_frames = num_frames
        self.absolute_position = args.absolute_position

        self.gt = gt
        self.gt_norm = copy.deepcopy(gt)

        # normalize gt start and end
        for i in self.gt_norm:
            i[0][0] = (i[0][0] - self.base) / (self.window_size * self.interval)
            i[0][1] = (i[0][1] - self.base) / (self.window_size * self.interval)

        self.gt_s_e_frames = [i[0] for i in self.gt_norm]
        self.fps = fps
        self.duration = num_frames / fps


class ThumosDetection(torch.utils.data.Dataset):
    def __init__(self, feature_folder, skeleton_folder, elements_folder, movements_folder, anno_file, split, args):
        annotations = load_json(anno_file)
        video_list = annotations.keys()
        self.window_size = args.window_size
        self.feature_folder = feature_folder
        self.skeleton_folder = skeleton_folder
        self.elements_folder = elements_folder
        self.movements_folder = movements_folder
        self.anno_file = load_json(anno_file)
        self.num_gt = args.gt_size
        if split == 'val':
            self.split = 'test'
        else:
            self.split = 'val'
        self.video_dict = {}
        video_pool = list(self.anno_file.keys())
        video_pool.sort()
        self.video_dict = {video_pool[i]: i for i in range(len(video_pool))}

        self.video_list = []
        for vid in video_list:
            if self.split in vid:
                num_frames = int(self.anno_file[vid]['duration_frame'])
                fps = int(self.anno_file[vid]['fps'])
                annotations = [
                    item['segment_frame']
                    for item in self.anno_file[vid]['annotations']
                ]
                labels = [
                    int(item['label'])
                    for item in self.anno_file[vid]['annotations']
                ]

                frames = np.expand_dims(np.array(range(0, num_frames, args.interval)), 1)

                seq_len = len(frames)
                if seq_len <= self.window_size:
                    locations = np.zeros((self.window_size, 1))
                    locations[:seq_len, :] = frames
                    gt = [(annotations[idx], labels[idx])
                          for idx in range(len(annotations))]
                    self.video_list.append(
                        VideoRecord(vid, num_frames, locations, gt, fps, args))
                else:
                    if self.split == 'test':
                        overlap_ratio = 2
                    else:
                        overlap_ratio = 4
                    stride = self.window_size // overlap_ratio
                    ws_starts = [
                        i * stride
                        for i in range((seq_len // self.window_size - 1) *
                                       overlap_ratio + 1)
                    ]
                    ws_starts.append(seq_len - self.window_size)

                    for ws in ws_starts:
                        locations = frames[ws:ws + self.window_size]
                        gt = []
                        for idx in range(len(annotations)):
                            anno = annotations[idx]
                            label = labels[idx]
                            if anno[0] >= locations[0] and anno[
                                    1] <= locations[-1]:
                                gt.append((anno, label))
                        if self.split == 'test':
                            self.video_list.append(
                                VideoRecord(vid, num_frames, locations, gt, fps, args))
                        elif len(gt) > 0:
                            self.video_list.append(
                                VideoRecord(vid, num_frames, locations, gt, fps, args))
        print(split, len(self.video_list))

    def get_data(self, video: VideoRecord):
        '''
        :param VideoRecord
        :return vid_name,
        locations : [N, 1],
        all_props_feature: [N, ft_dim + 2 + pos_dim],
        (gt_start_frame, gt_end_frame): [num_gt, 2]
        '''

        vid = video.id
        num_frames = video.num_frames
        base = video.base

        og_locations = torch.Tensor([location for location in video.locations])

        vid_feature = torch.load(os.path.join(self.feature_folder, vid))
        vid_skels = np.load(os.path.join(self.skeleton_folder, vid) + '.npy', allow_pickle=True).item()
        vid_skeleton = np.mean([np.array(list(person.values())).reshape(-1, 39) for person in vid_skels.values()], axis=0)
        vid_elements = torch.load(os.path.join(self.elements_folder, vid))
        vid_movements = torch.load(os.path.join(self.movements_folder, vid))

        ft_idxes = [min(torch.floor_divide(i, 8), len(vid_feature) - 1) for i in og_locations]
        snippet_fts = []
        snippet_skl = []
        snippet_ele = []
        snippet_mov = []
        for i in ft_idxes:
            i = int(i)
            snippet_fts.append(vid_feature[i].squeeze())
            snippet_skl.append(vid_skeleton[i].squeeze())
            snippet_ele.append(vid_elements[:,i].squeeze())
            snippet_mov.append(vid_movements[:,i].squeeze())

        snippet_fts = torch.from_numpy(np.stack(snippet_fts))
        snippet_skl = torch.from_numpy(np.stack(snippet_skl))
        snippet_ele = torch.from_numpy(np.stack(snippet_ele))
        snippet_mov = torch.from_numpy(np.stack(snippet_mov))

        assert snippet_fts.shape == (self.window_size, 2048), print(snippet_fts.shape)

        if video.absolute_position:
            locations = torch.Tensor([location for location in video.locations])
        else:
            locations = torch.Tensor([location for location in video.locations_offset])

        gt_s_e_frames = [(s, e, 0) for (s, e) in video.gt_s_e_frames]
        for (s, e, _) in gt_s_e_frames:
            assert s >= 0 and s <= 1 and e >= 0 and e <= 1, '{} {}'.format(s, e)

        targets = {
            'labels': [],
            'boxes': [],
            'video_id': torch.Tensor([self.video_dict[vid]])
        }
        for (start, end, label) in gt_s_e_frames:
            targets['labels'].append(int(label))
            targets['boxes'].append((start, end))

        targets['labels'] = torch.LongTensor(targets['labels'])

        targets['boxes'] = torch.Tensor(targets['boxes'])

        return vid, locations, snippet_fts, snippet_skl, snippet_ele, snippet_mov, targets, num_frames, base

    def __getitem__(self, idx):
        return self.get_data(self.video_list[idx])

    def __len__(self):
        return len(self.video_list)


def collate_fn(batch):
    vid_name_list, target_list, num_frames_list, base_list = [[] for _ in range(4)]
    batch_size = len(batch)
    max_props_num = batch[0][1].shape[0]
    ft_dim = batch[0][2].shape[-1]
    skl_dim = batch[0][3].shape[-1]
    ele_dim = batch[0][4].shape[-1]
    mov_dim = batch[0][5].shape[-1]
    snippet_fts = torch.zeros(batch_size, max_props_num, ft_dim)
    snippet_skl = torch.zeros(batch_size, max_props_num, skl_dim)
    snippet_ele = torch.zeros(batch_size, max_props_num, ele_dim)
    snippet_mov = torch.zeros(batch_size, max_props_num, mov_dim)
    locations = torch.zeros(batch_size, max_props_num, 1, dtype=torch.double)

    for i, sample in enumerate(batch):
        vid_name_list.append(sample[0])
        locations[i, :max_props_num, :] = sample[1].reshape((-1, 1))
        snippet_fts[i, :max_props_num, :] = sample[2]
        snippet_skl[i, :max_props_num, :] = sample[3]
        snippet_ele[i, :max_props_num, :] = sample[4]
        snippet_mov[i, :max_props_num, :] = sample[5]
        target_list.append(sample[6])
        num_frames_list.append(sample[7])
        if (sample[7] is not None):
            base_list.append(sample[8])

    num_frames = torch.from_numpy(np.array(num_frames_list))
    base = torch.from_numpy(np.array(base_list))

    return vid_name_list, locations, snippet_fts, snippet_skl, snippet_ele, snippet_mov, target_list, num_frames, base


def build(split, args):
    # split = train/val
    root = Path(args.feature_path)
    assert root.exists(
    ), f'provided thumos14 feature path {root} does not exist'
    feature_folder = root
    skeleton_folder = Path(args.skeleton_path)
    elements_folder = Path(args.elements_path)
    movements_folder = Path(args.movements_path)
    anno_file = Path(args.annotation_path)

    dataset = ThumosDetection(feature_folder, skeleton_folder, elements_folder, movements_folder, anno_file, split, args)
    return dataset


def get_args_parser():
    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
    parser.add_argument('--batch_size', default=2, type=int)

    # dataset parameters
    parser.add_argument('--dataset_file', default='thumos14')
    parser.add_argument('--window_size', default=100, type=int)
    parser.add_argument('--gt_size', default=100, type=int)
    parser.add_argument('--feature_path', default='/data1/tj/thumos_2048/', type=str)
    parser.add_argument('--tem_path', default='/data1/tj/BSN_share/output/TEM_results', type=str)
    parser.add_argument('--annotation_path', default='thumos14_anno_action_v0.json', type=str)
    parser.add_argument('--remove_difficult', action='store_true')

    parser.add_argument('--num_workers', default=2, type=int)

    return parser
