import os.path

import numpy as np
import pandas as pd
import torch.utils.data
from sklearn.preprocessing import FunctionTransformer
from sklearn.preprocessing import LabelBinarizer

from config.experiments import val_params


def sin_transformer(period):
    return FunctionTransformer(lambda x: np.sin(x / period * 2 * np.pi))


def cos_transformer(period):
    return FunctionTransformer(lambda x: np.cos(x / period * 2 * np.pi))


def min_max_encode(data, min, max):
    return (data - min) / (max - min + 1e-8)


def min_max_decode(data, min, max):
    return data * (max - min) + min


class SalzburgData(torch.utils.data.Dataset):
    """
    Dataset representing the Salzburg Ticket sales data.
    """

    def __init__(self, x, y, device='cpu'):
        self.x = torch.tensor(x, dtype=torch.float).to(device)
        self.y = torch.tensor(y, dtype=torch.float).to(device)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


def split_sequences(df, seq_len=30, overlap=0, single_y=False, extra_features=False):
    """
    Splits the given dataframe into sequences of length seq_len.

    :param df: The dataframe to split
    :param seq_len: The length of the sequences
    :param overlap: The overlap between the sequences
    :param single_y: If True, the y value will be the last value of the sequence, otherwise it will be the whole sequence
    :param extra_features: Only applicable to SC data; If True, the additional features will be included in the x values
    """
    x_out = []
    y_out = []
    df_y = df['venues'] if extra_features and 'venues' in df.columns else df
    for _t in range(0, df.shape[0] - seq_len - 1, max(1, seq_len - overlap)):
        x_out.append(df.iloc[_t:_t + seq_len])
        if single_y:
            y_out.append(df_y.iloc[_t + seq_len + 1])
        else:
            y_out.append(df_y.iloc[_t + 1:_t + seq_len + 1])

    return np.stack(x_out, axis=0), np.stack(y_out, axis=0)


def get_extra_features(df):
    used_columns = ['daysToWorkday', 'temp', 'feels_like', 'wind_speed', 'precipitation', 'clouds_all']
    df_out = df.loc[:, df.columns.get_level_values(1).isin(used_columns)]

    # One-hot encode weather string
    binarizer = LabelBinarizer()
    one_hot_weather = pd.DataFrame(binarizer.fit_transform(df[('features', 'weather_main')]))
    one_hot_weather.columns = pd.MultiIndex.from_product([('features',), binarizer.classes_])
    one_hot_weather.index = df_out.index
    df_out = pd.concat([df_out, one_hot_weather], axis=1)

    # add datetime features
    datetime_series = pd.to_datetime(df.index.to_series(), utc=True).dt
    df_out[('features', 'day')] = datetime_series.day
    df_out[('features', 'hour')] = datetime_series.hour  # * 60 +  datetime_series.dt.minute) / 15

    # perform sin-cos transform on 'month' column
    df_out[('features', 'month_sin')] = sin_transformer(12).fit_transform(datetime_series.month)
    df_out[('features', 'month_cos')] = cos_transformer(12).fit_transform(datetime_series.month)
    return df_out


def load_and_prep_df(filename='data/dataset_bounded.csv', normalize_visitors=False, extra_features=True,
                     start_year=2019, end_year=2021):
    if not os.path.isfile(filename):
        filename = '../' + filename
    df = pd.read_csv(filename, header=[0, 1], index_col=0)
    df_out = df.xs('venues', axis=1, drop_level=False)

    # Get datetime related columns
    datetime_series = pd.to_datetime(df.index.to_series(), utc=True)
    df['year'] = datetime_series.dt.year
    df['month'] = datetime_series.dt.month  # Month will be sin-cos encoded below

    # Do train/test split
    # train_end_idx = int(len(df) * train_split)
    train_mask = (df['year'] < end_year) & (df['year'] >= start_year)
    test_mask = df['year'] == end_year
    # too_few_entries_mask = (df_out[train_mask].sum() > 100) & (df_out[test_mask].sum() > 100)
    # df_out = df_out.loc[:, too_few_entries_mask]

    if extra_features:
        df_out = pd.concat([df_out, get_extra_features(df)], axis=1)

    print(df_out['venues'].columns)

    train = df_out[train_mask]
    test = df_out[test_mask]

    if extra_features:
        min_max_cols = ['daysToWorkday', 'day', 'hour', 'temp', 'feels_like', 'wind_speed', 'precipitation',
                        'clouds_all']
        for col in min_max_cols:
            minimum, maximum = train['features'][col].min(), train['features'][col].max()
            with pd.option_context('mode.chained_assignment', None):
                train.loc[:, ('features', col)] = min_max_encode(train['features'][col], minimum, maximum)
                test.loc[:, ('features', col)] = min_max_encode(test['features'][col], minimum, maximum)

    mins, maxs = train['venues'].min().values, train['venues'].max().values
    if normalize_visitors:
        for col in test['venues'].columns:
            minimum, maximum = train['venues'][col].min(), train['venues'][col].max()
            with pd.option_context('mode.chained_assignment', None):
                train.loc[:, ('venues', col)] = min_max_encode(train['venues'][col], minimum, maximum)
                test.loc[:, ('venues', col)] = min_max_encode(test['venues'][col], minimum, maximum)

    return train, test, mins, maxs


