""" Get slates for each session in movielens
    Ref: https://github.com/tensorflow/recommenders/blob/main/tensorflow_recommenders/examples/movielens.py
"""

import os
import pickle
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.args import ML100K_ITEM_FEATURES, ML100K_USER_FEATURES, DATA_SPLIT, USER_HISTORY_COL_NAME, \
    HISTORY_SIZE

CLICK_CRITERION = 4  # on the scale of 5
# CLICK_CRITERION = 5  # on the scale of 5


def get_history_seq(_df: pd.DataFrame, history_size: int):
    gb = _df.sort_values(by='timestamp').groupby('userId')['itemId']
    col_names = list()
    for t in range(history_size):
        _col = USER_HISTORY_COL_NAME + str(t + 1)
        _df[_col] = gb.shift(periods=(t + 1))
        col_names.append(_col)

    # Remove the users with shorter history length
    _df = _df.dropna()  # get the users with enough history length
    mask = _df.index.values  # get the mask to apply the same removal to the obtained slate
    _df = _df.reset_index(drop=True)
    _df["hist_seq"] = _df[col_names].astype(np.int).values.tolist()
    _df = _df.drop(col_names, axis=1)
    return _df, mask


def main(args: dict):
    # preprocess df_users
    users_headers = ['userId', 'age', 'gender', 'occupation', 'zipCode']
    df_users = pd.read_csv(os.path.join(args["data_dir"], "u.user"),
                           sep='|',
                           header=None,
                           names=users_headers,
                           encoding='latin-1')
    df_users["userId"] = df_users["userId"] - 1  # because the id starts from 1
    df_users["age"] = MinMaxScaler().fit_transform(df_users["age"].values.reshape(-1, 1))
    df_users = pd.concat([df_users, pd.get_dummies(df_users["gender"])], axis=1)
    df_users = pd.concat([df_users, pd.get_dummies(df_users["occupation"])], axis=1)
    df_users = df_users.drop(['gender', 'occupation', 'zipCode'], axis=1)
    print("df_users: {}".format(df_users.shape))
    print(df_users.columns)
    df_users.to_csv(os.path.join(args["data_dir"], "users.csv"), index=False)
    arr = df_users.drop(["userId"], axis=1).values
    np.save(file=os.path.join(args["data_dir"], "user_attr"), arr=arr)

    # preprocess df_movies
    movie_headers = ['itemId', 'movieTitle', 'releaseDate', 'videoReleaseDate',
                     'IMDbURL', 'unknown', 'Action', 'Adventure', 'Animation',
                     'Childrens', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy',
                     'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
                     'Thriller', 'War', 'Western']

    df_movies = pd.read_csv(os.path.join(args["data_dir"], "u.item"),
                            sep='|',
                            header=None,
                            names=movie_headers,
                            encoding='latin-1')
    df_movies["releaseDate"] = pd.to_datetime(df_movies["releaseDate"])
    df_movies["releaseDate"] = df_movies["releaseDate"].dt.strftime('%Y-%m-%d')
    df_movies["itemId"] = df_movies["itemId"] - 1  # because the id starts from 1
    df_movies = df_movies.drop(["movieTitle", "videoReleaseDate", "IMDbURL", "unknown"], axis=1)
    print("df_movies: {}".format(df_movies.shape))
    print(df_movies.columns)
    df_movies.to_csv(os.path.join(args["data_dir"], "movies.csv"), index=False)
    arr = df_movies.drop(["itemId", "releaseDate"], axis=1).values
    np.save(file=os.path.join(args["data_dir"], "item_attr"), arr=arr)

    # preprocess df_ratings
    # split the rating data
    headers = ["userId", "itemId", "rating", "timestamp"]
    df_rating = pd.read_csv(os.path.join(args["data_dir"], "u.data"),
                            sep='\t',
                            header=None,
                            names=headers,
                            encoding='latin-1')
    print(df_rating["rating"].value_counts(), df_rating.shape)
    df_rating["timestamp"] = pd.to_datetime(df_rating["timestamp"], unit="s")
    df_rating["userId"] = df_rating["userId"] - 1  # because the id starts from 1
    df_rating["itemId"] = df_rating["itemId"] - 1  # because the id starts from 1
    df_rating["click"] = (df_rating["rating"] >= CLICK_CRITERION).astype(np.int)
    print(df_rating.head())
    print(df_rating["userId"].describe(), df_rating["itemId"].describe())

    # get the anchor date to split the dataset by date
    offline_date = df_rating["timestamp"].quantile(DATA_SPLIT["offline"])

    # get the corresponding indices
    offline_session_id_list = df_rating[df_rating["timestamp"] <= offline_date].index.tolist()
    test_session_id_list = df_rating[offline_date <= df_rating["timestamp"]].index.tolist()

    # split the dataset by date
    df_offline = df_rating[df_rating.index.isin(offline_session_id_list)]
    df_online = df_rating[df_rating.index.isin(test_session_id_list)]
    print(df_offline.shape, df_online.shape)

    # get the itemIds in each train/test
    offline_itemIds, test_itemIds = df_offline["itemId"].unique().tolist(), df_online["itemId"].unique().tolist()
    intersection = list(set(offline_itemIds) & set(test_itemIds))
    print("Unique movies| train: {} test: {}".format(len(offline_itemIds), len(test_itemIds)))
    print("Only movies in test: {}".format(len(test_itemIds) - len(intersection)))
    # check if the itemId starts from 0
    print(min(offline_itemIds), max(offline_itemIds))
    print(min(test_itemIds), max(test_itemIds))

    print(df_offline.shape, df_online.shape)

    # === Users ===
    userId2index = {key: int(userId) for key, userId in enumerate(df_users["userId"])}
    index2userId = {int(userId): key for key, userId in enumerate(df_users["userId"])}
    with open(os.path.join(args["data_dir"], "userId2index.pkl"), "wb") as handle:
        pickle.dump(userId2index, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args["data_dir"], "index2userId.pkl"), "wb") as handle:
        pickle.dump(index2userId, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args["data_dir"], "user_attr.pkl"), "wb") as handle:
        pickle.dump(df_users[ML100K_USER_FEATURES].values, handle, protocol=pickle.HIGHEST_PROTOCOL)

    # === Items ===
    itemId2index = {key: int(userId) for key, userId in enumerate(df_movies["itemId"])}
    index2itemId = {int(userId): key for key, userId in enumerate(df_movies["itemId"])}
    with open(os.path.join(args["data_dir"], "itemId2index.pkl"), "wb") as handle:
        pickle.dump(itemId2index, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args["data_dir"], "index2itemId.pkl"), "wb") as handle:
        pickle.dump(index2itemId, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args["data_dir"], "item_attr.pkl"), "wb") as handle:
        pickle.dump(df_movies[ML100K_ITEM_FEATURES].values, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return df_offline, df_online


if __name__ == '__main__':
    args = {
        "seed": 2023,
        "how_to_split": "time",
        "data_dir": "./ml-100k"
    }

    set_randomSeed(seed=args["seed"])
    df_offline, df_online = main(args=args)

    # ======= GET history
    df_movies = pd.read_csv(os.path.join(args["data_dir"], "movies.csv"))
    df_users = pd.read_csv(os.path.join(args["data_dir"], "users.csv"))

    df_offline, mask = get_history_seq(_df=df_offline, history_size=HISTORY_SIZE)
    print(df_offline.shape)
    df_offline.to_csv(os.path.join(args["data_dir"], "offline_log.csv"), index=False)

    df_online, mask = get_history_seq(_df=df_online, history_size=HISTORY_SIZE)
    print(df_online.shape)
    df_online.to_csv(os.path.join(args["data_dir"], "online_log.csv"), index=False)
