#!/usr/bin/env python
# coding: utf-8
# In[43]:
import pandas as pd
import numpy as np
import os
import argparse
from scipy.stats import percentileofscore, rankdata
# In[51]:


parser = argparse.ArgumentParser(description='PyTorch Neural Collaborative Filtering')
parser.add_argument('--dataset', type=str, default='douban')
parser.add_argument('--cold_percent', type=int, default=50, help='upper epoch limit')
parser.add_argument('--train_percent', type=float, default=0.6, help='upper epoch limit')
parser.add_argument('--k', type=int, default=100)
args = parser.parse_args()
k = args.k
# In[571]:


dataset = args.dataset
if args.cold_percent == 50 and args.train_percent == 0.6:
    dataset = dataset
else:
    dataset = '{}_{}_{}'.format(dataset, args.cold_percent, args.train_percent)
dataset_path = '/home/data/{}'.format(dataset)

if 'yelp_AZ' in dataset:
    alter_data_path = '/home/data/{}'.format('yelp_AZ')
elif 'movie' in dataset:
    alter_data_path = '/home/data/{}'.format('movie_lens')
else:
    alter_data_path = '/home/data/{}'.format(dataset.split('_')[0])



test_df = pd.read_csv(os.path.join(alter_data_path, 'c_c.csv'))
n_users = test_df.userID.max()+1
n_items = test_df.itemID.max()+1
print(n_users, n_items)
dataset_path = '/home/data/{}'.format(dataset)
test_df = pd.read_csv(os.path.join(alter_data_path, 'c_c.csv'))
n_users = test_df.userID.max()+1
n_items = test_df.itemID.max()+1
print(n_users, n_items)
train_df = pd.read_csv(os.path.join(dataset_path, 'w_w_train.csv'))
test_df = pd.read_csv(os.path.join(dataset_path, 'c_c_train.csv'))
output_path = '/home/data/{}'.format(dataset)


n_feature = k


def cut(df, n_bins, key, bins=None):
    #
    if bins is None:
        bins = pd.cut(df[key], n_bins, duplicates='drop')
    else:
        bins = pd.cut(df[key], bins)
    new_df = df.copy()
    new_df['bin'] = bins
    return new_df


def bin_data(df, key, n_bins, bins=None):
    all_data_binned = cut(df, n_bins, key, bins=bins)
    return all_data_binned


def get_user_feature(df_mapped, item_bin, user_feature, users, n_feature=10):
    for i in users:
        curr_feature = np.zeros(n_feature)
        items = df_mapped[df_mapped['userID'] == i].itemID.values
        for item in items:
            curr_feature[int(item_bin[item])] += 1
        user_feature[i] = curr_feature
    return user_feature


def get_item_feature(df_mapped, user_bin, item_feature, items, n_feature=10):
    for i in items:
        curr_feature = np.zeros(n_feature)
        users = df_mapped[df_mapped['itemID'] == i].userID.values
        for user in users:
            curr_feature[int(user_bin[user])] += 1
        item_feature[i] = curr_feature
    return item_feature

def normalize_act(pert, act):
    act_n = {}
    sum_n = 0
    for k,v in pert.items():
        if v not in act_n:
            act_n[v] = act[k]
        else:
            act_n[v] += act[k]
        sum_n += act[k]
    print(len(act_n))
    for k,v in act_n.items():
        act_n[k] = (sum_n/v)

    return act_n


def get_feature(df_mapped, user_bin=None, item_bin=None, n_feature=10):
    user_activity = df_mapped.groupby('userID').count()
    item_activity = df_mapped.groupby('itemID').count()
    user_act = dict(zip(user_activity.index, user_activity.itemID))
    print(user_activity.index)
    item_act = dict(zip(item_activity.index, item_activity.userID))
    if user_bin is not None:
        user_bin[-1] = user_activity.itemID.max() + 1
        item_bin[-1] = item_activity.userID.max() + 1

    user_binned = bin_data(user_activity, 'itemID', n_feature, user_bin)
    user_binned['userID'] = user_binned.index
    idx = user_binned.groupby('bin').size().reset_index()['bin'].reset_index()
    print(idx)
    user_binned = user_binned.merge(idx, on='bin', how='inner')
    user_binned.rename(columns={'index': 'bin_index'}, inplace=True)
    user_bin = dict(zip(user_binned.userID, user_binned.bin_index))

    item_binned = bin_data(item_activity, 'userID', n_feature, item_bin)
    item_binned['itemID'] = item_binned.index
    idx = item_binned.groupby('bin').size().reset_index()['bin'].reset_index()
    print(idx)
    item_binned = item_binned.merge(idx, on='bin', how='inner')
    item_binned.rename(columns={'index': 'bin_index'}, inplace=True)
    item_bin = dict(zip(item_binned.itemID, item_binned.bin_index))

    return item_bin, user_bin, item_act, user_act


item_bin1, user_bin1, item_act1, user_act1 = get_feature(pd.read_csv(os.path.join(output_path, 'w_w_train.csv')), n_feature=n_feature)
item_bin2, user_bin2, item_act2, user_act2 = get_feature(pd.read_csv(os.path.join(output_path, 'c_c_train.csv')), n_feature=n_feature)
# print(user_bin1.keys())
# print(user_act1.keys())

# user_act_n1 = normalize_act(user_bin1, user_act1)
# item_act_n1 = normalize_act(item_bin1, item_act1)
# user_act_n2 = normalize_act(user_bin2, user_act2)
# item_act_n2 = normalize_act(item_bin2, item_act2)

user_feature = [[1,1]] * n_users
item_feature = [[1,1]] * n_items
print(user_bin1.keys() == user_act1.keys())
print(user_bin2.keys() == user_act2.keys())
for i in range(n_users):
    if i in user_bin1:
        user_feature[i] = [user_bin1[i], user_act1[i]]
        # user_feature[i] = [user_bin1[i], user_act_n1[user_bin1[i]]]
    elif i in user_bin2:
        user_feature[i] = [user_bin2[i], user_act2[i]]
        # user_feature[i] = [user_bin2[i], user_act_n2[user_bin2[i]]]
print(n_items, len(item_bin1.keys()) + len(item_bin2.keys()))
print(max(item_bin1.keys()), len(item_bin1.keys()))
for i in range(n_items):
    if i in item_bin1:
        item_feature[i] = [item_bin1[i], item_act1[i]]
        # item_feature[i] = [item_bin1[i], item_act_n1[item_bin1[i]]]
    elif i in item_bin2:
        item_feature[i] = [item_bin2[i], item_act2[i]]
        # item_feature[i] = [item_bin2[i], item_act_n2[item_bin2[i]]]
print(user_feature[:100])
# print(item_feature[:100])
if k==10:
    add = ""
else:
    add='_{}'.format(k)
with open(os.path.join(output_path, 'user_stat{}.npy'.format(add)), 'wb') as f:
    np.save(f, np.array(user_feature))
with open(os.path.join(output_path, 'item_stat{}.npy'.format(add)), 'wb') as f:
    np.save(f, np.array(item_feature))


