import scipy.sparse as spp
import numpy as np
import json
# from ExtractFeatures import extract_features, get_review_matrix
# from GenerateInfo import generate_info
from sklearn.decomposition import PCA
from sklearn import linear_model
from sklearn.preprocessing import normalize


geoMap={}
catMap={}
attrMap={}
AttrMap={}
busMap={}


def extract_rows(top_k, sparse_matrix):
    business_review_count = sparse_matrix.getnnz(axis=1)
    business_count = business_review_count.shape[0]
    top_k_index = np.argsort(business_review_count)[business_count - 1: business_count - 1 - top_k:-1]
    # top_k_index = np.random.choice(business_count, top_k, replace=False)
    matrix = spp.vstack([sparse_matrix.getrow(i) for i in top_k_index])
    return matrix


def extract_cols(top_k, sparse_matrix):
    user_review_count = sparse_matrix.getnnz(axis=0)
    user_count = user_review_count.shape[0]
    top_k_index = np.argsort(user_review_count)[user_count - 1: user_count - 1 - top_k:-1]
    # top_k_index=np.random.choice(user_count, top_k, replace=False)
    matrix = spp.hstack([sparse_matrix.getcol(i) for i in top_k_index])
    return matrix, top_k_index


def get_review_matrix(user_num, item_num, sparse_matrix):
    print('in get_review_matrix')
    row_reduced_matrix = extract_rows(user_num * 3, sparse_matrix)
    reduced_matrix, top_items = extract_cols(item_num, row_reduced_matrix)
    reduced_matrix = extract_rows(user_num, reduced_matrix)
    return reduced_matrix.toarray(), top_items


def avgRating_revCount(source_prefix='../yelp'):
    avgRating,revCount={},{}
    with open(source_prefix + '/review.json', encoding='utf-8') as f:
        for line in f:
            dicts = json.loads(line)
            # user_id = dicts['user_id']
            business_id = dicts['business_id']
            rating = float(dicts['stars'])
            if business_id not in avgRating:
                avgRating[business_id]=rating
                revCount[business_id]=1
            else:
                avgRating[business_id]+=rating
                revCount[business_id]+=1
    for k in avgRating:
        avgRating[k]=avgRating[k]/revCount[k]
    return avgRating,revCount


def attr2vec(attrDict):
    attrLen=34
    attrFea=np.zeros((attrLen))
    for i in attrDict:
        if i not in attrMap:
            attrMap[i]={attrDict[i]:0}
        elif attrDict[i] not in attrMap[i]:
            attrMap[i][attrDict[i]]=len(attrMap[i])
        if i not in AttrMap:
            AttrMap[i]=len(AttrMap)
        if AttrMap[i]<34:
            attrFea[AttrMap[i]]=attrMap[i][attrDict[i]]
        else:
            print(i,attrDict[i])
    return attrFea


def rawData2item_fv(rawD,avgRat,revCnt):
    # geo fea
    geoFea_len=330
    geoFea=np.zeros((geoFea_len))
    geoInfo=rawD['city']
    if geoInfo not in geoMap:
        geoMap[geoInfo]=len(geoMap)
    if geoMap[geoInfo]<geoFea_len:
        geoFea[geoMap[geoInfo]]=1
    else:
        print(geoInfo,geoMap[geoInfo])
    # cat fea
    catFea_len=1004
    catFea=np.zeros((catFea_len))
    catRet=[]
    if rawD['categories'] is not None:
        catList=rawD['categories'].split(',')
        for cat in catList:
            cat=cat.strip()
            if cat not in catMap:
                catMap[cat]=len(catMap)
            if catMap[cat]<catFea_len:
                catFea[catMap[cat]]=1
                catRet.append(catMap[cat])
            else:
                print(cat,catMap[cat])
    # avg rating & tot review count
    valFea=np.array([avgRat,revCnt])
    # attr fea
    attrLen=34
    attrFea=np.zeros((attrLen))
    if rawD['attributes'] is not None:
        attrFea=attr2vec(rawD['attributes'])
    
    fullFea=np.concatenate((geoFea,catFea,valFea,attrFea))
    return fullFea,catRet


def genItemFea(top_business,source_prefix='../yelp'):
    print('in genItemFea')
    avgRating,revCount=avgRating_revCount()
    itemFea,itemCatList={},{}
    with open(source_prefix + '/business.json', encoding='utf-8') as f:
        for line in f:
            dicts = json.loads(line)
            business_id = dicts['business_id']
            if business_id not in busMap:
                busMap[business_id]=len(busMap)
            busInd=busMap[business_id]
            itemFea[busInd],itemCatList[busInd]=rawData2item_fv(dicts,avgRating[business_id],revCount[business_id])
    itemFea=np.array([itemFea[busMap[i]] for i in top_business])
    itemCatList=[itemCatList[busMap[i]] for i in top_business]
    return itemFea,itemCatList


