#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: sine_dataset.py
# Created Date: Saturday, November 16th 2019, 1:20:51 pm
# Author: <Anonymized for Review>
#
# Copyright (c) 2019 <Anonymized for Review>
###


import abc
import logging
from argparse import ArgumentParser, Namespace
from collections import OrderedDict
from functools import cached_property
from itertools import chain
from typing import Any, Callable, List, NamedTuple, OrderedDict, Sequence, Tuple, Union

import numpy as np
import pandas as pd
from torch.utils import data

from .types import Seq2SeqSamples, BucketType


class GroupSequence(NamedTuple):

    """ Encapsulate start, end, and features for a sequence """

    start: int
    end: int
    features: Any


class SequencePair(NamedTuple):

    """ Local structure to encapsulate a sequence pair """

    key: Tuple[int, int]
    obs_idx: int
    fut_idx: int
    obs_len: int
    fut_len: int
    offset: int


# Type aliases for bucket maps. See `compute_group_buckets` and `compute_sequence_buckets` for details
GroupBucketMap = OrderedDict[Tuple[int, int, int], OrderedDict[int, List]]
SeqBucketMap = OrderedDict[Tuple[int, int], OrderedDict[int, List]]
BucketMap = Union[GroupBucketMap, SeqBucketMap]


def _compute_buckets(pairs: Sequence[SequencePair], key_extractor: Callable) -> BucketMap:
    """ Helper function to compute bucket maps

    Args:
        pairs           --  The list of SequencePair objects to bucket
        key_extractor   --  A function that returns a key given a SequencePair

    Returns a BucketMap object (doesn't do any explicit type checking)
    """
    bucket_map = OrderedDict()
    for idx, pair in enumerate(pairs):
        key = key_extractor(pair)
        if key not in bucket_map:
            # Create a new dictionary for the key and add
            obs_map = OrderedDict()
            bucket_map[key] = obs_map
        else:
            # Access the map for the key
            obs_map = bucket_map[key]
        # Add the idx to the appropriate observed idx list
        if pair.obs_idx not in obs_map:
            obs_map[pair.obs_idx] = [idx]
        else:
            obs_map[pair.obs_idx].append(idx)
    return bucket_map


def compute_group_buckets(pairs: List[SequencePair]) -> GroupBucketMap:
    """ Compute the bucket map for the sequences

    Map a tuple of (group_id, obs_len, fut_len) to a map of observed index to
    list of indices
    """
    key_extractor = lambda pair: (pair.key[0], pair.obs_len, pair.fut_len)
    return _compute_buckets(pairs, key_extractor)


def compute_seq_buckets(pairs: List[SequencePair]) -> SeqBucketMap:
    """ Compute the bucket map for the sequences

    Map a tuple of (obs_len, fut_len) to a map of observed index to
    list of indices
    """
    key_extractor = lambda pair: (pair.obs_len, pair.fut_len)
    return _compute_buckets(pairs, key_extractor)


class SocialDatasetInterface(data.Dataset, metaclass=abc.ABCMeta):
    """ Abstract interface for a SocialDataset """
    @property
    @abc.abstractmethod
    def group_bucket_map(self) -> GroupBucketMap:
        """ Return a GroupBucketMap for aiding generation of batches """
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def seq_bucket_map(self) -> SeqBucketMap:
        """ Return a SeqBucketMap for aiding generation of batches """
        raise NotImplementedError


# Map for accessing bucket attribute name by type
BUCKET_ATTR_NAME = {BucketType.GROUP: "group_bucket_map", BucketType.SEQ: "seq_bucket_map"}


