import math
import pickle
import os
import json
import gzip
from unicodedata import category
import pandas as pd
import swifter

# import requests
import urllib
import numpy as np
import time

from tqdm import tqdm
from PIL import Image
from io import BytesIO
import itertools

from multiprocessing import Pool
from collections import defaultdict

import argparse
# import spacy
'''
from torchtext.data.utils import(
    get_tokenizer,
    ngrams_iterator,
)
'''


# class EnTokenizer:
#     def __init__(self):
#         self.spacy_en = spacy.load('en_core_web_sm')

#     def tokenize_en(self, text):
#         """
#         Tokenizes English text from a string into a list of strings
#         """
#         return [tok.text for tok in self.spacy_en.tokenizer(text)]

DEBUG = False
SAVE_CHUNK = 10000
META_FILE_NAME = ''
REVIEW_FILE_NAME = ''


def dl_to_ld(dl):
    return [dict(zip(dl, t)) for t in zip(*dl.values())]


def ld_to_dl(ld):
    return {k: [dic[k] for dic in ld] for k in ld[0]}


def combine_list_of_strs(xs):
    xs = [i for i in xs if isinstance(i, str)]
    return '\n'.join(xs)

def chunks(a, n):
    k, m = divmod(len(a), n)
    return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))

def get_target_ids(xs):
    xs, asins = xs
    target_ids = []
    for t in tqdm(xs):
        indexes = []
        for t_i in t:
            if t_i in asins:
                indexes.append(asins.index(t_i))
        target_ids.append(indexes)
    return target_ids




