#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: run_saliency_gt.py
# Created Date: Saturday, October 17th 2020, 6:31:38 pm
# Author: Chirag Raman
#
# Copyright (c) 2020 Chirag Raman
###

import argparse
from argparse import Namespace
from collections import OrderedDict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from pytorch_lightning import Trainer, seed_everything
from torch.distributions import Normal
from torch.utils.data import DataLoader

from data.datasets import SocialUnpairedContextDataset
from data.loader import BucketSampler, collate_unpaired_context, collate_seq2seq
from data.types import Seq2SeqSamples
from explainability.saliency import sequence_saliency


FEATURE_COLUMNS = ["qw", "qx", "qy", "qz", "tx", "ty", "tz", "speaking"]


def init_torch(seed: int) -> None:
    """ Initialise torch with a seed """
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def _get_df(
        df: pd.DataFrame, group_ids: List, frame_start: int, frame_end: int
    ) -> pd.DataFrame:
    """ Return the subset of the data frame for the group and frames

    frame_start is included, frame_end is excluded.

    """
    return df.loc[df.group_id.isin(group_ids)
                  & (df.frame >= frame_start)
                  & (df.frame < frame_end)]


def seqs_for_turn(df: pd.DataFrame, group_id: int, obs_start: int,
                  obs_end: int, fut_start: int, fut_end: int,
                  args: Namespace) -> List[Seq2SeqSamples]:
    """ Get seq2seq samples for a turn, sort, and return iterators """
    obs_df = _get_df(df, [group_id], obs_start, obs_end)
    fut_df = _get_df(df, [group_id], fut_start, fut_end)

    # Create dataset, pass in obs df as dummy context
    turn_set = SocialUnpairedContextDataset(obs_df, obs_df, args, FEATURE_COLUMNS, fut_df)
    turn_set.compute_samples(fix_future_len=True)

    # Configure the loaders and get the samples
    sampler = BucketSampler(turn_set, args.batch_size)
    loader = DataLoader(
        turn_set,
        batch_sampler=sampler,
        collate_fn=collate_unpaired_context
    )
    samples = next(iter(loader)).target

    # Create a dictionary for sequence pair by frame_start and sort
    seqs = zip(
        samples.observed.split(1, dim=1),
        samples.observed_start,
        samples.offset,
        samples.future.split(1, dim=1)
    )

    seq_dict = {}
    for observed, frame_start, offset, future in seqs:
        seq_dict[frame_start] = Seq2SeqSamples(
            key=(group_id,), observed_start=frame_start,
            observed=observed.squeeze(1), future_len=samples.future_len,
            offset=offset, future=future.squeeze(1)
        )
    seq_dict = OrderedDict(sorted(seq_dict.items()))

    return list(seq_dict.values())


def _repack(
        seq_samples: List[Seq2SeqSamples], gts: List[Normal]
    ) -> Tuple[Seq2SeqSamples, Normal]:
    """ Repack into format expected for saliency computation """
    inputs = collate_seq2seq(seq_samples)
    means = [n.loc for n in gts]
    stds = [n.scale for n in gts]
    m = torch.stack(means, dim=1).unsqueeze(0)
    s = torch.stack(stds, dim=1).unsqueeze(0)
    gt = Normal(m, s)
    return inputs, gt


def main() -> None:
    """ Run the main experiment """
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("-s", "--seed", type=int, default=1234,
                        help="seed for initializing pytorch")
    parser.add_argument("--csv", type=str,
                        help="filename of the sythetic dataset csv")
    parser.add_argument("--group_id", type=int,
                        help="group id of interest")
    parser.add_argument("--obs_window1", type=int, nargs=2,
                        help="window from which to compute target "
                             "observed sequences")
    parser.add_argument("--fut_window1", type=int, nargs=2,
                        help="window from which to compute target "
                             "future sequences")
    parser.add_argument("--obs_window2", type=int, nargs=2,
                        help="window from which to compute target "
                             "observed sequences")
    parser.add_argument("--fut_window2", type=int, nargs=2,
                        help="window from which to compute target "
                             "future sequences")
    #-----------------------
    parser = Trainer.add_argparse_args(parser)
    parser = SocialUnpairedContextDataset.add_dataset_specific_args(parser)
    args = parser.parse_args()

    # Initialize pytorch
    init_torch(args.seed)
    seed_everything(args.seed)

    # Setup the paths
    dataset_dir = (Path(__file__).resolve().parent.parent / "dataset/")

    # Configure the test data
    data_file = dataset_dir / args.csv
    df = pd.read_csv(str(data_file), sep=",")

    # Get sequences for both turns
    seqs1 = seqs_for_turn(df, args.group_id, args.obs_window1[0],
                          args.obs_window1[1], args.fut_window1[0],
                          args.fut_window1[1], args)

    seqs2 = seqs_for_turn(df, args.group_id, args.obs_window2[0],
                          args.obs_window2[1], args.fut_window2[0],
                          args.fut_window2[1], args)

    # Iterate over the individual sequences and compute the future
    # distributions for observed sequences in both turns
    turn1_gts = []
    turn2_gts = []
    for s1, s2 in zip(seqs1, seqs2):
        assert s1.offset == s2.offset
        if torch.eq(np.around(s1.observed, 4), np.around(s2.observed, 4)).all():
            # Construct Normal distributions from futures of both turns
            stacked_fut = torch.stack((s1.future, s2.future))
            mean = torch.mean(stacked_fut, dim=0)
            std = torch.std(stacked_fut, dim=0)
            # Add epsilon to make sure no zeros
            std += np.finfo(np.float32).eps
            fut_normal = Normal(mean, std)

            # Append the same distribution to both turns
            turn1_gts.append(fut_normal)
            turn2_gts.append(fut_normal)
        else:
            # Construct Normal distribution from futures of individual turns
            # fix to 1 / 2*np.pi to ensure entropy of 0.5
            # sigma = 1 / np.sqrt(2 * np.pi)
            sigma = 1e-10
            f1_normal = Normal(loc=s1.future,
                               scale=torch.full_like(s1.future, sigma))
            turn1_gts.append(f1_normal)

            f2_normal = Normal(loc=s2.future,
                               scale=torch.full_like(s2.future, sigma))
            turn2_gts.append(f2_normal)


    turn2_inp, turn2_gt = _repack(seqs2, turn2_gts)
    entropies, saliency, _ = sequence_saliency(turn2_inp, turn2_gt)
    values = zip(entropies.values(), saliency.values())
    map_dict = OrderedDict(zip(entropies.keys(), values))

    print("\nSALIENCY TURN 2 : [obs start - obs end], offset : entropy; saliency\n")
    for key, value in sorted(map_dict.items()):
        print(f"[{key[0]} - {key[0] + 27}], offset={key[1]} : {value[0]:.5f}; {value[1]:.5f}")


if __name__ == "__main__":
    main()