def load_prep_with_near(normalize_visitors=True):
    # Get parquet file containing near (and Salzburgcard) data

    try:
        near_df = pd.read_parquet("data/graphs/node_attrs.parq")
    except FileNotFoundError:
        near_df = pd.read_parquet("../data/graphs/node_attrs.parq")
    # Get Salzburgcard data
    try:
        df = pd.read_csv("data/dataset_bounded.csv", header=[0, 1], index_col=0)
    except FileNotFoundError:
        df = pd.read_csv("../data/dataset_bounded.csv", header=[0, 1], index_col=0)
    df.index = pd.to_datetime(df.index, utc=True)
    df = df.loc[near_df.index]

    # Get df containing extra features
    features_df = get_extra_features(df)

    num_global_features = features_df.shape[1]

    # Get number of POIs
    num_pois = len(df['venues'].columns)

    near_df.columns = pd.MultiIndex.from_arrays(
        [['near'] * (len(near_df.columns) - num_pois) + ['venues'] * num_pois, near_df.columns])

    # Add features to df
    near_df = near_df.join(features_df, how='inner').fillna(0.0)
    near_df_out = near_df.copy()

    datetime_series = pd.to_datetime(near_df.index.to_series(), utc=True)
    near_df['year'] = datetime_series.dt.year
    near_df['month'] = datetime_series.dt.month  # Month will be sin-cos encoded below

    # Do train/test split
    # train_end_idx = int(len(df) * train_split)
    train_mask = (near_df['year'] < 2021) & (near_df['year'] >= 2019)
    test_mask = near_df['year'] == 2021

    train = near_df_out[train_mask]
    test = near_df_out[test_mask]

    mins, maxs = train['venues'].min().values, train['venues'].max().values
    if normalize_visitors:
        for col in train.columns:
            minimum, maximum = train[col].min(), train[col].max()
            train.loc[:, col] = min_max_encode(train[col], minimum, maximum)

    min_max_cols = ['daysToWorkday', 'day', 'hour', 'temp', 'feels_like', 'wind_speed', 'precipitation',
                    'clouds_all']
    for col in min_max_cols:
        minimum, maximum = train['features'][col].min(), train['features'][col].max()
        with pd.option_context('mode.chained_assignment', None):
            train.loc[:, ('features', col)] = min_max_encode(train['features'][col], minimum, maximum)
            test.loc[:, ('features', col)] = min_max_encode(test['features'][col], minimum, maximum)

    return train, test, num_global_features, mins, maxs


def prepare_torch_datasets(device='cpu', seq_len=30, overlap=0, normalize_visitors=False, extra_features=True,
                           use_near=False):
    """
    :return: train and eval datasets
    """
    if use_near:
        train, test, num_global_features, mins, maxs = load_prep_with_near(normalize_visitors)
        # extra_features true to not extract near columns
        x_train, y_train = split_sequences(train, seq_len=seq_len, overlap=overlap, extra_features=True)
        x_test, y_test = split_sequences(test, seq_len=val_params['seq_len'], overlap=val_params['overlap'],
                                         single_y=True, extra_features=True)
    else:
        train, test, mins, maxs = load_and_prep_df(normalize_visitors=normalize_visitors, extra_features=extra_features)
        x_train, y_train = split_sequences(train, seq_len=seq_len, overlap=overlap, extra_features=extra_features)
        x_test, y_test = split_sequences(test, seq_len=val_params['seq_len'], overlap=val_params['overlap'],
                                         single_y=True, extra_features=extra_features)

    ds_train = SalzburgData(x_train, y_train, device=device)
    ds_test = SalzburgData(x_test, y_test, device=device)
    return ds_train, ds_test, mins, maxs


