import sys

sys.path.insert(0, "../../../../")

import pickle
import argparse
import numpy as np
import pandas as pd

from ganrl.commons.multi_threading import RunInParallel

parser = argparse.ArgumentParser()
parser.add_argument('--pp_num_threads', type=int, default=5, help='num of threads')
args = parser.parse_args()

# Load datasets and convert the data type
with open("../ml-latest-small/slates.pkl", "rb") as file:
    slates = pickle.load(file)  # num_of_session x slate_size

df_ratings = pd.read_csv("../ml-latest-small/ratings.csv")
df_sparse = pd.read_csv("./ml-latest-small/sparse_feature.csv")  # num_movie x dim_sparse
df_sparse["sparse_f"] = df_sparse["sparse_f"].apply(
    lambda x: np.fromstring(x.replace('\n', '').replace('[', '').replace(']', '').replace('  ', ' '), sep=' '))
df_dense = pd.read_csv("./ml-latest-small/dense_feature.csv")  # num_movie x dim_dense
df_dense["dense_f"] = df_dense["dense_f"].apply(
    lambda x: np.fromstring(x.replace('\n', '').replace('[', '').replace(']', '').replace('  ', ' '), sep=' '))
print(df_ratings.shape, df_sparse.shape, df_dense.shape)
print(df_ratings.columns, df_sparse.columns, df_dense.columns)
assert slates.shape[0] == df_ratings.shape[0], "slates: num_session x slate_size | df_ratings: num_session x some_cols"

list_df_ratings = np.array_split(df_ratings, args.pp_num_threads)
list_slates = np.array_split(slates, args.pp_num_threads)


def _fn(_df, _slates):
    # For each session, we label the clicked and the non-clicked items and get the corresponding features
    sparse_f, dense_f, label = list(), list(), list()
    for row, slate in zip(_df.iterrows(), _slates):
        # clicked movie's feature
        _sparse_f = df_sparse["sparse_f"].iloc[int(row[1]["new_itemId"])]
        sparse_f.append(_sparse_f.tolist())
        _dense_f = df_dense["dense_f"].iloc[int(row[1]["new_itemId"])]
        dense_f.append(_dense_f.tolist())
        label.append(1)
        for itemId in slate:
            if itemId != row[1]["itemId"]:
                _sparse_f = df_sparse["sparse_f"].iloc[itemId]
                sparse_f.append(_sparse_f.tolist())
                _dense_f = df_dense["dense_f"].iloc[itemId]
                dense_f.append(_dense_f.tolist())
                label.append(0)
    return np.asarray(sparse_f), np.asarray(dense_f), np.asarray(label)


from functools import partial

fns = [partial(_fn, _df=list_df_ratings[i], _slates=list_slates[i]) for i in range(len(list_df_ratings))]
result = RunInParallel(fns=fns)

for i, k in enumerate(sorted(result.keys())):
    _sparse_f, _dense_f, _label = result[k]
    if i == 0:
        sparse_f = _sparse_f
        dense_f = _dense_f
        label = _label
    else:
        sparse_f = np.concatenate([sparse_f, _sparse_f], axis=0)
        dense_f = np.concatenate([dense_f, _dense_f], axis=0)
        label = np.concatenate([label, _label], axis=0)
    print(k, sparse_f.shape, dense_f.shape, label.shape)

print(np.asarray(sparse_f).shape, np.asarray(dense_f).shape, np.asarray(label).shape)

with open("./ml-latest-small/sparse_f.pkl", "wb") as file:
    pickle.dump(np.asarray(sparse_f), file, protocol=pickle.HIGHEST_PROTOCOL)

with open("./ml-latest-small/dense_f.pkl", "wb") as file:
    pickle.dump(np.asarray(dense_f), file, protocol=pickle.HIGHEST_PROTOCOL)

with open("./ml-latest-small/label.pkl", "wb") as file:
    pickle.dump(np.asarray(label), file, protocol=pickle.HIGHEST_PROTOCOL)