class AmazonDatasetBuilder:

    def __init__(self, meta_file, review_file, save_dir, save_name, debug=False, exclude='Electronics') -> None:
        self.meta_file = meta_file
        self.review_file = review_file
        self.save_dir = save_dir
        self.save_name = save_name
        self.debug = debug

    def get_meta(self):
        ### load the meta data
        meta_file = self.save_dir + '/meta.pkl'
        if not self.debug and os.path.isfile(meta_file):
            print(f'Loaded meta from {meta_file}.')
            return pd.read_pickle(meta_file)
        file_name = self.meta_file
        data = []
        i = 0
        with gzip.open(file_name) as f:
            for l in f:
                x = json.loads(l.strip())
                if not 'title' in x:
                    data.append(x)
                else:
                    if not 'getTime' in x['title']:
                        data.append(x)
                i += 1
                if i > 100 and self.debug:
                    break
        df = pd.DataFrame.from_dict(data)

        df = df.fillna('')
        # filter those unformatted rows
        # df5 = df3[~df3.title.str.contains('getTime')]
        df = df.drop_duplicates(subset='asin')
        print(f"Total {len(df)} unique data points are loaded from the meta file.")
        if not self.debug:
            df.to_pickle(meta_file)
            print(f'Saved to {meta_file}.')
        return df

    def get_review(self):
        review_file = self.save_dir + '/review.pkl'
        if not self.debug and os.path.isfile(review_file):
            print(f'Loaded review_dict from {review_file}.')
            df = pd.read_pickle(review_file)
            return df

        file_name = self.review_file
        data = []
        i = 0
        with gzip.open(file_name) as f:
            for l in f:
                data.append(json.loads(l.strip()))
                i += 1
                if i > 100 and self.debug:
                    break
        # return data
        df = pd.DataFrame.from_dict(data)
        print(f"Total {len(df)} unique data points are loaded from the review file.")
        if not self.debug:
            df.to_pickle(review_file)
            print(f'Saved to {review_file}.')
        return df


    def review_to_dict(self, review):
        # overall Int
        # verified Bool
        # reviewTime MM DD, YYYY
        # reviewText string
        # summary string
        review_dict = defaultdict(list)
        rows = review.shape[0]
        for i in range(rows):
            overall = review.iloc[i]['overall']
            verified = review.iloc[i]['verified']
            time = review.iloc[i]['reviewTime']
            review_text = review.iloc[i]['reviewText']
            review_summary = review.iloc[i]['summary']
            review_dict['overall'].append(float(overall))
            review_dict['verified'].append(bool(verified))
            review_dict['time'].append(time)
            review_dict['review_summary'].append(review_summary)
            review_dict['review_text'].append(review_text)
        return review_dict


    def meta_to_dict(self, meta, asin):
        text_feature_columns = [
            'title', 'brand', 'feature', 'date', 'description', 'price', 'category']
        image_url_column = 'imageURL'
        asin_meta = meta[meta['asin'] == asin]

        if len(asin_meta) > 1:
            print(f"Duplicates found for {asin}, I assume use the first one")
            asin_meta = asin_meta.iloc[0]
            # raise ValueError("Something is wrong here!")

        asin_meta_dict = asin_meta.to_dict()
        asin_meta_features = {k: '' for k in text_feature_columns}

        if asin_meta_dict == {}:
            asin_meta_features['image'] = None
            asin_meta_features['links'] = ''
            return asin_meta_features

        asin_meta_features = {}
        for key in text_feature_columns:
            try:
                value = asin_meta[key]
                if type(value) == list:
                    asin_meta_features[key] = value[0]
                elif type(value) == str:
                    asin_meta_features[key] = value
                else:
                    asin_meta_features[key] = value.to_list()[0]
            except:
                import pdb
                pdb.set_trace()

        image_url = asin_meta[image_url_column]
        if not type(image_url) == list:
            image_url = image_url.to_list()
            urls = image_url[0]
        else:
            urls = image_url

        imgs = []

        for url in urls:
            try:
                response = urllib.request.urlopen(url, timeout=3)
                img = Image.open(BytesIO(response.content))
                imgs.append(img)
            except:
                pass

        asin_meta_features['image'] = imgs

        view_links = asin_meta['also_view']
        buy_links = asin_meta['also_buy']

        if not type(view_links) == list:
            view_links = view_links.to_list()[0]
        if not type(buy_links) == list:
            buy_links = buy_links.to_list()[0]

        view_links = [] if view_links == '' else view_links
        buy_links = [] if buy_links == '' else buy_links
        similar_item_links = [] if buy_links == '' else buy_links

        asin_meta_features['view_links'] = view_links
        asin_meta_features['buy_links'] = buy_links
        asin_meta_features['similar'] = similar_item_links
        return asin_meta_features

    def save_data_chunk(meta, chunk_name):
        with open(chunk_name, 'wb') as f:
            pickle.dump(meta, f)
        print(f"Data chunk {chunk_name} is saved.")

    def build_src_target(self, row):
        connections = row['meta_co_view'] + \
            row['meta_co_buy'] + row['meta_similar_item']
        return connections

    # df.apply (lambda row: label_race(row), axis=1)

    #def yield_tokens(data_iter, ngrams, tokenizer):
    #    for text in data_iter:
    #        # print(text)
    #        yield ngrams_iterator(token_list=tokenizer.tokenize_en(text), ngrams=ngrams)


    # def build_vocab_and_tokenizer(train_iter, ngrams=2):
    #     tokenizer = EnTokenizer()
    #     vocab = build_vocab_from_iterator(
    #         yield_tokens(train_iter, ngrams, tokenizer), specials=["<unk>"])
    #     return vocab, tokenizer


    def tensorlize_text_features(row, vocab, tokenizer):
        item = row['raw_text_features']
        tensor = vocab(tokenizer.tokenize_en(item))
        return tensor

    def build_specific_text_features(self, row, name):
        # if 'name not in row:
        #     return ''
        return combine_list_of_strs(row[name])

    def build_text_features(self, row):
        re_text = combine_list_of_strs(row['review_text'])
        re_summary = combine_list_of_strs(row['review_summary'])
        # meta_title = combine_list_of_strs(row['meta_title'])
        meta_description = combine_list_of_strs(row['meta_description'])
        return re_text + re_summary + meta_description
    
    def build_label(self, row, exclude=None):
        category = row['category']
        label = [self.label_names[c] for c in category if c != exclude]
        return label

    def build_links(self, row, asins):
        simillars = row['meta_similar_item']
        views = row['meta_co_view']
        buys = row['meta_co_buy']
        links = simillars + views + buys
        indexes = []
        for l in links:
            if l in asins:
                indexes.append(asins.index(l))
        return indexes

    def build_label_name_mapping(self, df):
        categories = df['category'].tolist()
        categories = list(set(itertools.chain.from_iterable(categories)))
        self.label_names = {c: i for i, c in enumerate(categories)}
    
    def slice_asins(self, asins, n_chunks):
        sliced_asins = list(chunks(asins, n_chunks))
        with open('sliced_asins.pkl', 'wb') as f:
            pickle.dump(sliced_asins, f)
        print(f"Saved asins sliced to {n_chunks} to sliced_asins.pkl")

    def build_dataset(self, chunk_id=None):
        # this function is no longer in use
        df_meta = self.get_meta()
        df_review = self.get_review()
        self.df_meta, self.df_review = df_meta, df_review
        if self.debug:
            print("Debug mode is on, stop point 1")

        if chunk_id is not None:
            with open('sliced_asins.pkl', 'rb') as f:
                asins = pickle.load(f)[chunk_id]
        else:
            asins = df_meta['asin'].to_list()

        asin_based_knowledge_list = []
        pbar = tqdm(total=len(asins))
        for i, asin in enumerate(asins):

            raw_reviews = df_review[df_review['asin'] == asin]
            reviews = self.review_to_dict(raw_reviews)
            metas = self.meta_to_dict(df_meta, asin)
            asin_dict = {
                'asin': asin,
                'id': i,
                'review_text': reviews['review_text'],
                'review_summary': reviews['review_summary'],
                'review_time': reviews['time'],
                'review_score': reviews['overall'],
                'meta_title': metas['title'],
                # 'meta_description': metas['description'],
                'meta_image': metas['image'],
                'meta_co_view': metas['view_links'],
                'meta_co_buy': metas['buy_links'],
                'meta_similar_item': metas['similar'],
                'meta_description': metas['description'],
                'category': metas['category'],
            }
            asin_based_knowledge_list.append(asin_dict)
            i += 1
            pbar.update(1)
            if i >= 100 and DEBUG:
                break
        pbar.close()
        if chunk_id is not None:
            with open(f'asin_dict_{chunk_id}.pkl', 'wb') as f:
                pickle.dump(asin_based_knowledge_list, f)

    def read_and_combine_asins(self, n_chunks):
        asins = []
        file_names = [f'asin_dict_{n}.pkl'for n in range(n_chunks)]
        for n in file_names:
            with open(n, 'rb') as f:
                lst = pickle.load(f)
            asins += lst
        print('Loaded asins', file_names)
        return asins

    def finalize_dataset(self, asin_based_knowledge_list):
        target_dir, target_name = self.save_dir, self.save_name
        df_meta_processed = ld_to_dl(asin_based_knowledge_list)
        df = pd.DataFrame(df_meta_processed)
        df.to_pickle(f'{target_dir}/{target_name}')
        df['trg'] = df.swifter.apply(lambda row: self.build_src_target(row), axis=1)
        df['raw_text_features'] = df.swifter.apply(
            lambda row: self.build_text_features(row), axis=1)
        '''
        # corpus = df['raw_text_features'].to_list()
        # vocab, tokenier = build_vocab_and_tokenizer(corpus)
        # df['text_features'] = df.apply(
        #     lambda row: tensorlize_text_features(row, vocab, tokenier), axis=1)
        '''
        '''
        print("Built text features")
        for n in ['review_text', 'review_summary', 'meta_title', 'meta_description']:
            print("Dealing with", n)
            df[n] = df.swifter.apply(
                lambda row: self.build_specific_text_features(row, n), axis=1)
            # corpus = df[n].to_list()
            # vocab, tokenizer = build_vocab_and_tokenizer(corpus)
            # df[f'{n}_features'] = df.apply(
            #     lambda row: tensorlize_text_features(row, vocab, tokenizer), axis=1)
            # print(f"Built {n} features")
        # df['image_features'] = df.apply(
        #     lambda row: tensorlize_text_features(row, vocab, tokenier), axis=1)
        asins = df['asin'].to_list()
        df['links'] = df.swifter.apply(
            lambda row: self.build_links(row, asins), axis=1)
        self.build_label_name_mapping(df)
        print("Built label name mapping")
        print(self.label_names)

        df['labels'] = df.swifter.apply(
            lambda row: self.build_label(row), axis=1)
        '''

        asins = df['asin'].to_list()
        features = df['raw_text_features'].to_list()
        targets = df['trg'].to_list()
        categories = df['category'].to_list()

        print("Get targets_id")
        N = 32
        c_targets = list(chunks(targets, N))
        c_targets = [(c, asins) for c in c_targets]
        with Pool(N) as p:
            t_ids = p.map(get_target_ids, c_targets)
        t_ids = list(itertools.chain.from_iterable(t_ids))
        print("Build a smaller df")
        
        new_df = pd.DataFrame({'asin': asins, 'features': features, 'targets': targets, 'target_ids': t_ids, 'categories': categories})

        print("Filter out certain categories")
        def list_length(row):
            return len(row['categories'])
        new_df['categoriy_length'] = new_df.swifter.apply(lambda row: list_length(row), axis=1)

        if self.debug:
            print("Debug mode is on, stop point 2")

        df.to_pickle(f'{target_dir}/{target_name}')
        new_df.to_pickle(f'{target_dir}/compact.pkl')
        print(f"Saved df to {target_dir}/{target_name}")