class SocialDataset(SocialDatasetInterface):

    """ Encapsulate the synthetic social dataset created in Blender3D """

    @staticmethod
    def group_fields() -> List:
        return ["group_id", "group_size"]

    def __init__(self, obs_df: pd.DataFrame, hparams: Namespace,
                 feature_fields: List, fut_df: pd.DataFrame = None) -> None:
        """ Initialize the dataset object

        Args:
            obs_df          --  Dataframe consisting of behavioral data for
                                groups comprising observed sequences
            hparams         --  Parameters to process the data.
                                Refer `add_dataset_specific_args`
            feature_fields  --  The list of column names for the features
            fut_df          --  Dataframe consisting of behavioral data for
                                groups comprising future sequences

        """
        assert (0 <= hparams.overlap < 1), (
            "overlap must be between 0 inclusive and 1 exclusive"
        )
        self.obs_len = hparams.observed_len
        self.future_len = hparams.future_len
        self.hparams = hparams
        self.obs_df = obs_df
        self.fut_df = obs_df if fut_df is None else fut_df
        assert (not self.obs_df.isnull().values.any()) and (not self.fut_df.isnull().values.any()), \
            "SocialDataset init; Either the observed or future dataframe has NAN values!"
        self.feature_fields = feature_fields
        # Compute the keys and make sure they are the same for observed
        # and future dfs
        obs_keys = obs_df.groupby(SocialDataset.group_fields()).groups.keys()
        fut_keys = self.fut_df.groupby(SocialDataset.group_fields()).groups.keys()
        assert (obs_keys == fut_keys), ("dataframes must contain same groups")
        self.group_keys = obs_keys
        # Map group keys to (observed sequences, future_seqs)
        self.group_seqs = None
        self.pairs = []

    def _construct_pairs(
            self, group_seqs: OrderedDict, fix_future_len=False
        ) -> OrderedDict[Tuple[str, int], SequencePair]:
        """ Construct observed, future pairs from all sequences

        Args:
            group_seqs      --  the OrderedDict mapping group keys to
                                (observed sequences, future_seqs) where the
                                sequences are of type List[GroupSequence]
            fix_future_len  --  restrict future sequences to future_len if True

        Returns
            An ordered dict mapping group key to the list of SequencePair
            objects for that group

        """
        logging.info("[*] Constructing observed-future pairs")
        stride = self.hparams.time_stride
        group_pairs = OrderedDict()
        for key in self.group_keys:
            logging.debug("Constructing seq pairs for group: {}".format(key))
            s_pairs = []
            obs_seqs, fut_seqs = group_seqs[key]
            logging.debug(f"Number of unique seqs for key {key}: obs - "
                         f"{len(obs_seqs)}, fut - {len(fut_seqs)}")
            for i, s_obs in enumerate(obs_seqs):
                for j, s_fut in enumerate(fut_seqs):
                    s_obslen = ((s_obs.end - s_obs.start) // stride) + 1
                    s_futlen = ((s_fut.end - s_fut.start) // stride) + 1
                    offset = s_fut.start - s_obs.end
                    logging.debug(
                        "OBS [{}-{}; len-{}], FUT [{}-{}; len-{}]".format(
                            s_obs.start, s_obs.end, s_obslen,
                            s_fut.start, s_fut.end, s_futlen)
                    )
                    # Add pair of obs sequence is of length self.obs_len
                    # and the future sequence starts after obs seq ends
                    predicate = (
                        (s_obslen == self.obs_len)
                        and (0 < offset <= self.hparams.max_future_offset)
                        and (s_futlen == self.future_len or not fix_future_len)
                        # Return true for the last check when filter_futures
                        # is False.
                    )
                    if predicate:
                        logging.debug(
                            "Adding pair: "
                            "OBS [{}-{}; len-{}], FUT [{}-{}; len-{}]".format(
                                s_obs.start, s_obs.end, s_obslen,
                                s_fut.start, s_fut.end, s_futlen)
                        )
                        s_pairs.append(
                            SequencePair(key, i, j, s_obslen, s_futlen, offset)
                        )
            group_pairs[key] = s_pairs
        return group_pairs

    def _sequences_for_group(
            self, key: Tuple[int, int], df: pd.DataFrame,
            seq_len: int, overlap: float
        ) -> List[GroupSequence]:
        """ Compute samples for a unique group (assumes contiguous chunk of data) """
        # key is a tuple of (group_id, group_size)
        _, gsize = key
        nrows = np.arange(len(df))
        # Minimum step size is equal to 1 timestep or gsize rows
        step_size = max(int(np.rint(((1 - overlap) * seq_len) * gsize)), gsize)
        nseq_rows = seq_len * gsize
        starts = nrows[::step_size]
        bounds = list(zip(starts, starts+nseq_rows))
        s_data = lambda b: df.iloc[b[0]:b[1]]
        # A sequence comprises of start_frame, end_frame, and data
        seqs = []
        for s in map(s_data, bounds):
            if len(s) >= nseq_rows: # Check there are sufficient rows
                values = s[self.feature_fields].values.astype(np.float32)
                values = values.reshape(-1, gsize, values.shape[-1])
                gseq = GroupSequence(
                    s["frame"].iloc[0], s["frame"].iloc[-1], values
                )
                seqs.append(gseq)
        return seqs

    def _compute_samples_for_df(
            self, df: pd.DataFrame, seq_len: int, overlap: float
        ) -> List[GroupSequence]:
        """ Compute sequences for specific dataframe """
        stride = self.hparams.time_stride
        # Sample every `stride` frames, starting from the frame in the
        # first row of the dataframe
        dfs = df[(df.frame - int(df.frame.iloc[0])) % stride == 0]
        seq_dict = OrderedDict()
        for g in dfs.groupby(SocialDataset.group_fields()):
            key, group_df = g
            seqs_for_group = []
            # Process contiguous chunks
            predicate = (group_df.frame - group_df.frame.shift() > stride)
            for _, chunk in group_df.groupby(predicate.cumsum()):
                seqs = self._sequences_for_group(key, chunk, seq_len, overlap)
                seqs_for_group.extend(seqs)
            seq_dict[key] = seqs_for_group
        return seq_dict

    def compute_samples(self, fix_future_len=False) -> None:
        """ Constructs the sample sequences from the dataframe """
        overlap = self.hparams.overlap
        obs_dict = self._compute_samples_for_df(self.obs_df, self.obs_len, overlap)
        overlap = 1 if self.hparams.all_futures else overlap
        fut_dict = self._compute_samples_for_df(self.fut_df, self.future_len, overlap)
        self.group_seqs = OrderedDict(
            {k:(obs_dict.get(k), fut_dict.get(k)) for k in self.group_keys}
        )
        # Construct (obs, future) pairs for each group
        group_pairs = self._construct_pairs(self.group_seqs, fix_future_len)
        self.pairs = list(chain(*group_pairs.values()))

    @cached_property
    def group_bucket_map(self) -> GroupBucketMap:
        """ Override SocialDatasetInterface.group_bucket_map()

        Map a tuple of (group_id, obs_len, fut_len) to a map of observed idx to sample indices
        """
        return compute_group_buckets(self.pairs)

    @cached_property
    def seq_bucket_map(self) -> SeqBucketMap:
        """ Override SocialDatasetInterface.seq_bucket_map()

        Map a tuple of (obs_len, fut_len) to a map of observed idx to sample indices
        """
        return compute_seq_buckets(self.pairs)

    def _convert_pair(self, pair: SequencePair, seqs: OrderedDict) -> Seq2SeqSamples:
        """ Convert a SequencePair to Seq2SeqSamples """
        # key is a tuple of (group_id, group_size)
        obs = seqs[pair.key][0][pair.obs_idx]
        fut = seqs[pair.key][1][pair.fut_idx]

        return Seq2SeqSamples(
            key=pair.key,
            observed_start=obs.start,
            observed=obs.features,
            future_len=fut.features.shape[0],
            offset=pair.offset,
            future=fut.features
        )

    def __getitem__(self, idx: int) -> Seq2SeqSamples:
        """ Returns Seq2SeqSamples for a single group

        Each tensor in the sample is of shape (seq_len, npeople, data_dim)

        """
        pair = self.pairs[idx]
        return self._convert_pair(pair, self.group_seqs)

    def __len__(self) -> int:
        return len(self.pairs)

    @staticmethod
    def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        """ Add args pertaining to the model and training of the process """
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--observed_len", type=int, default=10,
                            help="number of observed timesteps (# rows from the dataframe)")
        parser.add_argument("--future_len", type=int, default=10,
                            help="number of future timesteps to predict (# rows from the dataframe)")
        parser.add_argument("--time_stride", type=int, default=1,
                            help="sampling rate of rows in the dataset, expressed as multiple of frame difference "
                                 "between rows; for eg."
                                 "for a dataset where each row is 20 frames apart, a stride of 60 means every third "
                                 "row will be taken.")
        parser.add_argument("--overlap", type=float, default=0.8,
                            help="Overlap between observed sequences [0, 1)")
        parser.add_argument("--all_futures", default=False, action="store_true",
                            help="Take all future sequences within max offset "
                                 "instead of applying overlap")
        parser.add_argument("--max_future_offset", type=int, default=150,
                            help="maximum offset in frame values between the end of the observed "
                            "and beginning of the future sequence, inclusive")
        parser.add_argument("--batch_size", type=int, default=128,
                            help="size of the mini-batch")
        parser.add_argument("--fix_future_len", default=False,
                            action="store_true",
                            help="discard future sequences shorter than "
                            "specified future length")

        return parser


class SocialUnpairedContextDataset(SocialDataset):

    """ Compute seq2seq samples with past sequences as context

    The context sequences are not split into observed and future pairs. This
    is meant to be used for the SocialProcess model where the context future
    sequences are not used for encoding the latent representations. Hence, only
    unique context sequences are required. In the case where multiple futures
    are possible for a given observed sequence, splitting the context into
    (observed, future) pairs would result in duplicate observed sequences which
    is not desirable in this case. Refer `SocialPairedContextDataset` if
    (observed, future) pairs for the context sequences is desirable.

    """

    def __init__(
            self, obs_df: pd.DataFrame, ctx_df: pd.DataFrame,
            hparams: Namespace, feature_fields: List,
            fut_df: pd.DataFrame = None
        ) -> None:
        """ Initialize the dataset object

        Args:
            obs_df          --  Dataframe consisting of behavioral data for
                                groups comprising observed sequences for which
                                to predict the future
            ctx_df          --  Dataframe consisting of past behavioral data
                                for groups comprising the context at evaluation
            hparams         --  Parameters to process the data.
                                Refer `add_dataset_specific_args`
            feature_fields  --  The list of column names for the features
            fut_df          --  Dataframe consisting of behavioral data for
                                groups comprising future sequences

        """
        super().__init__(obs_df, hparams, feature_fields, fut_df)
        self.ctx_df = ctx_df
        ctx_keys = ctx_df.groupby(SocialDataset.group_fields()).groups.keys()
        assert (ctx_keys == self.group_keys), (
            "dataframes must contain same groups"
        )
        # Map group ids to list of context seqeuences
        self.ctx_data = OrderedDict()

    def compute_samples(self, fix_future_len=False) -> None:
        """ Compute observed, future target pairs and context sequences """
        ctx_seqs = self._compute_samples_for_df(self.ctx_df, self.obs_len,
                                                self.hparams.overlap)
        for key in ctx_seqs:
            # Filter context sequences to make sure they're of the same length
            seqs = ctx_seqs[key]
            filtered = [s for s in seqs if s.features.shape[0] ==  self.obs_len]
            ctx_seq_data = [s.features for s in filtered]
            ctx = np.stack(ctx_seq_data, axis=1)
            self.ctx_data[key] = ctx
        super().compute_samples(fix_future_len=fix_future_len)

    def __getitem__(self, idx: int) -> Tuple[Seq2SeqSamples, np.ndarray]:
        """ Returns Seq2SeqSamples and corresponding context sequences

        Each tensor in the sample is of shape (seq_len, group_size, data_dim)

        Returns:
            A tuple of the Seq2Seq samples and context data of shape
            (seq_len, ctx_size, group_size, data_dim)
        """
        samples = super().__getitem__(idx)
        context = self.ctx_data[samples.key]
        return (samples, context)


class SocialPairedContextDataset(SocialDataset):

    """ Compute evaluation seq2seq samples with past sequences as context

    The context sequences are split into observed and future pairs. Refer to
    `SocialEvalDataset` for a discussion about this.

    """

    def __init__(
            self, obs_df: pd.DataFrame, ctx_df: pd.DataFrame,
            hparams: Namespace, feature_fields: List,
            fut_df: pd.DataFrame = None
        ) -> None:
        """ Initialize the dataset object

        Args:
            obs_df          --  Dataframe consisting of behavioral data for
                                groups comprising observed sequences for which
                                to predict the future
            ctx_df          --  Dataframe consisting of past behavioral data
                                for groups comprising the context at evaluation
            hparams         --  Parameters to process the data.
                                Refer `add_dataset_specific_args`
            feature_fields  --  The list of column names for the features
            fut_df          --  Dataframe consisting of behavioral data for
                                groups comprising future sequences

        """
        super().__init__(obs_df, hparams, feature_fields, fut_df)
        self.ctx_df = ctx_df
        ctx_keys = ctx_df.groupby(SocialDataset.group_fields()).groups.keys()
        assert (ctx_keys == self.group_keys), (
            "dataframes must contain same groups"
        )
        # Map group ids to a single Seq2SeqSamples of all context sequences
        # for that group
        self.group_ctx = OrderedDict()

    def compute_samples(self, fix_future_len=False) -> None:
        """ Compute observed, future target pairs and context sequences """
        # Compute context sequences
        overlap = self.hparams.overlap
        ctx_obs = self._compute_samples_for_df(self.ctx_df, self.obs_len, overlap)
        overlap = 1 if self.hparams.all_futures else overlap
        ctx_fut = self._compute_samples_for_df(self.ctx_df, self.future_len, overlap)
        ctx_seqs = OrderedDict(
            {k:(ctx_obs.get(k), ctx_fut.get(k)) for k in self.group_keys}
        )
        # Construct context (obs, future) pairs for each group
        # ctx_pairs is an ordered dict of group key to list of SequencePair
        # ensuring all observed and future sequences are of same length
        ctx_pairs = self._construct_pairs(ctx_seqs, fix_future_len=True)
        # Collate all context pairs into a single Seq2SeqSample for each group
        for key in ctx_pairs:
            samples = [self._convert_pair(p, ctx_seqs) for p in ctx_pairs[key]]
            self.group_ctx[key] = samples
        super().compute_samples(fix_future_len=fix_future_len)

    def __getitem__(self, idx: int) -> Tuple[Seq2SeqSamples, List[Seq2SeqSamples]]:
        """ Returns Seq2SeqSamples and corresponding context sequences

        Each tensor in the sample is of shape (seq_len, group_size, data_dim)

        Returns:
            A tuple of the target and context Seq2Seq samples.
            The sequence of context observed and future pairs

        """
        samples = super().__getitem__(idx)
        context = self.group_ctx[samples.key]
        return (samples, context)
