#coding: utf-8

from math import ceil
import csv
import numpy as np
import subprocess as sb
from tqdm import tqdm
import os
import pickle
import pandas as pd
from scipy.sparse import csr_array, vstack, hstack
import tensorflow_hub as hub
from sklearn.preprocessing import normalize
from zero.svd import MangakiSVD
from stanscofi.datasets import Dataset
from stanscofi.utils import load_dataset
from copy import deepcopy

import sys
if __name__ == "__main__":
    sys.path.insert(0,"HAN/")
else:
    sys.path.insert(0,"known_setting/HAN/")
from bdivrec.fabaphe.utils import seed_everything, chunks

from HAN import HAN

class Environment(object):
    def __init__(self, params):
        assert "name" in params
        if ("nchunks" not in params):
            self.nchunks = 10000 ## tradeoff
        self.quantize_digit = params.get('quantize_digit', -1) ## quantization to avoid large memory usage
        self.new = False
        for param in params:
            setattr(self, param, params[param])
        self.item_file = self.name+"/items.csv"
        self.user_file = self.name+"/users.csv"
        self.user_history_file = self.name+"/history_%d.pkl"
        if (self.new):
            proc = sb.Popen(f"rm -rf {self.name}".split(" "))
            proc.wait()
        proc = sb.Popen(f"mkdir -p {self.name}".split(" "))
        proc.wait()

    def get_user_hist(self, user):
        ## initialize lazily the user history
        if (not os.path.exists(self.user_history_file % user)):
            history = []
            with open(self.user_history_file % user, "wb") as f:
                pickle.dump([], f)
        else:
            with open(self.user_history_file % user, "rb") as f:
                history = pickle.load(f)
        return history

    def set_user_hist(self, user, S):
        if (os.path.exists(self.user_history_file % user)):
            proc = sb.Popen(f"rm -f {self.user_history_file % user}".split(" "))
            proc.wait()
        with open(self.user_history_file % user, "wb") as f:
            pickle.dump(S, f)

    def update_user_hist(self, user, S):
        ## initialize lazily the user history
        if (not os.path.exists(self.user_history_file % user)):
            history = []
        else:
            with open(self.user_history_file % user, "rb") as f:
                history = pickle.load(f)
        history = list(set(history+S)) ## unique items
        with open(self.user_history_file % user, "wb") as f:
            pickle.dump(history, f)

    def read_slice_large_file(self, fname, nrows, start_slice, end_slice):
        if fname.endswith(".csv"):
            fname_npy = fname[:-3] + "npy"
            #TODO: Make this work for really large files
            if not os.path.exists(fname_npy):
                np.save(fname_npy, np.loadtxt(fname, delimiter=",", ndmin=2))
            fname = fname_npy
        assert fname.endswith(".npy")
        assert end_slice-start_slice>0
        A = np.load(fname, mmap_mode='r')[start_slice:end_slice]
        #A = np.loadtxt(fname, skiprows=start_slice, max_rows=(end_slice-start_slice), delimiter=",", ndmin=2)
        #A = np.array([A])
        if (len(A.shape)>2):
            A = A.reshape((end_slice-start_slice, -1))
        if (self.quantize_digit>=0):
            A = np.round(A, self.quantize_digit)
        A = csr_array(A)
        return A

    def read_lines_large_file(self, fname, nrows, rows, verbose=False):
        if fname.endswith(".csv"):
            fname_npy = fname[:-3] + "npy"
            #TODO: Make this work for really large files
            if not os.path.exists(fname_npy):
                np.save(fname_npy, np.loadtxt(fname, delimiter=",", ndmin=2))
            fname = fname_npy
        assert fname.endswith(".npy")
        assert len(rows)==len(set(rows))
        if (len(rows)==1):
            return self.read_slice_large_file(fname, nrows, rows[0], rows[0]+1)
        assert len(rows) < 10000
        A = np.load(fname, mmap_mode='r')[rows]
        if (len(A.shape)>2):
            A = A.reshape((len(rows), -1))
        if (self.quantize_digit>=0):
            A = np.round(A, self.quantize_digit)
        A = csr_array(A)
        return A
        A = None
        #A = []
        row_id_list = list(sorted(deepcopy(rows)))
        for ii, row_lst in enumerate(pbar := tqdm(
            chunks(nrows, self.nchunks),
            position=4,
            leave=False,
            disable=not verbose or ceil(nrows/self.nchunks) < 2
        )):
            pbar.set_description(f"Enumerating lines {min((ii+1)*self.nchunks,nrows)}/{nrows}")
            a = np.loadtxt(fname, skiprows=row_lst[0], max_rows=len(row_lst), delimiter=",", ndmin=2)
            if (self.quantize_digit>=0):
                a = np.round(a, self.quantize_digit)
            a = csr_array(a)
            row_id_in_chunk = []
            for i in range(len(row_lst)):
                if (row_lst[i] in row_id_list):
                    row_id_in_chunk.append(i)
            if (len(row_id_in_chunk)>0):
                a = a[row_id_in_chunk]
                if (A is None):
                    A = a.copy()
                else:
                    A = vstack((A,a))
                    if (A.shape[0]==len(rows)):
                        break
                #A.append(a)
        #A = np.concatenate(A, axis=0)
        #A = vstack(tuple(A))
        ## sort in the same order as in the input
        A = A[[row_id_list.index(r) for r in rows]]
        return A

    def item_embs(self, item_ls):
        return self.read_lines_large_file(self.item_file, self.nitem, item_ls)

    def user_embs(self, user_ls):
        return self.read_lines_large_file(self.user_file, self.nuser, user_ls)

    def item_embs_slice(self, item_start, item_end):
        return self.read_slice_large_file(self.item_file, self.nitem, item_start, item_end)

    def user_embs_slice(self, user_start, user_end):
        return self.read_slice_large_file(self.user_file, self.nuser, user_start, user_end)

    def feedback(self, item, user):
        raise NotImplemented

    def feedback_slice(self, item_start, item_end, user):
        raise NotImplemented

    def reset_user_hist(self):
        for user in range(self.nuser):
            if (os.path.exists(self.user_history_file % user)):
                proc = sb.Popen(f"rm -f {self.user_history_file % user}".split(" "))
                proc.wait()

