from typing import List, Callable, Tuple, Dict, Union
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import numpy as np


def preprocessing_windowing(data : List[pd.DataFrame], labels : List[pd.DataFrame], window_size : int):

    windowed_data, windowed_labels = [], []

    for x, y in zip(data, labels):
        for i in range(0, x.shape[0] - window_size):

            windowed_data.append(x[i:i+window_size])
            windowed_labels.append(y[i:i+window_size])

    return windowed_data, windowed_labels


class PWindowing(object):

    def __init__(self, window_length : int):

        super(PWindowing, self).__init__()

        self.window_length = window_length

    def __call__(self, data : List[pd.DataFrame], labels : List[pd.DataFrame]) -> Tuple[List[pd.DataFrame], List[pd.DataFrame]]:

        return preprocessing_windowing(data, labels, self.window_length)


def preprocessing_min_max_scaling(data : List[pd.DataFrame], labels : List[pd.DataFrame]):

    scaler = MinMaxScaler()

    for x in data:
        x[x.columns] = scaler.fit_transform(x[x.columns])

    return data, labels


class PMinMaxScaler(object):

    def __init__(self):
        super(PMinMaxScaler, self).__init__()

        self.min = None
        self.max = None

    def __call__(self, data : List[pd.DataFrame], labels : List[pd.DataFrame]) -> Tuple[List[pd.DataFrame], List[pd.DataFrame]]:

        if self.min is None:
            self.min = np.array([np.inf for _ in range(data[0].shape[1])])
            for x in data:
                self.min = np.minimum(self.min, x.min(axis=0).to_numpy())

            self.max = np.array([-np.inf for _ in range(data[0].shape[1])])
            for x in data:
                self.max = np.maximum(self.max, x.max(axis=0).to_numpy())

        scaled_data = []
        for x in data:
            scaled_data.append((x - self.min) / (self.max - self.min))

        return scaled_data, labels


def preprocessing_standardize(data : List[pd.DataFrame], labels : List[pd.DataFrame]):

    scaler = StandardScaler()

    for x in data:
        x[x.columns] = scaler.fit_transform(x[x.columns])

    return data, labels


class CompositePrepsocesser(object):

    def __init__(self, preprocessing_functions : List[Callable]):

        self.fns = preprocessing_functions

    def __call__(self, data : List[pd.DataFrame], labels : List[pd.DataFrame]):

        for fn in self.fns:
            data, labels = fn(data, labels)

        return data, labels


def update_statistics_increment(frame: pd.DataFrame, mean: np.ndarray = None, min_val: np.ndarray = None,
                                max_val: np.ndarray = None, old_n: int = None) \
        -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
    n = frame.shape[0]
    if mean is not None:
        mean = (old_n / (old_n + n)) * mean + n / (old_n + n) * frame.mean().to_numpy()
    else:
        mean = frame.mean().to_numpy()

    if min_val is not None:
        min_val = np.minimum(min_val, frame.min().to_numpy())
    else:
        min_val = frame.min().to_numpy()

    if max_val is not None:
        max_val = np.maximum(max_val, frame.max().to_numpy())
    else:
        max_val = frame.max().to_numpy()

    return mean, min_val, max_val, old_n + n


def save_statistics(frame: pd.DataFrame, path: str):
    mean = frame.mean().to_numpy()
    std = frame.std().to_numpy()
    min = frame.min().to_numpy()
    max = frame.max().to_numpy()
    median = frame.median().to_numpy()

    np.savez(path, mean=mean, std=std, min=min, max=max, median=median)


def minmax_scaler(frame: Union[pd.DataFrame, np.ndarray], stats: Dict[str, np.ndarray]) -> pd.DataFrame:
    min = stats['min']
    max = stats['max']
    range = max - min
    # Fix (near-)constant features
    constant_mask = range < 10 * np.finfo(np.float32).eps
    range[constant_mask] = 1

    return (frame - min) / range
