import argparse
import time
import os
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split


# def : (dir)'./data/dataset/category' > train.csv, valid.csv, test.csv

def make_dataset(json_path, train_rate, save_dir, data_category):
    
    print(f'# Reading {data_category}_json file')
    
    # json read
    df = pd.read_json(json_path, lines=True)
        
    # df    
    df = df[['reviewerID', 'asin', 'reviewText', 'overall']]
    df.columns = ['userID', 'itemID', 'review', 'rating']  
    
    # ID > index
    df['userID'] = df.groupby(df['userID']).ngroup()
    df['itemID'] = df.groupby(df['itemID']).ngroup()


    # (exception) 
    null_or_notstr_index = df[[not isinstance(x, str) or len(x) == 0 for x in df['review']]].index
    
    # print(f'null_or_notstr_data = {len(null_or_notstr_index)}')
    df = df.drop(null_or_notstr_index)  

    
    # preprocessing
    def clean_text(review):  # clean a review : useless punctuations


        with open('./data/preprocessing/punctuations.txt') as f:  # punctuations
            punctuations = set(f.read().splitlines())    

        review = review.lower() 
        
        for p in punctuations:
            review = review.replace(p, '') 
        return review 

    # df > preprcossing
    # tqdm.pandas() # tqdm
    df['review'] = df['review'].apply(lambda x : clean_text(x))
    
    
    # train : valid : test 데이터 나누기 > train_rate : test_size/2 : test_size/2  
    train, valid = train_test_split(df, test_size=1 - train_rate, random_state=2024)  # random_state 
    valid, test = train_test_split(valid, test_size=0.5, random_state=2024)           # random_state
    

    # making intergated reviews
    def get_concat_reviews(df, sub, obj, for_train=True, word_limit=10, review_length=400):

        df = df.reset_index(drop=True)
        df['review_split'] = df['review'].apply(lambda x : x.split())
        
        # ex : dict({userID: df(item id|reviews)})

        reviews_all_dict = dict(list(df[[obj, 'review_split']].groupby(by=df[sub])))  
        concat_reviews = []
        sparse_idx     = set()  
        
        
        print(f'get_concat_reviews processing : {sub} > {obj}')

        for idx, (sub_id, obj_id) in enumerate((zip(tqdm(df[sub]), df[obj]))):
            df_data = reviews_all_dict[sub_id]  


            if for_train :
                reviews = df_data['review_split'].to_list()  
                # print(type(reviews))  
                 
            # for valid, test 
            else: reviews = df_data['review_split'][df_data[obj] != obj_id].to_list()  

            all_reviews = sum(reviews, [])
            
            if len(all_reviews) < word_limit :  
                sparse_idx.add(idx) 
                # print(sparse_idx)
          
            all_reviews = all_reviews[:review_length]      
            all_reviews = " ".join(all_reviews)

            concat_reviews.append(all_reviews) 

        concat_reviews_df = pd.DataFrame(concat_reviews)

        concat_reviews_df.columns = [str(sub[:4]) +'_reviews_concat']
        return_df = pd.concat([df,concat_reviews_df], axis=1)

        # print(f'drop sparse_idx: {sparse_idx}')
        return_df = return_df.drop(sparse_idx).reset_index(drop=True)
        
        return return_df.drop(columns='review_split')

    # dir > data save
    os.makedirs(save_dir, exist_ok=True)

    train = get_concat_reviews(train, 'userID', 'itemID') 
    train = get_concat_reviews(train, 'itemID', 'userID')
    valid = get_concat_reviews(valid, 'userID', 'itemID', for_train=False) 
    valid = get_concat_reviews(valid, 'itemID', 'userID', for_train=False)
    test  = get_concat_reviews(test, 'userID', 'itemID',  for_train=False) 
    test  = get_concat_reviews(test, 'itemID', 'userID',  for_train=False)


    train.to_csv(os.path.join(save_dir, data_category + '_train.csv'), index=False)
    valid.to_csv(os.path.join(save_dir, data_category + '_valid.csv'), index=False)
    test.to_csv(os.path.join(save_dir,  data_category + '_test.csv'),  index=False)
    
    print(os.path.join(save_dir, data_category + '_train.csv'))

    print(f'## (Dataset) train: {len(train)}, valid: {len(valid)}, test: {len(test)}')
    print(f'## (Total) {len(df)} reviews, {len(df.groupby("userID"))} users, {len(df.groupby("itemID"))} items.') 


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--json_path', dest='json_path') 
    parser.add_argument('--train_rate', dest='train_rate', default=0.8)
    parser.add_argument('--save_dir', dest='save_dir')  
    parser.add_argument('--data_category', dest='data_category') 

    args = parser.parse_args()

    start_time = time.perf_counter()
    make_dataset(args.json_path, args.train_rate, args.save_dir, args.data_category)
    end_time   = time.perf_counter()
    print(f'## complete dataset.py: Time used {end_time - start_time:.0f} seconds.')

# python dataset.py --json_path ./raw/Office_Products_5.json --save_dir ./temp/OP --data_category OP