""" Replace the movie id with index in the item-embedding for each session in movielens """
import sys

sys.path.insert(0, "../")

import argparse
import numpy as np
import pandas as pd

from multi_threading import RunInParallel

parser = argparse.ArgumentParser()
parser.add_argument('--pp_num_threads', type=int, default=5, help='num of threads')
args = parser.parse_args()

df_movies = pd.read_csv("./ml-latest-small/movies.csv")
id2idx = {v: k for k, v in enumerate(df_movies["itemId"])}

# Update the movie indices in the dataset
df_ratings = pd.read_csv("./ml-latest-small/ratings.csv")

list_df_ratings = np.array_split(df_ratings, args.pp_num_threads)


def _fn(_df):
    _df["new_itemId"] = _df["itemId"].apply(lambda x: id2idx[x]).astype(np.int)
    return _df


from functools import partial

fns = [partial(_fn, _df=list_df_ratings[i]) for i in range(len(list_df_ratings))]
result = RunInParallel(fns=fns)

for i, k in enumerate(sorted(result.keys())):
    # slates: num_of_session x slate_size
    if i == 0:
        df = result[k]
    else:
        print(df.shape, result[k].shape)
        df = pd.concat([df, result[k]], ignore_index=True)
    print(k, df.shape)

print(df)
df.to_csv("./ml-latest-small/ratings.csv", index=False)

import pickle

with open("./ml-latest-small/id2idx.pkl", "wb") as file:
    pickle.dump(id2idx, file, protocol=pickle.HIGHEST_PROTOCOL)
