import os
import pickle
import numpy as np
import pandas as pd
from tqdm import trange
from surprise.prediction_algorithms.matrix_factorization import NMF
from surprise import Dataset, Reader
from utils import factorize, set_prices, param_str
from collections import Counter


def create_yelp_data(d=12, gamma=0.5, V_epsilon=0.0, p_epsilon=0.0, rho=0.0):
    seed = 3
    np.random.seed(seed)

    n_markets = 240
    n_splits = 6
    n_seeds = 6
    m = 20  # num items
    n = 20  # num users
    d_prime = d // 2

    # create costum data for yelp
    # yelp dataset: https://www.kaggle.com/datasets/yelp-dataset/yelp-dataset
    # (some of the building data code is taken from: 
    # https://www.kaggle.com/code/fahd09/yelp-dataset-surpriseme-recommendation-system/notebook)
    # filter businesses to restaurants only
    df_yelp_business = pd.read_json('yelp_data/yelp_academic_dataset_business.json', lines=True)
    df_yelp_business.fillna('NA', inplace=True)
    df_yelp_business = df_yelp_business[df_yelp_business['categories'].str.contains('Restaurants')]
    
    df_yelp_review_iter = pd.read_json('yelp_data/yelp_academic_dataset_review.json', chunksize=100000, lines=True)
    df_yelp_review = pd.DataFrame()
    for df in df_yelp_review_iter:
        df = df[df['business_id'].isin(df_yelp_business['business_id'])]
        df_yelp_review = pd.concat([df_yelp_review, df])
    
    # filter out users with less than 20 reviews and businesses with less than 20 reviews
    user_count = Counter(df_yelp_review['user_id'])
    business_count = Counter(df_yelp_review['business_id'])
    df_yelp_review['user_id_count'] = df_yelp_review['user_id'].apply(user_count.get)
    df_yelp_review['business_id_count'] = df_yelp_review['business_id'].apply(business_count.get)
    df = df_yelp_review[(df_yelp_review['user_id_count'] > 20) & (df_yelp_review['business_id_count'] > 20)]
    
    # reformat for Surprise Dataset
    user_id_dct = {k: i for i, k in enumerate(df['user_id'].unique())}
    business_id_dct = {k: i for i, k in enumerate(df['business_id'].unique())}
    df['user_id_i'] = df['user_id'].apply(user_id_dct.get)
    df['business_id_i'] = df['business_id'].apply(business_id_dct.get)
    df = df[['user_id_i', 'business_id_i', 'stars']]
    
    reader = Reader(rating_scale=(1, 5))
    data = Dataset.load_from_df(df, reader)
    data = data.build_full_trainset()
    
    algo = NMF(n_factors=d)
    algo.fit(data)

    users_indices = np.random.choice(range(data.n_users), n * n_markets)
    all_B = algo.pu[users_indices]
    # derive T from all markets
    all_U, T = factorize(all_B, n_components=d_prime)

    all_U = all_U.reshape((n_markets, n, d_prime))
    all_B = all_B.reshape((n_markets, n, d))

    for items_seed in range(n_seeds):
        params_str = ''.join([param_str('gamma', gamma, 0.5),
                              param_str('V_eps', V_epsilon, 0.0),
                              param_str('p_eps', p_epsilon, 0.0),
                              param_str('rho', rho, 0.0)
                              ])
        filename = f'yelp__items_seed{str(items_seed)}__d{d}{params_str}.pkl'
        if os.path.exists(os.path.join('pickles', 'data', filename)):
            print(f'skipped {filename} - already exist')
            continue
        
        np.random.seed(items_seed)
        items_indices = np.random.choice(range(data.n_items), m)
        X = algo.qi[items_indices] / data.rating_scale[1]
        all_p = []

        for i in trange(n_markets):
            U = all_U[i]
            B = U @ T
            # calculate V
            V = B @ X.T
            # compute prices according to full information (V)
            p = set_prices(V, gamma, V_epsilon, p_epsilon, rho)    
            all_p.append(p)
        all_p = np.array(all_p)

        # ctrate splits
        indices = np.arange(n_markets).astype(int)
        np.random.shuffle(indices)
        splits = np.split(indices, n_splits)
        
        with open(os.path.join('pickles', 'data', filename), "wb") as f:
            pickle.dump((X, all_U, all_p, T, splits), f)
            
            
if __name__ == '__main__':
    os.makedirs(os.path.join('pickles', 'data'), exist_ok=True)
    
    # d = 12
#     # vary gamma
#     for gamma in [0.0, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 1.0]:
#         create_yelp_data(gamma=gamma)

#     # vary epsilon
#     for epsilon in [0.0, 1e-2, 3e-2, 1e-1, 3e-1]:
#         create_yelp_data(V_epsilon=epsilon)  # V_eps
#         create_yelp_data(p_epsilon=epsilon)  # p_eps
    
#     # vary rho
#     for rho in [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]:
#         create_yelp_data(rho=rho)
    # d = 12
    d = 12
    create_yelp_data(d=d)
    
    # d = 100
    d = 100
    create_yelp_data(d=d)
