#!/usr/bin/env python
# coding: utf-8

# In[43]:
import pandas as pd
import numpy as np
import os
import argparse
from scipy.stats import percentileofscore, rankdata
# In[51]:


parser = argparse.ArgumentParser(description='PyTorch Neural Collaborative Filtering')
parser.add_argument('--dataset', type=str, default='douban')
parser.add_argument('--cold_percent', type=int, default=50, help='upper epoch limit')
parser.add_argument('--train_percent', type=float, default=0.6, help='upper epoch limit')
parser.add_argument('--k', type=int, default=100)
args = parser.parse_args()
k = args.k
# In[571]:


dataset = args.dataset
dataset_path = '/home/data/{}'.format(dataset)
if 'yelp_AZ' in dataset:
    alter_data_path = '/home/data/{}'.format('yelp_AZ')
elif 'movie' in dataset:
    alter_data_path = '/home/data/{}'.format('movie_lens')
else:
    alter_data_path = '/home/data/{}'.format(dataset.split('_')[0])
test_df = pd.read_csv(os.path.join(alter_data_path, 'c_c.csv'))
n_users = test_df.userID.max()+1
n_items = test_df.itemID.max()+1
print(n_users, n_items)
if args.cold_percent == 50 and args.train_percent == 0.6:
    dataset = dataset
else:
    dataset = '{}_{}_{}'.format(dataset, args.cold_percent, args.train_percent)
dataset_path = '/home/data/{}'.format(dataset)
test_df = pd.read_csv(os.path.join(alter_data_path, 'c_c.csv'))
n_users = test_df.userID.max()+1
n_items = test_df.itemID.max()+1
print(n_users, n_items)
train_df = pd.read_csv(os.path.join(dataset_path, 'w_w_train.csv'))
test_df = pd.read_csv(os.path.join(dataset_path, 'c_c_train.csv'))
output_path = '/home/data/{}'.format(dataset)


n_feature = 100


# In[ ]:


def cut(df, n_bins, key, bins=None):
    #
    if bins is None:
        bins = pd.cut(df[key], n_bins, duplicates='drop')
    else:
        bins = pd.cut(df[key], bins)
    new_df = df.copy()
    new_df['bin'] = bins
    return new_df

def bin_data(df, key, n_bins, bins=None):
    all_data_binned = cut(df, n_bins, key, bins=bins)
    return all_data_binned

def get_user_feature(df_mapped, item_bin, user_feature, users, n_feature=10):
    for i in users:
        curr_feature = np.zeros(n_feature)
        items = df_mapped[df_mapped['userID']==i].itemID.values
        for item in items:
            curr_feature[int(item_bin[item])] += 1
        user_feature[i] = curr_feature
    return user_feature

def get_item_feature(df_mapped, user_bin, item_feature, items, n_feature=10):
    for i in items:
        curr_feature = np.zeros(n_feature)
        users = df_mapped[df_mapped['itemID']==i].userID.values
        for user in users:
            # curr_feature += pop_embed(user_bin[user])
            curr_feature[int(user_bin[user])] += 1
        item_feature[i] = curr_feature
    return item_feature

def get_percentile_dict(activity, act, percentile, n_feature):
    perc_dict = {}
    for i, n in enumerate(percentile):
        for j in n:
            perc_dict[j] = i


    user = activity.index

    ret = {}
    for i, u in enumerate(user):
        ret[u] = perc_dict[act[i]]
    return ret

def get_dict(act, n_feature):
    ret = []
    split = np.array_split(np.unique(act), n_feature)
    act_dict = {}
    for i, n in enumerate(split):
        for j in n:
            act_dict[j] = i
    for n in act:
        ret.append(act_dict[n])
    return ret

