from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from datetime import datetime, timedelta
import pandas as pd
import math
import numpy as np
import random
from tqdm import trange
from copy import deepcopy

from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile

from math import sqrt
from pandas import read_csv, DataFrame
from scipy import stats

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def prep_data(data, covariates, data_start, train = True):

    time_len = data.shape[0]
    input_size = window_size-stride_size
    windows_per_series = np.full((num_series), (time_len-input_size) // stride_size)

    if train: windows_per_series -= (data_start+stride_size-1) // stride_size

    total_windows = np.sum(windows_per_series)
    x_input = np.zeros((total_windows, window_size, 1 + num_covariates + 1), dtype='float32')
    label = np.zeros((total_windows, window_size), dtype='float32')
    v_input = np.zeros((total_windows, 2), dtype='float32')

    count = 0

    if not train:
        covariates = covariates[-time_len:]

    for series in trange(num_series):
        cov_age = stats.zscore(np.arange(total_time-data_start[series]))

        if train:
            covariates[data_start[series]:time_len, 0] = cov_age[:time_len-data_start[series]]
        else:
            covariates[:, 0] = cov_age[-time_len:]

        for i in range(windows_per_series[series]):

            if train:
                window_start = stride_size*i+data_start[series]
            else:
                window_start = stride_size*i
            window_end = window_start+window_size

            x_input[count, 1:, 0] = data[window_start:window_end-1, series]
            x_input[count, :, 1:1+num_covariates] = covariates[window_start:window_end, :]
            x_input[count, :, -1] = series
            label[count, :] = data[window_start:window_end, series]
            nonzero_sum = (x_input[count, 1:input_size, 0]!=0).sum()

            if nonzero_sum == 0:
                v_input[count, 0] = 0
            else:
                v_input[count, 0] = np.true_divide(x_input[count, 1:input_size, 0].sum(),nonzero_sum)+1
                x_input[count, :, 0] = x_input[count, :, 0]/v_input[count, 0]
                if train:
                    label[count, :] = label[count, :]/v_input[count, 0]
            count += 1

    prefix = os.path.join(save_path, 'train_' if train else 'test_')
    np.save(prefix+'data_'+save_name, x_input)
    np.save(prefix+'v_'+save_name, v_input)
    np.save(prefix+'label_'+save_name, label)

def gen_covariates(times, num_covariates):

    covariates = np.zeros((times.shape[0], num_covariates))

    for i, input_time in enumerate(times):
        covariates[i, 1] = input_time.weekday()
        covariates[i, 2] = input_time.hour
        covariates[i, 3] = input_time.month

    for i in range(1,num_covariates):
        covariates[:,i] = stats.zscore(covariates[:,i])

    return covariates[:, :num_covariates]

def visualize(data, week_start):

    x = np.arange(window_size)
    f = plt.figure()
    plt.plot(x, data[week_start:week_start+window_size], color='b')
    f.savefig("visual.png")
    plt.close()


def process_list(s, variable_type=int, delimiter=None):

    """Parses a line in the PEMS format to a list."""

    if delimiter is None:
      l = [
          variable_type(i) for i in s.replace('[', '').replace(']', '').split()
      ]

    else:
      l = [
          variable_type(i)
          for i in s.replace('[', '').replace(']', '').split(delimiter)
      ]

    return l

def read_single_list(filename):

    """Returns single list from a file in the PEMS-custom format."""

    with open(filename, 'r') as dat:
      l = process_list(dat.readlines()[0])
    return l


def read_matrix(filename):
    """Returns a matrix from a file in the PEMS-custom format."""
    array_list = []
    with open(filename, 'r') as dat:

      lines = dat.readlines()
      for i, line in enumerate(lines):
        if (i + 1) % 50 == 0:
          print('Completed {} of {} rows for {}'.format(i + 1, len(lines),
                                                        filename))

        array = [
            process_list(row_split, variable_type=float, delimiter=None)
            for row_split in process_list(
                line, variable_type=str, delimiter=';')
        ]
        array_list.append(array)

    return np.array(array_list)


if __name__ == '__main__':

    global save_path
    save_name = 'traffic'
    window_size = 192
    stride_size = 24
    num_covariates = 4
    train_start = '2008-01-01 00:00:00'
    train_end = '2008-06-15 23:00:00'
    test_start = '2008-06-09 00:00:00' #need additional 7 days as given info
    test_end = '2008-07-15 23:00:00'

    save_path = os.path.join('data', save_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # Load rae data from files

    shuffle_order = np.array(read_single_list(os.path.join(save_path, 'randperm'))) - 1  # index from 0
    train_dayofweek = read_single_list(os.path.join(save_path, 'PEMS_trainlabels'))
    train_tensor = read_matrix(os.path.join(save_path, 'PEMS_train'))
    test_dayofweek = read_single_list(os.path.join(save_path, 'PEMS_testlabels'))
    test_tensor = read_matrix(os.path.join(save_path, 'PEMS_test'))

    # Inverse permutate shuffle order
    print('Shuffling')
    inverse_mapping = {
      new_location: previous_location
      for previous_location, new_location in enumerate(shuffle_order)
    }
    reverse_shuffle_order = np.array([
      inverse_mapping[new_location]
      for new_location, _ in enumerate(shuffle_order)
    ])

    # Group and reoder based on permuation matrix
    print('Reodering')
    day_of_week = np.array(train_dayofweek + test_dayofweek)
    combined_tensor = np.array(np.r_[train_tensor, test_tensor])
    print(combined_tensor.shape)

    day_of_week = day_of_week[reverse_shuffle_order]
    combined_tensor = combined_tensor[reverse_shuffle_order]

    # Put everything back into a dataframe
    print('Parsing as dataframe')
    labels = ['traj_{}'.format(i) for i in read_single_list(os.path.join(save_path, 'stations_list'))]

    # print(day_of_week)

    df_arr = np.zeros(shape=(0, 963))
    df_index_arr = []

    last_date = datetime.strptime("01/01/08", "%m/%d/%y")
    # print(last_date)

    hr_td_arr = [timedelta(minutes=10*i) for i in range(144)]

    for day, day_matrix in enumerate(combined_tensor):
        if day == 0:
            current_date = deepcopy(last_date)
        elif day_of_week[day] <= day_of_week[day-1]:
            td = 7 + day_of_week[day] - day_of_week[day-1]
            current_date = last_date + timedelta(days=int(td))
            last_date = deepcopy(current_date)
        else:
            td = day_of_week[day] - day_of_week[day-1]
            current_date = last_date + timedelta(days=int(td))
            last_date = deepcopy(current_date)

        day_time_arr = [current_date + hr_td for hr_td in hr_td_arr]

        # if day < 1:
        #     # print(day_time_arr)

        df_index_arr = np.r_[df_index_arr, day_time_arr]
        df_arr = np.r_[df_arr, np.transpose(day_matrix)]

        # if (day + 1) % 20 == 0:
        #     print(day, 'Done')

    data_frame = pd.DataFrame(df_arr, index=df_index_arr)
    data_frame = data_frame.resample('1H',label = 'left',closed = 'right').mean()[train_start:test_end]

    data_frame.fillna(0, inplace=True)

    covariates = gen_covariates(data_frame[train_start:test_end].index, num_covariates)
    train_data = data_frame[train_start:train_end].values
    test_data = data_frame[test_start:test_end].values
    data_start = (train_data!=0).argmax(axis=0) #find first nonzero value in each time series
    total_time = data_frame.shape[0] #32304
    num_series = data_frame.shape[1] #370

    prep_data(train_data, covariates, data_start)
    prep_data(test_data, covariates, data_start, train=False)