def genContextualVec(itemFea,userItemRating,user_n=2000,dim=50):
    print('in genContextualVec')
    pca = PCA(n_components=dim,whiten=True)
    X=pca.fit_transform(itemFea)
    reg=linear_model.Ridge(alpha=0.5)
    theta=[]
    for i in range(user_n):
        reg.fit(X,userItemRating[i])
        theta.append(reg.coef_)
    theta=np.array(theta)
    X = normalize(X, axis=1, norm='l2')
    theta = normalize(theta, axis=1, norm='l2')
    return X,theta


def storeInfo(U, V, user_num, attr_list, target_prefix):
    print('in storeInfo')
    # user info: user_preference.txt
    f = open(target_prefix + '/user_preference.txt', 'w')
    for i in range(user_num):
        user = dict()
        user['uid'] = i
        user['preference_v'] = U[i].reshape((-1, 1)).tolist()
        f.write(json.dumps(user) + '\n')
    f.close()

    # pair items and attributes
    # attr_list = [attr_list[i] for i in top_items]
    item_attr = dict()
    for attr in attr_list:
        for item in attr:
            if item not in item_attr:
                item_attr[item] = len(item_attr)

    # item info: arm_info.txt
    # item-attribute info: arm_suparm_relation.txt
    f = open(target_prefix + '/arm_suparm_relation.txt', 'w')
    g = open(target_prefix + '/arm_info.txt', 'w')
    separate = ','
    for i, attr in enumerate(attr_list):
        arm = {'a_id': i, 'fv': V[i, :].reshape((-1, 1)).tolist()}
        f.write(str(i) + '\t' + separate.join([str(item_attr[item]) for item in attr]) + ',\n')
        g.write(json.dumps(arm) + '\n')
    f.close()
    g.close()


def load_sparse_matrix(source_prefix='../yelp'):
    print('in load_sparse_matrix')
    # key: business_id, item: list of categories or None
    categories = {}
    bs = open(source_prefix + '/business.json', encoding='utf-8')
    for line in bs:
        dicts = json.loads(line)
        if dicts['categories'] is None:
            categories[dicts['business_id']] = None
        else:
            items = dicts['categories'].split(',')
            categories[dicts['business_id']] = []
            for item in items:
                categories[dicts['business_id']].append(item.strip())
    bs.close()

    # get tuples for sparse matrix
    user_dict, business_dict = {}, {}
    category_list = []
    data, rows, cols = [], [], []
    busIdList=[]
    rf = open(source_prefix + '/review.json', encoding='utf-8')
    for line in rf:
        dicts = json.loads(line)
        user_id = dicts['user_id']
        business_id = dicts['business_id']
        rating = float(dicts['stars'])
        rating = 1 if rating > 3 else 0

        # key: user id, item: user index
        if user_id not in user_dict:
            user_dict[user_id] = len(user_dict)
        row_index = user_dict[user_id]

        # key: business id, item: business index
        if business_id not in business_dict:
            business_dict[business_id] = len(business_dict)
            # a list of categories in the order of business index
            category_list.append(categories[business_id])
            busIdList.append(business_id)
        col_index = business_dict[business_id]

        # add tuple to list
        data.append(rating)
        rows.append(row_index)
        cols.append(col_index)
    rf.close()

    # generate sparse matrix
    data = np.asarray(data)
    rows = np.asarray(rows)
    cols = np.asarray(cols)
    print(len(data), len(rows), len(cols))
    print(len(user_dict), len(business_dict), len(category_list))
    return spp.csr_matrix((data, (rows, cols))), category_list, busIdList


if __name__ == '__main__':
    target_prefix = 'conUCB_yelp_data'
    user_num = 2000
    business_num = 25000
    d = 50

    M, category_list, busIdList = load_sparse_matrix()
    M, top_business_ind = get_review_matrix(user_num, business_num, M)

    # U, V = extract_features(user_num, business_num, d, M)
    # generate_info(U, V, user_num, top_business, category_list, target_prefix)
    # np.save(target_prefix + '/user_item.npy', M)

    itemFea,itemCatList=genItemFea([busIdList[i] for i in top_business_ind])
    X,theta=genContextualVec(itemFea,M,user_num,d)
    storeInfo(theta,X,user_num,itemCatList,target_prefix)