argparser = argparse.ArgumentParser()
argparser.add_argument('--chunk', default=None, type=int)
argparser.add_argument('--n_chunks', default=None, type=int)
argparser.add_argument('--mode', default='build', type=str)

args = argparser.parse_args()

if args.mode == 'build':
    builder = AmazonDatasetBuilder(
        meta_file='data/raw/electronics/meta_Electronics.json.gz',
        review_file='data/raw/electronics/Electronics_5.json.gz',
        save_dir='data/processed/electronics',
        save_name='electronics_dataset.pkl',
        exclude='Electronics'
    )

    if args.n_chunks is not None:
        df = builder.get_meta()
        builder.slice_asins(df['asin'].to_list(), args.n_chunks)

    builder.build_dataset(args.chunk)
elif args.mode == 'finalize':
    builder = AmazonDatasetBuilder(
        meta_file='data/raw/electronics/meta_Electronics.json.gz',
        review_file='data/raw/electronics/Electronics_5.json.gz',
        save_dir='data/processed/electronics',
        save_name='electronics_dataset.pkl',
        exclude='Electronics'
    )
    if args.n_chunks is None:
        raise ValueError('n_chunks is required')
    asins = builder.read_and_combine_asins(args.n_chunks)
    builder.finalize_dataset(asins)




