import numpy as np
import pandas as pd
from glob import glob
from multiprocessing import Pool


def rolling_window(data, seq_len):
    """
    Input: (N, T, D) where N is the observed sequences, T is the sequence length and D is the number of features
    Output: 
    """
    rolled_data = []
    for sequence in data:
        for i in range(0, len(sequence) - seq_len + 1):
            rolled_data.append(sequence[i:i+seq_len])

    rolled_data = np.stack(rolled_data, axis=0)
    return rolled_data


def minmaxscale(data, min_v=0, max_v=1):
    if isinstance(data, list):
        """
        (N, T, D), where T might constant or not
        """
        d_min = np.min([np.min(a, axis=0) for a in data], axis=0)
        d_max = np.max([np.max(a, axis=0) for a in data], axis=0)

        scaled = []

        for arr in data:
            scaled.append(((arr - d_min) / (d_max - d_min)) * (max_v - min_v) + min_v)

        return scaled
    elif isinstance(data, np.ndarray) and len(data.shape) == 2:
        """
        (N, T, D), where T might be constant
        """
        data = np.expand_dims(data, 0)
        d_min = np.min(data, axis=(0, 1), keepdims=True)
        d_max = np.max(data, axis=(0, 1), keepdims=True)
        return ((data - d_min) / (d_max - d_min)) * (max_v - min_v) + min_v
    elif isinstance(data, np.ndarray) and len(data.shape) == 3:
        """
        (N, T, D), where T might be constant
        """
        d_min = np.min(data, axis=(0, 1), keepdims=True)
        d_max = np.max(data, axis=(0, 1), keepdims=True)
        return ((data - d_min) / (d_max - d_min)) * (max_v - min_v) + min_v
    else:
        raise NotImplementedError("Could not figure out how to min-max scale the data")


def load_dataset_from_str(dataset_str, prefix):
    if dataset_str == 'geolife-25':
        data = load_geolife_25(prefix)
    elif dataset_str == 'geolife-100':
        data = load_geolife_100(prefix)
    elif dataset_str == 'porto-25':
        data = load_porto_25(prefix)
    elif dataset_str == 'porto-100':
        data = load_porto_100(prefix)
    else:
        raise NotImplementedError("Dataset: " + str(dataset_str) + " is not a supported dataset.")

    return data

def load_geolife_25(prefix):
    return np.load(prefix + 'preprocessed/geolife-25-train.npy')

def load_geolife_100(prefix):
    return np.load(prefix + 'preprocessed/geolife-100-train.npy')

def load_porto_25(prefix):
    return np.load(prefix + 'preprocessed/porto-25-train.npy')

def load_porto_100(prefix):
    return np.load(prefix + 'preprocessed/porto-100-train.npy')