# Original implementation: https://github.com/BorgwardtLab/Set_Functions_for_Time_Series
#
# Copyright 2020 Max Horn
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AD CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
# TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.N


from collections import defaultdict
from collections.abc import Sequence

import os
import json

import numpy as np
import tensorflow_datasets as tfds
import medical_ts_datasets

from tqdm.auto import tqdm


PROJECT_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
STATISTICS_DIR = os.path.join(PROJECT_DIR, "missing", "data", "statistics")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, required=True)
    parser.add_argument("-o", "--output",  type=str, default=None)
    parser.add_argument("-s", "--split",   type=str, default=tfds.Split.TRAIN)
    args = parser.parse_args()

    train_ds, info = tfds.load(
        args.dataset,
        split=(tfds.Split.TRAIN if args.split is None else args.split),
        as_supervised=True,
        with_info=True,
    )

    if args.output is None:
        args.output = os.path.join(STATISTICS_DIR, f"{info.name}_v{info.version}.json")

    os.makedirs(os.path.dirname(args.output), exist_ok=True)

    ts_sum_x = None
    ts_sum_sq_x = None
    ts_count = None

    demo_sum_x = None
    demo_sum_sq_x = None
    demo_count = None

    class_samples = defaultdict(lambda: 0)

    for ((demo, _, ts, _, _), label) in tqdm(tfds.as_numpy(train_ds), unit=" ex", desc="Computing", ncols=80):
        if isinstance(label, Sequence) or isinstance(label, np.ndarray):
            label = np.amax(label)

        class_samples[int(label)] += 1

        if ts_count is None:
            ts_count = np.zeros(ts.shape[-1])
            ts_sum_x = np.zeros(ts.shape[-1])
            ts_sum_sq_x = np.zeros(ts.shape[-1])

            demo_count = 0
            demo_sum_x = np.zeros(demo.shape[-1])
            demo_sum_sq_x = np.zeros(demo.shape[-1])

        else:
            ts_count += np.sum(np.isfinite(ts), axis=0)
            ts_sum_x += np.nansum(ts, axis=0)
            ts_sum_sq_x += np.nansum(ts ** 2, axis=0)

            demo_count += 1
            demo_sum_x += demo
            demo_sum_sq_x += demo ** 2

    num_total_samples = sum(class_samples.values())
    class_balance = {
        label: num_samples / num_total_samples
        for label, num_samples in class_samples.items()
    }

    eps = 1e-7

    ts_means = ts_sum_x / ts_count
    ts_stds = np.sqrt(
        1.0 / (ts_count - 1) * (ts_sum_sq_x - 2.0 * ts_sum_x * ts_means + ts_count * ts_means ** 2)
    )
    ts_stds[ts_stds < eps] = eps

    # Dont normalize categorical variables
    cat_ts = info.metadata["combined_categorical_indicator"]
    ts_means[cat_ts] = 0.
    ts_stds[cat_ts] = 1.

    demo_means = demo_sum_x / demo_count
    demo_stds = np.sqrt(
        1.0 / (demo_count - 1) * (demo_sum_sq_x - 2.0 * demo_sum_x * demo_means + demo_count * demo_means ** 2)
    )
    demo_stds[demo_stds < eps] = eps

    # Dont normalize categorical variables
    cat_demo = info.metadata["demographics_categorical_indicator"]

    demo_means[cat_demo] = 0.
    demo_stds[cat_demo] = 1.

    with open(args.output, "w") as f:
        json.dump({
            "series_means": ts_means.tolist(),
            "series_stds": ts_stds.tolist(),
            "demo_means": demo_means.tolist(),
            "demo_stds": demo_stds.tolist(),
            "class_balance": class_balance
        }, f, indent=4)