def load_and_prep_graphs(path="data/graphs", normalize_visitors=False):
    if not os.path.isdir(path):
        path = '../' + path
    if normalize_visitors:
        filename = os.path.abspath(path + "/node_attrs_normalized.parq")
    else:
        filename = os.path.abspath(path + "/node_attrs.parq")

    print("Reading node attributes")

    node_attrs = pd.read_parquet(filename)

    print("Preparing global features")
    df = pd.read_csv(path + '/../dataset.csv', header=[0, 1], index_col=0)
    df.index = pd.to_datetime(df.index)
    df = df.loc[node_attrs.index]
    global_features = get_extra_features(df)
    if normalize_visitors:
        min_max_cols = ['daysToWorkday', 'day', 'hour', 'temp', 'feels_like', 'wind_speed', 'precipitation',
                        'clouds_all']
        for col in min_max_cols:
            minimum, maximum = global_features['features'][col].min(), global_features['features'][col].max()
            global_features.loc[:, ('features', col)] = min_max_encode(global_features['features'][col], minimum,
                                                                       maximum)

    num_global_features = global_features.shape[1]
    global_features = global_features.fillna(0.0)
    node_attrs = pd.concat({'venues': node_attrs}, names=['Firstlevel'], axis=1)
    node_attrs = node_attrs.join(global_features)
    return node_attrs, num_global_features


def prepare_graph_datasets(seq_len=30, normalize_visitors=False):
    edges_file = 'data/graphs/edge_weights.csv'
    if not os.path.isfile(edges_file):
        edges_file = '../' + edges_file
    edges = pd.read_csv(edges_file, header=0)

    train, test, num_global_features, mins, maxs = load_prep_with_near(normalize_visitors)
    train.rename(columns={'near': 'venues'}, inplace=True)
    x_train, y_train = split_sequences(train, seq_len, extra_features=True, single_y=True)
    x_test, y_test = split_sequences(test, seq_len=val_params['seq_len'],
                                     overlap=val_params['overlap'], extra_features=True, single_y=True)

    # HACK indices of POIs inside BBox
    indices = ['Eintritt Domgrabungsmuseum', 'Eintritt Domquartier',
               'Eintritt Festspielhaus Salzburg',
               'Eintritt Festungseintritt - Torbogenkassa XML', 'Eintritt Geburtshaus',
               'Eintritt Georg-Trakl-Gedenkstätte', 'Eintritt Haus der Natur',
               'Eintritt MdM - Rupertinum', 'Eintritt MdM - am Mönchsberg',
               'Eintritt Neue Residenz', 'Eintritt Panorama Museum',
               'Eintritt Spielzeug Museum', 'Eintritt St. Peter Katakomben',
               'Eintritt Wohnhaus', 'Fahrt Salzburg Stadt Schiff-fahrt',
               'Festungsbahn Bergfahrt', 'Festungsbahn Talfahrt',
               'Mönchsbergaufzug Bergfahrt', 'Mönchsbergaufzug Talfahrt',
               'Sound of Music World', 'Weihnachtsmuseum']

    y_test = y_test[:, -len(indices):]

    glob_train = x_train[..., -num_global_features:]
    x_train = x_train.transpose(0, 2, 1)[:, :-num_global_features, None, :]
    glob_test = x_test[..., -num_global_features:]
    x_test = x_test.transpose(0, 2, 1)[:, :-num_global_features, None, :]
    # y_train = y_train.transpose(0, 2, 1)
    # y_test = y_test.transpose(0, 2, 1)

    train_x_tensor = torch.from_numpy(x_train).type(torch.FloatTensor)  # (B, N, F, T)
    train_glob_tensor = torch.from_numpy(glob_train).type(torch.FloatTensor)  # (B, N, T)
    train_target_tensor = torch.from_numpy(y_train).type(torch.FloatTensor)  # (B, N, T)
    train_ds = torch.utils.data.TensorDataset(train_x_tensor, train_glob_tensor, train_target_tensor)

    test_x_tensor = torch.from_numpy(x_test).type(torch.FloatTensor)  # (B, N, F, T)
    test_glob_tensor = torch.from_numpy(glob_test).type(torch.FloatTensor)  # (B, N, T)
    test_target_tensor = torch.from_numpy(y_test).type(torch.FloatTensor)  # (B, N, T)
    val_ds = torch.utils.data.TensorDataset(test_x_tensor, test_glob_tensor, test_target_tensor)
    return train_ds, val_ds, edges, len(indices), mins, maxs


if __name__ == "__main__":
    # main()
    prepare_torch_datasets()
