# Copyright 2021 The ODE-LSTM Authors. All Rights Reserved.

import numpy as np
import os


class Walker2dImitationData:
    def __init__(self, seq_len):
        self.seq_len = seq_len
        all_files = sorted(
            [
                os.path.join("data/walker", d)
                for d in os.listdir("data/walker")
                if d.endswith(".npy")
            ]
        )

        self.rng = np.random.RandomState(891374)
        np.random.RandomState(125487).shuffle(all_files)
        # 15% test set, 10% validation set, the rest is for training
        test_n = int(0.15 * len(all_files))
        valid_n = int((0.15 + 0.1) * len(all_files))
        test_files = all_files[:test_n]
        valid_files = all_files[test_n:valid_n]
        train_files = all_files[valid_n:]

        train_x, train_t, train_y = self._load_files(train_files)
        valid_x, valid_t, valid_y = self._load_files(valid_files)
        test_x, test_t, test_y = self._load_files(test_files)

        train_x, train_t, train_y = self.perturb_sequences(train_x, train_t, train_y)
        valid_x, valid_t, valid_y = self.perturb_sequences(valid_x, valid_t, valid_y)
        test_x, test_t, test_y = self.perturb_sequences(test_x, test_t, test_y)

        self.train_x, self.train_times, self.train_y = self.align_sequences(
            train_x, train_t, train_y
        )
        self.valid_x, self.valid_times, self.valid_y = self.align_sequences(
            valid_x, valid_t, valid_y
        )
        self.test_x, self.test_times, self.test_y = self.align_sequences(
            test_x, test_t, test_y
        )
        self.input_size = self.train_x.shape[-1]

        print("train_times: ", str(self.train_times.shape))
        print("train_x: ", str(self.train_x.shape))
        print("train_y: ", str(self.train_y.shape))

    def align_sequences(self, set_x, set_t, set_y):

        times = []
        x = []
        y = []
        for i in range(len(set_y)):

            seq_x = set_x[i]
            seq_t = set_t[i]
            seq_y = set_y[i]

            for t in range(0, seq_y.shape[0] - self.seq_len, self.seq_len // 4):
                x.append(seq_x[t : t + self.seq_len])
                times.append(seq_t[t : t + self.seq_len])
                y.append(seq_y[t : t + self.seq_len])

        return (
            np.stack(x, axis=0),
            np.expand_dims(np.stack(times, axis=0), axis=-1),
            np.stack(y, axis=0),
        )

    def perturb_sequences(self, set_x, set_t, set_y):

        x = []
        times = []
        y = []
        for i in range(len(set_y)):

            seq_x = set_x[i]
            seq_y = set_y[i]

            new_x, new_times = [], []
            new_y = []

            skip = 0
            for t in range(seq_y.shape[0]):
                skip += 1
                if self.rng.rand() < 0.9:
                    new_x.append(seq_x[t])
                    new_times.append(skip)
                    new_y.append(seq_y[t])
                    skip = 0

            x.append(np.stack(new_x, axis=0))
            times.append(np.stack(new_times, axis=0))
            y.append(np.stack(new_y, axis=0))

        return x, times, y

    def _load_files(self, files):
        all_x = []
        all_t = []
        all_y = []
        for f in files:

            arr = np.load(f)
            x_state = arr[:-1, :].astype(np.float32)
            y = arr[1:, :].astype(np.float32)

            x_times = np.ones(x_state.shape[0])
            all_x.append(x_state)
            all_t.append(x_times)
            all_y.append(y)

            print("Loaded file '{}' of length {:d}".format(f, x_state.shape[0]))
        return all_x, all_t, all_y





class PersonData:
    class_map = {
        "lying down": 0,
        "lying": 0,
        "sitting down": 1,
        "sitting": 1,
        "standing up from lying": 2,
        "standing up from sitting": 2,
        "standing up from sitting on the ground": 2,
        "walking": 3,
        "falling": 4,
        "on all fours": 5,
        "sitting on the ground": 6,
    }

    sensor_ids = {
        "010-000-024-033": 0,
        "010-000-030-096": 1,
        "020-000-033-111": 2,
        "020-000-032-221": 3,
    }

    def __init__(self, seq_len=32):

        self.seq_len = seq_len
        self.num_classes = 7
        all_x, all_t, all_y = self.load_crappy_formated_csv()
        all_x, all_t, all_y = self.cut_in_sequences(
            all_x, all_t, all_y, seq_len=seq_len, inc=seq_len // 2
        )

        print("all_x.shape: ", str(all_x.shape))
        print("all_t.shape: ", str(all_t.shape))
        print("all_y.shape: ", str(all_y.shape))
        total_seqs = all_x.shape[0]
        print("Total number of sequences: {}".format(total_seqs))
        permutation = np.random.RandomState(98841).permutation(total_seqs)
        test_size = int(0.2 * total_seqs)

        self.test_x = all_x[permutation[:test_size]]
        self.test_y = all_y[permutation[:test_size]]
        self.test_t = all_t[permutation[:test_size]]
        self.train_x = all_x[permutation[test_size:]]
        self.train_t = all_t[permutation[test_size:]]
        self.train_y = all_y[permutation[test_size:]]

        self.feature_size = int(self.train_x.shape[-1])

        print("train_x.shape: ", str(self.train_x.shape))
        print("train_t.shape: ", str(self.train_t.shape))
        print("train_y.shape: ", str(self.train_y.shape))
        print("Total number of train sequences: {}".format(self.train_x.shape[0]))
        print("Total number of test  sequences: {}".format(self.test_x.shape[0]))

    def load_crappy_formated_csv(self):

        all_x = []
        all_y = []
        all_t = []

        series_x = []
        series_t = []
        series_y = []

        last_millis = None
        if not os.path.isfile("data/person/ConfLongDemo_JSI.txt"):
            print("ERROR: File 'data/person/ConfLongDemo_JSI.txt' not found")
            print("Please execute the command")
            print("source download_dataset.sh")
            import sys

            sys.exit(-1)
        with open("data/person/ConfLongDemo_JSI.txt", "r") as f:
            current_person = "A01"

            for line in f:
                arr = line.split(",")
                if len(arr) < 6:
                    break
                if arr[0] != current_person:
                    # Enque and reset
                    series_x = np.stack(series_x, axis=0)
                    series_t = np.stack(series_t, axis=0)
                    series_y = np.array(series_y, dtype=np.int32)
                    all_x.append(series_x)
                    all_t.append(series_t)
                    all_y.append(series_y)
                    last_millis = None
                    series_x = []
                    series_y = []
                    series_t = []

                millis = np.int64(arr[2]) / (100 * 1000)
                # 100ms will be normalized to 1.0
                millis_mapped_to_1 = 10.0
                if last_millis is None:
                    elasped_sec = 0.05
                else:
                    elasped_sec = float(millis - last_millis) / 1000.0
                elasped = elasped_sec * 1000 / millis_mapped_to_1

                last_millis = millis
                current_person = arr[0]
                sensor_id = self.sensor_ids[arr[1]]
                label_col = self.class_map[arr[7].replace("\n", "")]
                feature_col_2 = np.array(arr[4:7], dtype=np.float32)
                # Last 3 entries of the feature vector contain sensor value

                # First 4 entries of the feature vector contain sensor ID
                feature_col_1 = np.zeros(4, dtype=np.float32)
                feature_col_1[sensor_id] = 1

                feature_col = np.concatenate([feature_col_1, feature_col_2])
                series_x.append(feature_col)
                series_t.append(elasped)
                series_y.append(label_col)

        return all_x, all_t, all_y

    def cut_in_sequences(self, all_x, all_t, all_y, seq_len, inc=1):

        sequences_x = []
        sequences_t = []
        sequences_y = []

        for i in range(len(all_x)):
            x, t, y = all_x[i], all_t[i], all_y[i]

            for s in range(0, x.shape[0] - seq_len, inc):
                start = s
                end = start + seq_len
                sequences_x.append(x[start:end])
                sequences_t.append(t[start:end])
                sequences_y.append(y[start:end])

        return (
            np.stack(sequences_x, axis=0),
            np.stack(sequences_t, axis=0).reshape([-1, seq_len, 1]),
            np.stack(sequences_y, axis=0),
        )


