import os
import math
import pickle
import numpy as np

import matplotlib
from matplotlib import pyplot as plt


def gauss(n=11,sigma=1):

    """Gaussian kernel to smoothen out the distribution"""

    r = range(-int(n/2),int(n/2)+1)
    return [1 / (sigma * math.sqrt(2*math.pi)) * math.exp(-float(x)**2/(2*sigma**2)) for x in r]



def estimate_kalman_filter(history, prediction_horizon):
    """
    Predict the future position by running the kalman filter.

    :param history: 3d array of shape (batch, length_of_history, 2)
    :param prediction_horizon: how many steps in the future to predict
    :return: the predicted position (x, y)
    """

    #history = history[~np.isnan(history).any(axis=2)]
    batch_size = history.shape[0]
    length_history = history.shape[1]

    z_x = history[:]
    v_x = 0

    v_x = np.sum(z_x[:, 1:length_history] - z_x[:, :(length_history-1)], axis=1)
    v_x = v_x / (length_history - 1)

    x_x = np.zeros((batch_size,length_history + 1), np.float32)
    P_x = np.zeros((batch_size,length_history + 1), np.float32)
    P_vx = np.zeros((batch_size,length_history + 1), np.float32)

    # we initialize the uncertainty to one (unit gaussian)
    P_x[:, 0] = 1.0
    P_vx[:, 0] = 1.0
    x_x[:, 0] = z_x[:, 0]

    Q = 0.00001
    R = 0.0001

    K_x = np.zeros((batch_size,length_history + 1), np.float32)
    K_vx = np.zeros((batch_size,length_history + 1), np.float32)

    for k in range(length_history - 1):
        x_x[:, k + 1] = x_x[:, k] + v_x
        P_x[:, k + 1] = P_x[:, k] + P_vx[:, k] + Q
        P_vx[:, k + 1] = P_vx[:, k] + Q

        K_x[:, k + 1] = P_x[:, k + 1] / (P_x[:, k + 1] + R)
        x_x[:, k + 1] = x_x[:, k + 1] + K_x[:, k + 1] * (z_x[:, k + 1] - x_x[:, k + 1])
        P_x[:, k + 1] = P_x[:, k + 1] - K_x[:, k + 1] * P_x[:, k + 1]

        K_vx[:, k + 1] = P_vx[:, k + 1] / (P_vx[:, k + 1] + R)
        P_vx[:, k + 1] = P_vx[:, k + 1] - K_vx[:, k + 1] * P_vx[:, k + 1]

    predsX = np.zeros((batch_size,prediction_horizon))

    k = k + 1

    for i in range(1, prediction_horizon+1):
        predsX[:, i-1] = x_x[:, k] + v_x * prediction_horizon

    return predsX


if __name__ == '__main__':

    data_path = 'data/traffic/'

    data_name = 'traffic'

    # Load the data

    train_len = 168
    pred_len = 24

    train_x = np.load(os.path.join(data_path, f'train_data_{data_name}.npy'))[:, :train_len, 0]
    train_v = np.load(os.path.join(data_path, f'train_v_{data_name}.npy'))
    train_y = np.load(os.path.join(data_path, f'train_label_{data_name}.npy'))[:, train_len:]

    print('Data loaded', train_x.shape, train_v.shape, train_y.shape)

    # Run Kalman filter and compute MAE

    kal_res = estimate_kalman_filter(train_x, pred_len)
    kal_res = np.tile(train_v[:, 0], reps=(pred_len, 1)).transpose() * kal_res

    all_mae = np.mean(np.abs(train_y - kal_res), axis=1)
    pickle.dump(all_mae, open('kal_mae.pb', 'wb'))

    print('Kalman filter MAE generation complete')

    # Run LDS on label distribution and dump map

    all_labels = train_y[:, -pred_len:].flatten()
    all_labels = all_labels[all_labels > 0]

    counts, bins, _ = plt.hist(all_labels, bins=100)
    counts = counts / np.sum(counts)
    plt.close()

    gauss_filt = gauss(5, 2)
    gauss_labels = np.convolve(counts, gauss_filt, mode='same')

    base_offset = 0.001
    gauss_labels_off = (gauss_labels + base_offset) / np.sum(gauss_labels + base_offset)

    bool_arr = []

    for idx in range(100):
        bool_arr.append(np.logical_and(all_labels < bins[idx + 1], all_labels >= bins[idx]))

    label_map_val = np.piecewise(all_labels, bool_arr, gauss_labels_off)

    label_prob_map = {}

    for idx in range(len(label_map_val)):
        label_prob_map[all_labels[idx]] = label_map_val[idx]

    pickle.dump(label_prob_map, open('lds_prob_map.pb', 'wb'))

    print('LDS probability map generation complete')