class TabularNoHistory(Environment):
    def __init__(self, params):
        assert "items" in params
        assert "users" in params
        assert "feedbacks" in params
        super().__init__(params)
        assert self.feedbacks.shape[0] == self.items.shape[0]
        assert self.feedbacks.shape[1] == self.users.shape[0]
        self.feedback_file = self.name+"/feedbacks.csv"
        for fn, mat in [(self.item_file, self.items), (self.user_file, self.users), (self.feedback_file, self.feedbacks)]:
            np.savetxt(fn, mat, delimiter=",", newline="\n")
        self.nitem, self.d = self.items.shape
        self.nuser, self.d_user = self.users.shape

    def feedback(self, item, user):
        feedback_items = self.read_lines_large_file(self.feedback_file, self.nitem, item)
        qh = feedback_items[:,[user]]
        return qh

    def feedback_slice(self, item_start, item_end, user):
        feedback_items = self.read_slice_large_file(self.feedback_file, self.nitem, item_start, item_end)
        qh = feedback_items[:,[user]]
        return qh

class SyntheticCosine(Environment):
    def __init__(self, params):
        assert "nitem" in params ## number of items
        assert "nuser" in params ## number of users
        assert "d" in params ## dimension of item and user embeddings
        assert "name" in params ## name of the data set
        assert "seed" in params ## random seed
        if ("ngroup" not in params):
            self.ngroup = 3 ## number of collinear groups
        if ("nvar" not in params):
            self.nvar = 0.01
        super().__init__(params)
        self.rng = np.random.default_rng(self.seed)
        self.d_user = self.d
        ## N groups of collinear vectors
        def item_gen_fun(n):
            Phi = self.rng.normal(0, 2, size=(n//self.ngroup, self.d))
            Phi = np.concatenate(tuple([Phi + self.nvar*i for i in range(self.ngroup)]))
            return Phi
        def user_gen_fun(n):
            return self.rng.normal(0, 1, size=(n, self.d))
        if (not os.path.exists(self.item_file)):
            self.nitem = 0
            self.nitem += self.add_elements(self.item_file, params["nitem"], item_gen_fun)
        if (not os.path.exists(self.user_file)):
            self.nuser = 0
            self.nuser += self.add_elements(self.user_file, params["nuser"], user_gen_fun)

    def feedback(self, item, user, verbose=False):
        qh = None
        N = len(item)
        h = self.user_embs([user])
        for ii, item_lst in enumerate(pbar := tqdm(
            chunks(N, self.nchunks),
            position=4,
            leave=False,
            disable=not verbose or ceil(N/self.nchunks) < 2
        )):
            pbar.set_description(f"Computing feedback {min((ii+1)*self.nchunks,N)}/{N}")
            Phi = self.item_embs([item[i] for i in item_lst])
            qs = self.score(Phi, h)
            if (self.quantize_digit>=0):
                qs = np.round(qs, self.quantize_digit)
            qs = csr_array(qs)
            if (qh is None):
                qh = qs.copy()
            else:
                qh = vstack((qh, qs))
        return qh

    def feedback_fast(self, Phi, user):
        h = self.user_embs([user])
        qs = self.score(Phi, h)
        if (self.quantize_digit>=0):
            qs = np.round(qs, self.quantize_digit)
        qs = csr_array(qs)
        return qs

    def feedback_slice(self, item_start, item_end, user, verbose=False):
        qh = None
        N = item_end-item_start
        h = self.user_embs([user])
        for ii, item_lst in enumerate(pbar := tqdm(
            chunks(N, self.nchunks),
            position=4,
            leave=False,
            disable=not verbose or ceil(N/self.nchunks) < 2
        )):
            pbar.set_description(f"Computing feedback (slice) {min((ii+1)*self.nchunks,N)}/{N}")
            Phi = self.item_embs_slice(item_lst[0], item_lst[-1]+1)
            qs = self.score(Phi, h)
            if (self.quantize_digit>=0):
                qs = np.round(qs, self.quantize_digit)
            qs = csr_array(qs)
            if (qh is None):
                qh = qs.copy()
            else:
                qh = vstack((qh, qs))
        return qh

    def add_elements(self, fname, N, gen_fun, verbose=False):
        with open(fname, "a") as f:
            for ii, lst in enumerate(pbar := tqdm(
                chunks(N, self.nchunks),
                position=4,
                leave=False,
                disable=not verbose or ceil(N/self.nchunks) < 2
            )):
                pbar.set_description(f"Building elements {min((ii+1)*self.nchunks,N)}/{N}")
                Mat = gen_fun(len(lst))
                Mat = normalize(Mat, norm="l2", axis=1)
                np.savetxt(f, Mat, delimiter=",", newline="\n")
        return N

    def score(self, Phi, h):
        return (((Phi @ h.T).toarray()+1)/2)

class MovieLens(Environment):
    def __init__(self, params):
        assert "movielens_filepath" in params ## path to MovieLens files
        assert "seed" in params
        super().__init__(params)
        self.rng = np.random.default_rng(self.seed)
        self.item_file = self.name+"/items.csv"
        self.user_file = self.name+"/users.csv"
        self.user_history_file = self.name+"/history_%d.pkl"
        self.model = None
        if (not os.path.exists(self.item_file)):
            self.compute_item_embs()
        self.compute_user_hist_score()
        self.reset_user_hist()
        self.d = None
        ratings = pd.read_csv(self.user_file)
        self.item_indices = ratings['movieId'].unique()
        embeddings = pd.read_csv(self.item_file, index_col=0)
        self.nitem, self.ditem = embeddings.shape

    def compute_item_embs(self):
        movies = pd.read_csv(f'{self.movielens_filepath}/movies.csv')
        links = pd.read_csv(f'{self.movielens_filepath}/links.csv')
        links['url'] = links['imdbId'].map(lambda x: f'https://www.imdb.com/title/tt{x:07d}/')
        movies["url"] = links.loc[movies.index]["url"]
        movies["genres"] = [",".join(g.split("|")) for g in movies["genres"]]
        db = links.merge(movies, on='url')
        db['content'] = db['title'] + ' Keywords: ' + db['genres']
        db = db.dropna(subset=['content'])
        embed = hub.load("https://www.kaggle.com/models/google/universal-sentence-encoder/TensorFlow2/universal-sentence-encoder/2")
        embeddings = embed(db['content'].tolist())
        embeddings = pd.DataFrame(embeddings, index=db['movieId_x'].unique(), columns=range(embeddings.shape[1]))
        ratings = pd.read_csv(f'{self.movielens_filepath}/ratings.csv')
        self.item_indices = ratings['movieId'].unique()
        embeddings = embeddings.loc[self.item_indices]
        Mat = pd.DataFrame(normalize(embeddings.values, norm="l2", axis=1), index=embeddings.index, columns=embeddings.columns)
        Mat.to_csv(self.item_file)
        self.nitem, self.d = embeddings.shape

    def compute_user_hist_score(self):
        ratings = pd.read_csv(f'{self.movielens_filepath}/ratings.csv')
        ratings['user'] = np.unique(ratings['userId'], return_inverse=True)[1]
        ratings['item'] = np.unique(ratings['movieId'], return_inverse=True)[1]
        ratings.to_csv(self.user_file)
        self.model = MangakiSVD()
        self.model.nb_users = ratings['user'].nunique()
        self.model.nb_works = ratings['item'].nunique()
        self.model.fit(ratings[['user', 'item']].values, ratings['rating'])
        self.nuser = self.model.nb_users
        self.d_user = 1

    def reset_user_hist(self):
        ratings = pd.read_csv(self.user_file)
        for user in ratings['user'].unique():
            dff = ratings.query(f"`user` == {user}")
            if (dff.shape[0]>0):
                H = dff[["item"]].values.ravel().tolist()
                self.set_user_hist(user, H)

    def feedback(self, item, user):
        ratings = pd.read_csv(self.user_file, index_col=0)
        qh = None
        for i in item:
            dff = ratings.query(f"`user` == {user}")
            if (dff.shape[0]==0):
                qs = csr_array(np.ones((1, 1)))
            else:
                ii_id = self.item_indices[i-1]
                dff = dff.query(f"`movieId` == {ii_id}")
                if (dff.shape[0]==0):
                    qs = csr_array(np.ones((1, 1)))
                else:
                    dff = dff[['user', 'item']].values
                    qs = csr_array(self.model.predict(dff).reshape(-1, 1))
            if (qh is None):
                qh = qs.copy()
            else:
                qh = vstack((qh, qs))
        return qh

    def feedback_slice(self, item_start, item_end, user):
        return self.feedback(range(item_start, item_end), user)

    def item_embs(self, item_ls):
        embeddings = pd.read_csv(self.item_file, index_col=0)
        ids = [self.item_indices[i-1] for i in item_ls]
        return csr_array(embeddings.loc[ids].values)

    def item_embs_slice(self, item_start, item_end):
        return self.item_embs(range(item_start, item_end))

class DrugRepurposing(Environment):
    def __init__(self, params):
        assert ("filepath" in params) or ("dataset_name" in params) ## path to data set files
        assert ("dataset_name" not in params) or (params["dataset_name"] in ["Gottlieb", "Cdataset", "DNdataset", "LRSSL",
    "PREDICT_Gottlieb", "TRANSCRIPT", "PREDICT"])
        assert "seed" in params
        def_params = { "k": 15, "learning_rate": 1e-3, "epoch": 1000,
                "weight_decay": 0.0, "decision_threshold": 0
        }
        for a in def_params:
            setattr(self, a, def_params[a])
        super().__init__(params)
        self.rng = np.random.default_rng(self.seed)
        self.item_file = self.name+"/items.csv"
        self.user_file = self.name+"/users.csv"
        self.rating_file = self.name+"/ratings.csv"
        self.score_file = self.name+"/scores.csv"
        self.user_history_file = self.name+"/history_%d.pkl"
        if ("dataset_name" in params):
            data_args = load_dataset(self.dataset_name, save_folder="/".join(self.name.split("/")[:-1]))
        else:
            data_args = {}
        if (not os.path.exists(self.item_file)):
            self.compute_item_embs(data_args)
        if (not os.path.exists(self.user_file)):
            self.compute_user_embs(data_args)
        if (not os.path.exists(self.score_file)):
            self.compute_embs_hist_score(data_args)
        self.reset_user_hist()
        drugs = pd.read_csv(self.item_file,index_col=0)
        self.item_list = list(drugs.index)
        diseases = pd.read_csv(self.user_file,index_col=0)
        self.nuser, self.d_user = diseases.shape
        self.nitem, self.d = drugs.shape

    def compute_item_embs(self, data_args):
        if (len(data_args)==0):
            drugs = pd.read_csv(f'{self.filepath}/items.csv',index_col=0)
        else:
            drugs = data_args["items"]
        drugs.replace([np.inf, -np.inf], np.nan, inplace=True)
        drugs = drugs.fillna(0.)
        self.item_list = list(drugs.columns)
        self.d, self.nitem = drugs.shape
        drugs = pd.DataFrame(normalize(drugs.values, norm="l2", axis=1), index=drugs.index, columns=drugs.columns)
        drugs.T.to_csv(self.item_file)

    def compute_user_embs(self, data_args):
        if (len(data_args)==0):
            diseases = pd.read_csv(f'{self.filepath}/users.csv',index_col=0)
        else:
            diseases = data_args["users"]
        diseases.replace([np.inf, -np.inf], np.nan, inplace=True)
        diseases = diseases.fillna(0.)
        self.d_user, self.nuser = diseases.shape
        diseases = pd.DataFrame(normalize(diseases.values, norm="l2", axis=1), index=diseases.index, columns=diseases.columns)
        diseases.T.to_csv(self.user_file)

    def compute_embs_hist_score(self, data_args):
        if (len(data_args)==0):
            ratings = pd.read_csv(f'{self.filepath}/ratings_mat.csv', index_col=0)
            drugs = pd.read_csv(f'{self.filepath}/items.csv',index_col=0)
            diseases = pd.read_csv(f'{self.filepath}/users.csv',index_col=0)
            data_args = {"items": drugs, "users": diseases, "ratings": ratings}
        else:
            ratings = data_args["ratings"]
        ratings.to_csv(self.rating_file)
        dataset = Dataset(**data_args)
        params = { "k": self.k, "learning_rate": self.learning_rate,
            "epoch": self.epoch, "weight_decay":self.weight_decay,
            "decision_threshold": self.decision_threshold,
            "seed": self.seed
        }
        model = HAN(params)
        model.fit(dataset)
        scores = model.predict_proba(dataset)
        scores = pd.DataFrame(scores.toarray(), index=ratings.index, columns=ratings.columns)
        scores.to_csv(self.score_file)

    def reset_user_hist(self):
        ratings = pd.read_csv(self.rating_file, index_col=0)
        for uu, user in enumerate(ratings.columns):
            H = np.argwhere(ratings[user].values.ravel() != 0).ravel().tolist()
            if (len(H)>0):
                self.set_user_hist(uu, H)

    def feedback(self, item, user):
        ratings = pd.read_csv(self.score_file, index_col=0)
        scores = ratings.loc[ratings.index[item]][[ratings.columns[user]]].values.reshape((len(item),1))
        return csr_array(scores)

    def feedback_slice(self, item_start, item_end, user):
        return self.feedback(range(item_start, item_end), user)

    def item_embs(self, item_ls):
        embeddings = pd.read_csv(self.item_file, index_col=0)
        return csr_array(embeddings.loc[embeddings.index[item_ls]].values)

    def item_embs_slice(self, item_start, item_end):
        return self.item_embs(range(item_start, item_end))

    def user_embs(self, user_ls):
        embeddings = pd.read_csv(self.user_file, index_col=0)
        return csr_array(embeddings.loc[embeddings.index[user_ls]].values)

    def user_embs_slice(self, user_start, user_end):
        return self.user_embs(range(user_start, user_end))

if __name__ == "__main__":
    from time import time

    ###### TRANSCRIPT DATA SET
    env = DrugRepurposing(dict(name="../../dw_datasets/TRANSCRIPT", dataset_name="TRANSCRIPT", seed=1234))
    print("TRANSCRIPT (drug repurposing)")
    scores = env.feedback([0], 0)
    print(scores.shape)
    scores = env.feedback_slice(0, env.nitem, 0)
    print(scores.shape)
    embs = env.item_embs([0,10,2,30])
    print(embs.shape)
    embs2 = env.item_embs_slice(0, env.nitem)
    print(embs2.shape)

    ###### MOVIELENS DATA SET
    movielens_filepath = '../../dw_datasets/ml-latest-small/'
    assert os.path.exists(movielens_filepath)
    env = MovieLens(dict(name="../../datasets/MovieLens", movielens_filepath=movielens_filepath, seed=1234))
    print("MovieLens (movie recommendation)")
    scores = env.feedback([0], 0)
    print(scores.shape) ## (1,1)
    scores = env.feedback_slice(0, env.nitem, 0)
    print(scores.shape) ## (9724, 1)
    embs = env.item_embs([0,10,2,30])
    print(embs.shape) ## (4, 512)
    embs2 = env.item_embs_slice(0, env.nitem)
    print(embs2.shape) ## (9724, 512)

    ###### PREDICT DATA SET
    predict_filepath = "../../dw_datasets/PREDICT_private/"
    env = DrugRepurposing(dict(name="../../datasets/PREDICT", epoch=100, filepath=predict_filepath, seed=1234))
    print("PREDICT (drug repurposing)")
    scores = env.feedback([0], 0)
    print(scores.shape)
    scores = env.feedback_slice(0, env.nitem, 0)
    print(scores.shape)
    embs = env.item_embs([0,10,2,30])
    print(embs.shape)
    embs2 = env.item_embs_slice(0, env.nitem)
    print(embs2.shape)

    ###### SYNTHETIC DATA SET
    q=3
    folder_name = "../../datasets/Tabular_test_data"
    items = np.array([[1,0,1], [-1,0,-1], [0,1,0]])
    users = np.array([[1,0,1]])
    feedbacks = np.array([[1], [-1], [0]])
    env = TabularNoHistory(dict(name=folder_name, items=items, users=users, feedbacks=feedbacks, seed=1234, quantize_digit=q, new=True))
    val = env.feedback([0,2], 0).toarray().ravel()
    assert (val == np.array([[1, 0]])).all()
    val = env.feedback_slice(0, 1+1, 0).toarray().ravel()
    assert (val == np.array([[1, -1]])).all()
    val = env.feedback([2,0], 0).toarray().ravel()
    assert (val == np.array([[0, 1]])).all()

    start_T = time()
    q=3
    folder_name = "../../datasets/Synthetic_test_data"
    def get_line(fname, idx):
        rd = sb.check_output(f"head -n{idx+1} {folder_name}/{fname}.csv | tail -n1", shell=True)
        clean_rd = rd.decode("utf-8").split("\n")[0].split(",")
        a = np.array([clean_rd],dtype=float)
        a = csr_array(a)
        if (q>=0):
            a = np.round(a, q)
        return a
    def get_lines(fname, idx_lst):
        #return np.concatenate([get_line(fname, idx) for idx in idx_lst], axis=0)
        #return vstack(tuple([get_line(fname, idx) for idx in idx_lst]))
        A = None
        for idx in idx_lst:
            a = get_line(fname, idx)
            if (A is None):
                A = a.copy()
            else:
                A = vstack((A, a))
        return A
    ## change d, quantize_digit
    env = SyntheticCosine(dict(name=folder_name, nitem=1000000, nuser=100, d=10, seed=1234, quantize_digit=q, new=True))
    idx = 10
    out = get_lines("items", [idx])
    out1 = env.item_embs([idx])
    #print((out.shape, out1.shape))
    #print(vstack((out, out1)).T.toarray())
    assert np.isclose(np.abs(out-out1).sum(),0)
    out = get_lines("users", [idx])
    out1 = env.user_embs([idx])
    assert np.isclose(np.abs(out-out1).sum(),0)
    idx_lst = [99094,3242,5000]
    out = get_lines("items", idx_lst)
    out1 = env.item_embs(idx_lst)
    assert np.isclose(np.abs(out-out1).sum(),0)
    idx_lst = [99,0,50]
    out = get_lines("users", idx_lst)
    out1 = env.user_embs(idx_lst)
    assert np.isclose(np.abs(out-out1).sum(),0)
    idx_lst = [5000, 5500]
    out = get_lines("items", range(idx_lst[0],idx_lst[1]+1,1))
    out1 = env.item_embs_slice(idx_lst[0], idx_lst[1]+1)
    #print((out.shape, out1.shape))
    assert np.isclose(np.abs(out-out1).sum(),0)
    idx_lst = [10, 20]
    out = get_lines("users", range(idx_lst[0],idx_lst[1]+1,1))
    out1 = env.user_embs_slice(idx_lst[0], idx_lst[1]+1)
    #print((out.shape, out1.shape))
    assert np.isclose(np.abs(out-out1).sum(),0)
    idx_lst = [99094,3242,5000]
    user = 0
    ls = env.feedback(idx_lst, user)
    assert ls.shape==(len(idx_lst),1)
    out = env.feedback([5000], user)
    out1 = ls[[-1],:]
    assert np.isclose(np.abs(out-out1).sum(),0)
    print(f"Time = {np.round(time()-start_T, 3)} sec")
    lss = env.feedback_slice(0, env.nitem, user)
    assert lss.shape==(env.nitem,1)
    out2 = lss[[5000],:]
    assert np.isclose(np.abs(out-out2).sum(),0)
    print(f"Time = {np.round(time()-start_T, 3)} sec")
    user = 0
    H = env.get_user_hist(user)
    assert len(H)==0
    H =[5000, 5500]
    env.update_user_hist(user, H)
    H1 = env.get_user_hist(user)
    assert len(H)==len(H1)
    assert all([H[i]==H1[i] for i in range(len(H))])
    user = 1
    H = env.get_user_hist(user)
    assert len(H)==0
    proc = sb.Popen(f"rm -rf {folder_name}".split(" "))
    proc.wait()
