import pandas as pd


def preprocess(df_ratings, split={"train": 0.8, "test": 0.1, "val": 0.1}):
    """ Original code: https://github.com/paxcema/KerasGRU4Rec/blob/master/preprocess/movieLens20M.py """
    # Prep to preprocess
    del df_ratings["new_itemId"]
    df_ratings = df_ratings.drop("rating", axis=1)  # we don't need the ratings
    df_ratings["itemId"] = df_ratings["itemId"].astype(int)
    df_ratings["timestamp"] = pd.to_datetime(df_ratings["timestamp"], unit="s")
    _desc = df_ratings["timestamp"].describe()
    train_date = df_ratings["timestamp"].quantile(split["train"])
    test_date = df_ratings["timestamp"].quantile(split["train"] + split["test"])
    print(_desc["first"], _desc["last"])

    # Training set
    df_train = df_ratings[df_ratings["timestamp"] <= train_date]
    df_test = df_ratings[(train_date < df_ratings["timestamp"]) & (df_ratings["timestamp"] < test_date)]
    df_val = df_ratings[test_date <= df_ratings["timestamp"]]
    print(df_train.shape, df_test.shape, df_val.shape)

    # Remove the users whose history is too short
    df_train = df_train.groupby("userId").filter(lambda x: 5 < len(x))
    df_test = df_test.groupby("userId").filter(lambda x: 5 < len(x))
    df_val = df_val.groupby("userId").filter(lambda x: 5 < len(x))
    print(df_train.shape, df_test.shape, df_val.shape)

    # Rename the columns: userId -> SessionId
    df_train.columns = ['SessionId', 'ItemId', 'Time']
    df_test.columns = ['SessionId', 'ItemId', 'Time']
    df_val.columns = ['SessionId', 'ItemId', 'Time']
    print(df_train.shape, df_test.shape, df_val.shape)
    return df_train, df_test, df_val


if __name__ == '__main__':
    # Load the data: cols -> (userId, itemId, rating, timestamp)
    df_ratings = pd.read_csv('../ml-latest-small/ratings.csv')

    df_train, df_test, df_val = preprocess(df_ratings=df_ratings)
    df_train.to_csv('./ml-latest-small/train.csv', sep='\t', index=False)
    df_test.to_csv('./ml-latest-small/test.csv', sep='\t', index=False)
    df_val.to_csv('./ml-latest-small/val.csv', sep='\t', index=False)
