""" Get slates for each session in movielens
    Ref: https://github.com/tensorflow/recommenders/blob/5712f07c8744d2e8e3cc9635f07229167fb8a1cb/tensorflow_recommenders/examples/movielens.py#L102
"""
import sys

sys.path.insert(0, "../../../../")

import argparse
import numpy as np
import pandas as pd
from random import sample, shuffle

from ganrl.commons.multi_threading import RunInParallel
from ganrl.commons.args import PARAMS

parser = argparse.ArgumentParser()
parser.add_argument('--pp_num_threads', type=int, default=5, help='num of threads')
args = parser.parse_args()

df_ratings = pd.read_csv("../ml-latest-small/ratings.csv")

import pickle

with open("./ml-latest-small/id2idx.pkl", "rb") as file:
    id2idx = pickle.load(file)

df_release_date = pd.read_csv("./ml-latest-small/release_date.csv")
df_ratings["timestamp"] = pd.to_datetime(df_ratings["timestamp"], unit="s")
df_release_date["release_date"] = pd.to_datetime(df_release_date["release_date"], format="%Y-%m-%d")
print(df_release_date["release_date"].describe())
print(df_ratings["timestamp"].describe())

slates = list()

list_df_ratings = np.array_split(df_ratings, args.pp_num_threads)


def _fn(_df):
    # We assume that a user was shown to a list of randomly sampled movies that had been released already
    for row in _df.iterrows():
        movies = df_release_date["itemId"][df_release_date["release_date"] <= row[1]["timestamp"]].values
        # Find the clicked(rated) movie
        clicked_movie = row[1]["new_itemId"]
        # Reindex the movies
        movies = [id2idx[itemId] for itemId in movies]
        # Randomly sample the movies(except the clicked one) in a slate
        if clicked_movie in movies:
            movies.remove(clicked_movie)
        slate = sample(movies, k=PARAMS.GET_SLATE_SLATE_SIZE - 1)
        # Insert the rated movie
        slate += [clicked_movie]
        # Shuffle just in case
        shuffle(slate)
        slates.append(slate)
    return np.asarray(slates)


from functools import partial

fns = [partial(_fn, _df=list_df_ratings[i]) for i in range(len(list_df_ratings))]
result = RunInParallel(fns=fns)

for i, k in enumerate(sorted(result.keys())):
    # slates: num_session x slate_size
    if i == 0:
        slates = result[k]
    else:
        print(slates.shape, result[k].shape)
        slates = np.concatenate([slates, result[k]], axis=0)
    print(k, slates.shape)

import pickle

# slates: num_of_session x slate_size
with open("../ml-latest-small/slates.pkl", "wb") as file:
    pickle.dump(slates, file, protocol=pickle.HIGHEST_PROTOCOL)
