""" for the new movie lens dataset """
import sys

sys.path.insert(0, "../../../../")

import pickle
import argparse
import numpy as np
import pandas as pd

from ganrl.encoder.obs_encoder import BasicObsEncoder
from ganrl.commons.multi_threading import RunInParallel
from ganrl.commons.args import PARAMS

parser = argparse.ArgumentParser()
parser.add_argument('--pp_num_threads', type=int, default=5, help='num of threads')
args = parser.parse_args()

df_ratings = pd.read_csv("../ml-latest-small/ratings.csv")  # num_session x some_data
obs_encoder = BasicObsEncoder(device="cpu")

""" === begin: Preprocess the dataset === """
col_names = list()
for i in range(1, PARAMS.GET_DATASET_REWARD_MODEL_HISTORY_SIZE + 1):
    col_names.append("t-{}".format(i))
    df_ratings["t-{}".format(i)] = df_ratings.sort_values(by="timestamp").groupby("userId")["new_itemId"].shift(
        periods=i)

drop_index = df_ratings[df_ratings.isna().any(axis=1)].index
df_ratings = df_ratings.dropna()
df_ratings[col_names] = df_ratings[col_names].astype(np.int)
print(df_ratings.sort_values(by="timestamp").head())
""" === end: Preprocess the dataset === """

with open("../ml-latest-small/item_embedding.pkl", "rb") as file:
    item_embeddings = pickle.load(file)  # num_movie x dim_movie

with open("../ml-latest-small/slates.pkl", "rb") as file:
    slates = pickle.load(file)  # num_session x slate_size

# Drop the sessions
slates = np.delete(slates, drop_index, axis=0)
print(df_ratings.shape, slates.shape)

list_df_ratings = np.array_split(df_ratings, args.pp_num_threads)
list_slates = np.array_split(slates, args.pp_num_threads)


def _fn(_df, _slates):
    X, Y = list(), list()
    for row, slate in zip(_df.iterrows(), _slates):
        # get the state for this session
        history_seq = item_embeddings[row[1][col_names].astype(np.int)]  # history_size x dim_item
        history_seq = np.expand_dims(history_seq, axis=0)  # 1 x history_size x dim_item
        state = obs_encoder.encode(obs=history_seq)
        state = state.tolist()

        # clicked movie's feature
        positive = item_embeddings[int(row[1]["new_itemId"])].tolist()
        positive = state + positive
        X.append(positive)
        Y.append(1)
        for itemId in slate:
            if itemId != int(row[1]["new_itemId"]):
                negative = item_embeddings[itemId].tolist()
                negative = state + negative
                X.append(negative)
                Y.append(0)
    return np.asarray(X), np.asarray(Y)


from functools import partial

fns = [partial(_fn, _df=list_df_ratings[i], _slates=list_slates[i]) for i in range(len(list_df_ratings))]
result = RunInParallel(fns=fns)

for i, k in enumerate(sorted(result.keys())):
    _X, _Y = result[k]
    if i == 0:
        X = _X
        Y = _Y
    else:
        X = np.concatenate([X, _X], axis=0)
        Y = np.concatenate([Y, _Y], axis=0)
    print(k, X.shape, Y.shape)

with open("../ml-latest-small/reward_model_X.pkl", "wb") as file:
    pickle.dump(np.asarray(X), file, protocol=pickle.HIGHEST_PROTOCOL)

with open("../ml-latest-small/reward_model_Y.pkl", "wb") as file:
    pickle.dump(np.asarray(Y), file, protocol=pickle.HIGHEST_PROTOCOL)
