""" Dataset loader for the Charades dataset """
import csv
import json
import os

import numpy as np
import torch
import torch.utils.data as data
import copy
from ..helper_utils import list2tensorpad

class TVInteractionDataset(data.Dataset):
    def __init__(self, params, tokenizer):
        super().__init__()
        self.num_data_points_per_split = {}
        # load caption information for different splits
        self.subsets = ["train", "test", "trainval"]
        self.overfit = params["OVERFIT"]
        self.params = params
        self._split = "train"
        self._split_idx = 0

        self.num_classes = 4

        caption_paths = [
            params["GENERATED_CAPTIONS_PATH_TRAIN"],
            params["GENERATED_CAPTIONS_PATH_TEST"],
        ]

        self.labels = []
        self.data = []
        self.sample_ids = []
        self.root = params["DATA_ROOT"]
        self.annotation_root = os.path.join(self.root, "tv_human_interaction_annotations")
        self.video_path = os.path.join(self.root, "tv_human_interactions_videos")
        self.metadata = json.load(open(os.path.join(self.root, "metadata.json")))
        self.classes = self.metadata["classes"]
        self.ids = []
        self.data = []

        for subset_id, subset in enumerate(self.subsets[:-1]):

            cur_caption_path = caption_paths[subset_id]
            processed_captions = self.process_captions(cur_caption_path)
            # load all the captions and the label
            # populate ids
            self.ids.append([])
            for class_name, class_ids in self.metadata["splits"][subset].items():
                for cur_id in class_ids:
                    self.ids[subset_id].append("%s_%d"%(class_name, cur_id))
            self.data.append(processed_captions)
            self.num_data_points_per_split[subset] = len(self.ids[subset_id])

        self._tokenizer = tokenizer
        # trainval split
        trainval_ids = copy.deepcopy(self.ids[0]) + copy.deepcopy(self.ids[1])
        self.num_data_points_per_split["trainval"] = self.num_data_points_per_split["train"] + self.num_data_points_per_split["test"]
        self.ids.append(trainval_ids)
        trainval_data = copy.deepcopy(self.data[0])
        val_data = copy.deepcopy(self.data[1])
        trainval_data.update(val_data)
        self.data.append(trainval_data)        
        
    # self get object verb meta information
    def __len__(self):
        return self.num_data_points_per_split[self._split]

    def process_captions(self, caption_path, num_captions=3):

        with open(caption_path) as f_cap:
            captions = {}
            reader = csv.reader(f_cap, delimiter="\t")
            for row in reader:
                class_name, cur_id, image_id = row[0].split("_")
                id = class_name + "_" + cur_id
                if id not in captions:
                    captions[id] = {"captions": [], "ids": []}
                caption_dict = json.loads(row[1])[0]
                captions[id]["captions"].append(caption_dict["caption"])
                captions[id]["ids"].append(int(image_id))
            # sort by image ids
            for sample in captions:
                cur_captions = captions[sample]["captions"]
                ids = captions[sample]["ids"]
                cur_sorted_captions = [
                    x
                    for _, x in sorted(zip(ids, cur_captions), key=lambda pair: pair[0])
                ]
                sorted_ids = sorted(ids)
                # undersample based on parameter
                subsampled_indices = (
                    np.round(np.linspace(0, len(cur_sorted_captions) - 1, num_captions))
                    .astype(int)
                    .tolist()
                )
                sorted_ids = [sorted_ids[j] for j in subsampled_indices]
                cur_sorted_captions = [
                    cur_sorted_captions[j] for j in subsampled_indices
                ]

                captions[sample]["captions"] = cur_sorted_captions
                captions[sample]["ids"] = sorted_ids

        return captions

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        cur_data = self.data[self._split_idx]
        id = self.ids[self._split_idx][index]
        cur_sample = cur_data[id]

        cur_image_ids = cur_sample["ids"]
        target = self.classes.index(id.split("_")[0])
        cur_captions = cur_sample["captions"]
        # return tokenized captions; cap length; index; image ids; target classes
        # encode image ids
        captions_encoded, captions_length = self.tokenize(
            cur_captions, max_len=self.params["MAX_LEN_CAPTION"]
        )
        # print("target", target)
        item = {}
        item["captions"] = captions_encoded
        item["captions_length"] = captions_length
        item["target"] = target
        item["index"] = torch.Tensor([index]).long()
        item["image_ids"] = torch.Tensor(cur_image_ids).long()

        return item

    @property
    def split(self):
        return self._split

    @property
    def tokenizer(self):
        return self._tokenizer

    @split.setter
    def split(self, split):
        assert split in self.subsets
        self._split = split
        self._split_idx = self.subsets.index(split)

    def tokenize(self, strings, max_len):
        encoded_strings = []
        encoded_len = []
        for s in strings:
            encoded_string = self._tokenizer.encode(
                s, add_special_tokens=False, padding=False
            )[:max_len]
            encoded_strings.append(list2tensorpad(encoded_string, max_seq_len=max_len))
            encoded_len.append(torch.LongTensor([min(len(encoded_string), max_len)]))
        return torch.cat(encoded_strings, dim=0).long(), torch.cat(encoded_len, dim=0)

    """
    tensorized token: (batch x seq len)
    """

    def decode(self, tensorized_tokens, seqlen):
        assert tensorized_tokens.shape[0] == seqlen.shape[0]
        decoded_strings = []
        for i in range(tensorized_tokens.shape[0]):
            decoded_strings.append(
                self._tokenizer.decode(tensorized_tokens[i][: seqlen[i]].tolist())
            )
        return decoded_strings

    """
    Get list of urls given list of image ids in the Charades dataset.
    """

    def get_image_paths(self, index, image_ids):
        paths = []
        cur_id = self.ids[self._split_idx][index]
        class_name, id = cur_id.split("_")
        iddir = "%s_%04d"%(class_name, int(id))
        image_root = os.path.join(self.params["IMAGE_ROOT"], iddir)
        for i in image_ids:
            paths.append(os.path.join(image_root, "img_%04d.png"%i))
        return paths