def pop_embed(perc):
    if perc == 0:
        return [0]*11
    loc = int(perc//10)
    if perc % 10 == 0:
        return [0]*loc + [1] + [0]*(10 - loc)
    return [0]*loc + [1 - (perc%10) / 10] + [(perc%10) / 10] + [0]*(9 - loc)


def get_feature(df_mapped, user_feature, item_feature, user_bin= None, item_bin = None, n_feature=10):

    user_activity = df_mapped.groupby('userID').count()
    item_activity = df_mapped.groupby('itemID').count()
    user_act = np.squeeze(user_activity[['itemID']].values)
    item_act = np.squeeze(item_activity[['userID']].values)

    user_unique = np.unique(np.squeeze(user_activity[['itemID']].values))
    item_unique = np.unique(np.squeeze(item_activity[['userID']].values))
    print(len(user_unique))
    print(len(item_unique))
    # if len(user_unique) < 100 or len(item_unique) < 100:

    if user_bin is not None:
        user_bin[-1] = user_activity.itemID.max() + 1
        item_bin[-1] = item_activity.userID.max() + 1

    user_binned = bin_data(user_activity, 'itemID', n_feature, user_bin)
    user_binned['userID'] = user_binned.index
    user_bin = user_binned[['bin']]
    idx = user_binned.groupby('bin').size().reset_index()['bin'].reset_index()
    print(idx)
    user_bins = [0]+ [a.right for a in idx['bin'].values.tolist()]
    user_binned = user_binned.merge(idx, on='bin', how='inner')
    user_binned.rename(columns={'index':'bin_index'}, inplace=True)
    user_bin = dict(zip(user_binned.userID, user_binned.bin_index))

    item_binned = bin_data(item_activity, 'userID', n_feature, item_bin)
    item_binned['itemID'] = item_binned.index
    idx = item_binned.groupby('bin').size().reset_index()['bin'].reset_index()
    print(idx)
    item_bins = [0]+ [a.right for a in idx['bin'].values.tolist()]
    item_binned = item_binned.merge(idx, on='bin', how='inner')
    item_binned.rename(columns={'index':'bin_index'}, inplace=True)
    item_bin = dict(zip(item_binned.itemID, item_binned.bin_index))


    user_feature = get_user_feature(df_mapped, item_bin, user_feature, df_mapped.userID.unique().tolist(), n_feature)
    item_feature = get_item_feature(df_mapped, user_bin, item_feature, df_mapped.itemID.unique().tolist(), n_feature)


    return user_feature, item_feature

user_feature = [np.zeros(100)] * n_users
item_feature = np.array([np.zeros(100)] * n_items)
user_feature, item_feature= get_feature(pd.read_csv(os.path.join(output_path, 'w_w_train.csv')), user_feature, item_feature, n_feature=n_feature)
# user_feature, item_feature, user_bins, item_bins = get_feature(train_df, user_feature, item_feature)


# In[ ]:



user_feature, item_feature= get_feature(pd.read_csv(os.path.join(output_path, 'c_c_train.csv')), user_feature, item_feature, n_feature=n_feature)


# In[ ]:


print(pd.read_csv(os.path.join(output_path, 'c_c_train.csv')).itemID.unique())
print(pd.read_csv(os.path.join(output_path, 'w_w_train.csv')).itemID.unique())

np.intersect1d(pd.read_csv(os.path.join(output_path, 'c_c_train.csv')).itemID.unique(), 
             pd.read_csv(os.path.join(output_path, 'w_w_train.csv')).itemID.unique() )

print(np.array(user_feature).shape, np.array(item_feature).shape)
# In[ ]:
print(user_feature[0])
if k==10:
    add = ""
else:
    add='_{}'.format(k)
with open(os.path.join(output_path, 'user_popularity_new{}.npy'.format(add)), 'wb') as f:
    np.save(f, np.array(user_feature))
with open(os.path.join(output_path, 'item_popularity_new{}.npy'.format(add)), 'wb') as f:
    np.save(f, np.array(item_feature))


# In[ ]:






# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:




