import os
import time
import math



from collections import Counter

import nltk
from nltk import StanfordPOSTagger
from datasets import load_dataset
import stanza
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Pool
from nltk.corpus import wordnet as wn
from nltk.stem import WordNetLemmatizer
from tqdm import tqdm

from cand_generator import ptb_to_wn_pos

from my_utils.data_utils import save_to_pickle, load_from_pickle
import shared_dir

jar = "/path/to/stanford_postagger/stanford-postagger/stanford-postagger.jar"
model = "/path/to/stanford_postagger/stanford-postagger/models/english-bidirectional-distsim.tagger"

format_pos = {'NNS', 'NNPS', 'JJR', 'JJS', 'RBR', 'RBS', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'}
base_pos = {'NN', 'RB', 'VB', 'JJ'}
all_content_pos = format_pos | base_pos

all_wordnet_content_pos = [wn.ADJ, wn.ADV, wn.NOUN, wn.VERB]


def split_list(one_list, split_num):
    length = len(one_list)
    size = length // split_num
    remainder = length % split_num

    splited = []
    start = 0

    for i in range(split_num):
        end = start + size
        if i == split_num - 1:  # If it's the last part, include any remaining elements
            end += remainder
        part = one_list[start:end]
        splited.append(part)
        start = end

    return splited



def run_in_parallel(F, arg_list, num_cpus):
    # results = []
    # with ProcessPoolExecutor(max_workers=num_cpus) as executor:
    #     futures = {executor.submit(F, arg): arg for arg in arg_list}
    #
    #     for future in as_completed(futures):
    #         arg = futures[future]
    #         try:
    #             results.extend(future.result())
    #         except Exception as e:
    #             print(f"Function F failed on argument {arg} with exception {e}")
    #
    # return results

    with Pool(num_cpus) as p:
        res_list = p.map(F, arg_list)

    results = []

    for res in res_list:
        results.extend(res)

    return results





# ============ process the dataset

def get_dataset(dataset_name):

    if dataset_name == 'eli5':
        # load train/dev/test set
        all_string_list = []
        for split in ['train_eli5', 'validation_eli5']:
            orig_dataset = load_dataset("eli5", split=split)

            for d in tqdm(orig_dataset):
                question = d['title']
                orig_answer_list = d['answers']['text']
                all_string_list.extend([question] + orig_answer_list)

    elif dataset_name == 'squad':
        all_string_list = []
        for split in ['train', 'validation']:
            orig_dataset = load_dataset('squad', split=split)
            orig_articles = [d['context'] for d in orig_dataset]
            all_string_list.extend(orig_articles)

    elif dataset_name == 'yelp':
        all_string_list = []
        for split in ['train']:
            orig_dataset = load_dataset("yelp_review_full", split=split)
            orig_articles = [d['text'] for d in orig_dataset]
            all_string_list.extend(orig_articles)
    else:
        raise Exception

    all_string_list = list(set(all_string_list))

    return all_string_list



# def _tokenize_worker(text_sublist, results, idx):
#     nlp = stanza.Pipeline(lang='en', processors='tokenize', use_gpu=False, download_method=None)
#     if idx == 0:
#         results[idx] = [_tokenize(text, nlp) for text in tqdm(text_sublist)]
#     else:
#         results[idx] = [_tokenize(text, nlp) for text in text_sublist]
#
# def _tokenize(text, nlp):
#     doc = nlp(text)
#     return [word.text for sent in doc.sentences for word in sent.words]

def get_tokens(dataset, text_list):
    print('Generate token list ...')
    cache_path = f'{shared_dir.cand_preprocess_dir}{dataset}_tokens.pkl'
    if os.path.exists(cache_path):
        return load_from_pickle(cache_path)

    nlp = stanza.Pipeline(lang='en', processors='tokenize,mwt', download_method=None, use_gpu=True)

    token_ll = []

    for text in tqdm(text_list):
        doc = nlp(text)
        max_length = 150

        cur_tokens_list = [[]]
        for sent in doc.sentences:
            cur_tokens = [word.text for word in sent.words]
            if len(cur_tokens_list[-1]) > max_length:
                cur_tokens_list.append(cur_tokens)
            else:
                cur_tokens_list[-1] += cur_tokens


        token_ll.extend(cur_tokens_list)

    save_to_pickle(token_ll, cache_path)

    return token_ll


def _get_pos_subprocess(args):
    sub_tokens_list, pos_tagger = args
    cur_pos_list = []

    # for tokens in tqdm(sub_tokens_list):
    #     try:
    #         pos_list = pos_tagger.tag(tokens)
    #         cur_pos_list.append(pos_list)
    #     except Exception as e:
    #         print(e)
    #         continue

    batch_size = 32
    batches = [sub_tokens_list[i:i + batch_size] for i in range(0, len(sub_tokens_list), batch_size)]
    for batch_tokens in tqdm(batches):
        batch_pos_list = pos_tagger.tag_sents(batch_tokens)
        cur_pos_list.extend(batch_pos_list)

    # cur_pos_list = pos_tagger.tag_sents(sub_tokens_list)

    return cur_pos_list



def get_pos(dataset, tokens_list):
    print('Generate POS ...')
    cache_path = f'{shared_dir.cand_preprocess_dir}{dataset}_pos_info.pkl'
    if os.path.exists(cache_path):
        return load_from_pickle(cache_path)

    cpu_number = 64

    sub_tokens_list = split_list(tokens_list, cpu_number)

    arg_list = []
    for i in range(cpu_number):
        cur_post_tagger = StanfordPOSTagger(model, jar, encoding="utf-8", java_options='-Xmx40g')
        arg_list.append((sub_tokens_list[i], cur_post_tagger))


    all_pos_list = run_in_parallel(_get_pos_subprocess, arg_list, cpu_number)
    save_to_pickle(all_pos_list, cache_path)
    print(f'Save to {cache_path}!')

    return all_pos_list

# def get_pos(dataset, tokens_list):
#     print('Generate POS ...')
#
#     cpu_number = 3
#
#     sub_tokens_list = split_list(tokens_list, cpu_number)
#
#     arg_list = []
#     for i in range(cpu_number):
#         cur_post_tagger = StanfordPOSTagger(model, jar, encoding="utf-8", java_options='-Xmx30g')
#         arg_list.append((sub_tokens_list[i], cur_post_tagger))
#
#
#     all_pos_list = run_in_parallel(_get_pos_subprocess, arg_list, cpu_number)
#
#
#     return all_pos_list

def build_token_dict(dataset, tokens_list):
    print('Generate token dict ...')
    cache_path = f'{shared_dir.cand_preprocess_dir}{dataset}_freq_dict.pkl'
    # if os.path.exists(cache_path):
    #     return load_from_pickle(cache_path)

    token2freq = Counter([token.lower() for tokens in tokens_list for token in tokens])

    # sort
    sorted_token2freq = {k: v for k, v in sorted(token2freq.items(), key=lambda x: x[1], reverse=True)}

    # for tokens in tqdm(tokens_list):
    #     for token in tokens:
    #         token2freq[token] += 1

    save_to_pickle(sorted_token2freq, cache_path)
    print(f'Save to {cache_path}!')

    return sorted_token2freq




def build_lemma(dataset, sorted_token2freq, all_pos_list):
    print('Build token dict ...')
    cache_path = f'{shared_dir.cand_preprocess_dir}{dataset}_lemma_info.pkl'
    if os.path.exists(cache_path):
        return load_from_pickle(cache_path)


    print('total tokens:', len(sorted_token2freq))

    valid_tokens = {k: v for k, v in sorted_token2freq.items() if v > 2}

    print('Filterd token num:', len(valid_tokens))

    wnl = WordNetLemmatizer()

    lemma_dict = {}
    reverse_lemma_dict = {}



    # Process each token
    for token_pos_list in tqdm(all_pos_list):
        for token, ptb_pos in token_pos_list:

            if (token, ptb_pos) in lemma_dict:
                continue

            # filter out
            if ptb_pos not in all_content_pos:
                continue

            if token not in valid_tokens:
                continue

            wn_pos = ptb_to_wn_pos(ptb_pos)

            assert wn_pos is not None

            lemma = wnl.lemmatize(token, wn_pos)

            lemma_dict[(token, ptb_pos)] = lemma
            reverse_lemma_dict[(lemma, ptb_pos)] = token

    ret = (lemma_dict, reverse_lemma_dict)

    save_to_pickle(ret, cache_path)
    print('Save to', cache_path)

    return ret



# def fix_error(dataset):
#
#     tokens_list = get_tokens(dataset, None)
#     all_pos_list = get_pos(dataset, tokens_list)
#
#     print(len(tokens_list))
#     print(len(all_pos_list))
#
#     cpu_num = int(math.sqrt(len(all_pos_list)))
#
#
#
#     # check
#     real_all_pos_list = all_pos_list[:cpu_num]
#
#     for i in range(cpu_num):
#         error_pos_list = all_pos_list[i * cpu_num]
#         assert error_pos_list[0] == real_all_pos_list[0][0]
#
#         error_str = " ".join([w for w, pos in error_pos_list[0]])
#         real_str = " ".join(tokens_list[0])
#         assert error_str == real_str, error_str + '\n' + real_str
#
#     correct_pos_list = []
#     for b_pos in real_all_pos_list:
#         correct_pos_list.extend(b_pos)
#
#     assert len(correct_pos_list) == len(tokens_list)
#
#     cache_path = f'{shared_dir.cand_preprocess_dir}{dataset}_pos_info.pkl'
#     save_to_pickle(correct_pos_list, cache_path)
#
#     print('Correct ', dataset)





def build_cand_dict(dataset, sorted_token2freq, lemma_info):
    print('Build candidate dict ...')
    cache_path = f'{shared_dir.cand_preprocess_dir}{dataset}_cand_dict.pkl'
    if os.path.exists(cache_path):
        return load_from_pickle(cache_path)


    lemma_dict, reverse_lemma_dict = lemma_info

    cand_dict = {}

    content_pos_list = list(all_content_pos)

    for orig_token, f_ in tqdm(sorted_token2freq.items()):

        if f_ < 3:
            break

        cand_dict[orig_token] = {wn_pos_: [] for wn_pos_ in all_wordnet_content_pos}

        for ptb_pos in content_pos_list:

            if (orig_token, ptb_pos) not in lemma_dict:
                continue

            wn_pos = ptb_to_wn_pos(ptb_pos)
            assert wn_pos is not None

            orig_lemma = lemma_dict[(orig_token, ptb_pos)]
            wn_synsets = wn.synsets(orig_lemma, pos=wn_pos)

            valid_lemmas = []

            for wn_synset in wn_synsets:
                wn_synonym_lemmas = wn_synset.lemmas()
                for wn_synnoym_lemma in wn_synonym_lemmas:
                    wn_synnoym_name = wn_synnoym_lemma.name()
                    if not (wn_synnoym_name == orig_lemma or '_' in wn_synnoym_name):  # filter out
                        valid_lemmas.append(wn_synnoym_name)

            final_synonyms = []
            # reverse back
            for valid_lemma in valid_lemmas:
                if (valid_lemma, ptb_pos) in reverse_lemma_dict:
                    format_synonym = reverse_lemma_dict[(valid_lemma, ptb_pos)]
                    final_synonyms.append(format_synonym)

            cand_dict[orig_token][wn_pos].extend(final_synonyms)
            cand_dict[orig_token][wn_pos] = list(set(final_synonyms))


    save_to_pickle(cand_dict, cache_path)
    print('Save to', cache_path)

    return cand_dict













if __name__ == '__main__':
    dataset = 'eli5'

    text_list = get_dataset(dataset)
    print('Load dataset complete!')

    tokens_list = get_tokens(dataset, text_list)
    print('Tokenization finished!')

    print(len(tokens_list))

    tokens_list = [tk for tk in tokens_list if len(tk) <= 200]

    print(len(tokens_list))




    all_pos_list = get_pos(dataset, tokens_list)
    freq_dict = build_token_dict(dataset, tokens_list)

    lemma_info = build_lemma(dataset, freq_dict, all_pos_list)

    build_cand_dict(dataset, freq_dict, lemma_info)