# df_meta = builder.get_meta()
# df_review = builder.get_review()
# review_to_dict = builder.review_to_dict
# meta_to_dict = builder.meta_to_dict

# def get_list(args):
#     asins, my_review, my_meta = args[0], args[1], args[2]
#     # asins, my_review, my_meta = args[0], df_review, df_meta
#     results = []
#     pbar = tqdm(total=len(asins))
#     print("total asins", len(asins))
#     for i, asin in enumerate(asins):
#         raw_reviews = my_review[my_review['asin'] == asin]
#         reviews = review_to_dict(raw_reviews)
#         metas = meta_to_dict(my_meta, asin)
#         asin_dict = {
#             'asin': asin,
#             'id': i,
#             'review_text': reviews['review_text'],
#             'review_summary': reviews['review_summary'],
#             'review_time': reviews['time'],
#             'review_score': reviews['overall'],
#             'meta_title': metas['title'],
#             # 'meta_description': metas['description'],
#             'meta_image': metas['image'],
#             'meta_co_view': metas['view_links'],
#             'meta_co_buy': metas['buy_links'],
#             'meta_similar_item': metas['similar'],
#             'meta_description': metas['description'],
#             'category': metas['category'],
#         }
#         results.append(asin_dict)
#         i += 1
#         pbar.update(1)
#     return results

# def chunks(lst, n):
#     for i in range(0, len(lst), n):
#         yield lst[i:i + n]


# if __name__ == '__main__':
#     n_processes = 1
#     asins = df_meta['asin'].to_list()
#     if debug:
#         asins = asins[:100]
#     tmp = get_list((asins, df_review, df_meta))
#     exit()

#     chunked_asins = chunks(asins, n_processes)
#     assemble_args = [(asin, df_review, df_meta) for asin in chunked_asins]

#     start_time = time.time()
#     # get_list(assemble_args[0])
#     with Pool(n_processes) as p:
#         tmp = p.map(get_list, assemble_args)
#     # with Pool(n_processes) as p:
#     #     tmp = p.map(get_list, chunked_asins)
#     print(f"Time taken: {time.time() - start_time}")

