# -*- coding: utf-8 -*-
from __future__ import division
from __future__ import print_function

import csv
import html
import random
import re
import string
import unicodedata

from sklearn.utils import Bunch

import config
import pandas as pd

def fetch_itaki(filename='util/merged.csv', **kwargs):
    return fetch_from_csv(filename, **kwargs)


def fetch_ice(filename='util/merged_ice.csv', **kwargs):
    return fetch_from_csv(filename, **kwargs)


min_word_limit = 100
max_word_limit = 400

url_pattern = re.compile(r'(www|http)\S+', re.DOTALL)
email_pattern = re.compile(
    r'(([^<>()\[\]\.,;:\s@\"]+(\.[^<>()\[\]\.,;:\s@\"]+)*)|(\".+\"))@(([^<>()[\]\.,;:\s@\"]+\.)+[^<>()[\]\.,;:\s@\"]{2,})',
    re.DOTALL)
xml_pattern = re.compile(r'<[^<]+>')
censored_pattern_1 = re.compile(r'\s[xX]{2,}\s')
censored_pattern_2 = re.compile(r'[xX]{3,}')

def fetch_from_csv(filename, subset='all', subsets=None, categories=None, strip_accents=False, train_precentage=0.7,
                   trim_by_category=False, trim_by_doc_count=False, trim_by_min_max=False,
                   reuse_skipped_docs=False,
                   verbose=False):

    categories_names = categories
    top_categories = None
    if isinstance(categories[0], list):
        top_categories = categories
        categories = [item for sublist in categories for subsubList in sublist[1:] for item in subsubList]
        categories_names = [item[0] for item in top_categories]

    train_data = []
    train_target = []
    test_data = []
    test_target = []
    input_csv = open(filename, 'r')
    skipped_docs_by_lang = {}
    docs_by_lang = {}
    wc_by_lang = {}

    with input_csv:
        csv_reader = csv.DictReader(input_csv, delimiter=",")
        # Pre-process and categorize docs by language
        for row in csv_reader:
            document = row['document']
            language = row['language']
            process_row(document, language, categories, strip_accents, skipped_docs_by_lang,
                        docs_by_lang, trim_by_category, top_categories, wc_by_lang, trim_by_min_max)
        if reuse_skipped_docs:
            # Merge skipped docs and re-process by language
            merge_doc_by_lang = {}
            for lang in skipped_docs_by_lang:
                merge_doc_by_lang[lang] = " ".join(skipped_docs_by_lang[lang])
            skipped_docs_by_lang = {}
            for lang in merge_doc_by_lang:
                process_row(merge_doc_by_lang[lang], lang, categories, strip_accents, skipped_docs_by_lang,
                            docs_by_lang, trim_by_category, top_categories, wc_by_lang, trim_by_min_max)

    if trim_by_category and top_categories is not None:
        if trim_by_doc_count:
            dc_for_categories = []
            for category in top_categories:
                dc_for_current_category = 0
                for lang in category[1]:
                    dc_for_current_category += len(docs_by_lang[lang])
                dc_for_categories.append(dc_for_current_category)

            lowest_dc_for_category = min(dc_for_categories)
            dc_for_lang_by_category = calc_dc_for_lang_by_category(top_categories, lowest_dc_for_category)

            # Check word-count
            lowest_dc_for_category, dc_cannot_fulfill, dc_for_lang_by_category = \
                check_for_dc_inbalance(lowest_dc_for_category, top_categories, docs_by_lang, dc_for_lang_by_category)

            if dc_cannot_fulfill:
                lowest_dc_for_category, dc_cannot_fulfill, dc_for_lang_by_category = \
                    check_for_dc_inbalance(lowest_dc_for_category, top_categories, docs_by_lang,
                                           dc_for_lang_by_category)

            if dc_cannot_fulfill:
                lowest_dc_for_category, dc_cannot_fulfill, dc_for_lang_by_category = \
                    check_for_dc_inbalance(lowest_dc_for_category, top_categories, docs_by_lang,
                                           dc_for_lang_by_category)

            if dc_cannot_fulfill:
                dc_for_lang_by_category = calc_dc_for_lang_by_category(top_categories, lowest_dc_for_category)
                # Check word-count again
                for category in top_categories:
                    for lang in category[1]:
                        if (len(docs_by_lang[lang]) < dc_for_lang_by_category[category[0]]):
                            raise Exception(
                                "Doc count of '{:,}' of the '{}' is less than category contribution: {:,}".format(
                                    len(docs_by_lang[lang]), lang, int(dc_for_lang_by_category[category[0]])))

            if not config.is_silent:
                for category in top_categories:
                    for language in category[1]:
                        dc_for_lang = dc_for_lang_by_category[category[0]]
                        print("'{}' with doc count {:,} in category:'{}', effective doc count is {:,}".format(language,
                                                                                                              len(
                                                                                                                  docs_by_lang[
                                                                                                                      language]),
                                                                                                              category[
                                                                                                                  0],
                                                                                                              int(
                                                                                                                  dc_for_lang)))
                    print()

            for index, language in enumerate(docs_by_lang):
                category_of_lang = None
                for category in top_categories:
                    if language in category[1]:
                        category_of_lang = category[0]
                        break
                if category_of_lang is None:
                    continue
                dc_for_lang = int(dc_for_lang_by_category[category_of_lang])
                docs_by_lang[language] = docs_by_lang[language][:dc_for_lang]
        else:
            wc_for_categories = []
            for category in top_categories:
                wc_for_current_category = 0
                for lang in category[1]:
                    wc_for_current_category += wc_by_lang[lang]
                wc_for_categories.append(wc_for_current_category)

            lowest_wc_for_category = min(wc_for_categories)

            wc_for_lang_by_category = calc_wc_for_lang_by_category(top_categories, lowest_wc_for_category)

            # Check word-count
            lowest_wc_for_category, wc_cannot_fulfill, wc_for_lang_by_category = \
                check_for_wc_inbalance(lowest_wc_for_category, top_categories, wc_by_lang, wc_for_lang_by_category,
                                       docs_by_lang)

            if wc_cannot_fulfill:
                lowest_wc_for_category, wc_cannot_fulfill, wc_for_lang_by_category = \
                    check_for_wc_inbalance(lowest_wc_for_category, top_categories, wc_by_lang, wc_for_lang_by_category,
                                           docs_by_lang)

            if wc_cannot_fulfill:
                lowest_wc_for_category, wc_cannot_fulfill, wc_for_lang_by_category = \
                    check_for_wc_inbalance(lowest_wc_for_category, top_categories, wc_by_lang, wc_for_lang_by_category,
                                           docs_by_lang)

            if wc_cannot_fulfill:
                wc_for_lang_by_category = calc_wc_for_lang_by_category(top_categories, lowest_wc_for_category)
                # Check word-count again
                for category in top_categories:
                    for lang in category[1]:
                        if (wc_by_lang[lang] < wc_for_lang_by_category[category[0]]):
                            raise Exception(
                                "Wordcount of '{:,}' of the '{}' is less than category contribution: {:,}".format(
                                    wc_by_lang[lang], lang, int(wc_for_lang_by_category[category[0]])))

            if not config.is_silent:
                for category in top_categories:
                    for language in category[1]:
                        wc_for_lang = wc_for_lang_by_category[category[0]]
                        print(
                            "'{}' with word count {:,} in category:'{}', effective word count is {:,}".format(language,
                                                                                                              wc_by_lang[
                                                                                                                  language],
                                                                                                              category[
                                                                                                                  0],
                                                                                                              wc_for_lang))
                    print()

            for index, language in enumerate(docs_by_lang):
                category_of_lang = None
                for category in top_categories:
                    if language in category[1]:
                        category_of_lang = category[0]
                        break
                if category_of_lang is None:
                    continue
                wc_for_lang = wc_for_lang_by_category[category_of_lang]
                total_word_counter = 0
                docs = []
                for doc in docs_by_lang[language]:
                    words = doc.split()
                    wc_for_doc = len(words)
                    if total_word_counter + wc_for_doc < wc_for_lang:
                        total_word_counter += wc_for_doc
                        docs.append(doc)
                    else:
                        trim_count = wc_for_doc - ((total_word_counter + wc_for_doc) - wc_for_lang)
                        trimmed_doc = " ".join(words[:int(trim_count)])
                        total_word_counter += trim_count
                        docs.append(trimmed_doc)
                        break  # exit for loop since no more words needed
                docs_by_lang[language] = docs

            del wc_for_lang_by_category

    # Calculate and divide train and test docs per the percentage for each language
    for index, language in enumerate(docs_by_lang):
        if verbose:
            print("Processing " + language + "...")

        docs_for_current_lang = docs_by_lang[language]
        total = len(docs_for_current_lang)
        train = int(total * train_precentage)
        test = total - train

        if not config.is_silent:
            if subset == 'train':
                print("{:,} train documents {:,} skipped".format(train, len(skipped_docs_by_lang[language])))
                # print(skipped_docs_by_lang[language])
                # print()
            elif subset == 'test':
                print("{:,} test documents {:,} skipped".format(test, len(skipped_docs_by_lang[language])))
                # print(skipped_docs_by_lang[language])
                # print()
            else:
                label = "'{}' with {:,} total documents of {:,} train documents and {:,} test documents found, {:,} skipped"
                print(label.format(language, total, train, test, len(skipped_docs_by_lang[language])))
                # print(skipped_docs_by_lang[language])
                # print()
            del skipped_docs_by_lang[language]

        train_data += docs_for_current_lang[:train]
        test_data += docs_for_current_lang[train:]

        if top_categories is not None:
            # category index
            target = None
            for i, category in enumerate(top_categories):
                if language in category[1]:
                    target = i
                    break
        else:
            # language index
            target = categories.index(language)

        # repeat language_id for train and test amounts
        train_target += [target for i in range(0, train)]
        test_target += [target for i in range(0, test)]

    if not config.is_silent:
        pass

    # Clearing stuff
    del docs_by_lang

    # shuffle the data
    deep_shuffle(train_data, train_target)
    deep_shuffle(test_data, test_target)

    rv = {}
    if subsets is not None:
        for subset_label in subsets:
            if subset_label == 'train':
                rv["train"] = Bunch(data=train_data, target=train_target, target_names=categories_names)
            elif subset_label == 'test':
                rv["test"] = Bunch(data=test_data, target=test_target, target_names=categories_names)
            else:
                # Merge Train and Test data
                all_data = train_data + test_data
                all_target = train_target + test_target
                all_target_names = categories_names
                rv["all"] = Bunch(data=all_data, target=all_target, target_names=all_target_names)
        return rv

    # Create Data Set
    if subset == 'train':
        return Bunch(data=train_data, target=train_target, target_names=categories_names)
    elif subset == 'test':
        return Bunch(data=test_data, target=test_target, target_names=categories_names)
    else:
        # Merge Train and Test data
        train_data.extend(test_data)
        train_target.extend(test_target)
        train_target_names = categories_names
        return Bunch(data=train_data, target=train_target, target_names=train_target_names)


def process_row(document, language, categories, strip_accents, skipped_docs_by_lang, docs_by_lang,
                trim_by_category, top_categories, wc_by_lang, trim_by_min_max):
    if categories and language not in categories:
        # Skip, if not in provided `category`
        return

        # Replace non english
    # if strip_accents:
    #     document = strip_accent(document)

        # Replace URLs
    document = re.sub(url_pattern, '', document)

    # Replace emails
    document = re.sub(email_pattern, '', document)

    # Remove XML/HTML Tags
    document = re.sub(xml_pattern, '', document)

    # Replace censored data
    document = re.sub(censored_pattern_1, " ", document)

    # Replace censored data
    document = re.sub(censored_pattern_2, '', document)

    words = document.split()

    skipped_docs = []
    if language in skipped_docs_by_lang:
        skipped_docs = skipped_docs_by_lang[language]

    if len(words) == 0:
        # Skip, zero words
        skipped_docs.append(document)
        skipped_docs_by_lang[language] = skipped_docs
        return
    if trim_by_min_max:
        if len(words) < min_word_limit:
            # Check for min limit
            skipped_docs.append(document)
            skipped_docs_by_lang[language] = skipped_docs
            return
        if len(words) > max_word_limit:
            # Check for max limit
            document = " ".join(words[:max_word_limit])  # trim long doc
            if len(words[max_word_limit:]) > min_word_limit:
                # Re-use other half of the doc
                process_row(" ".join(words[max_word_limit:]), language, categories, strip_accents, skipped_docs_by_lang,
                            docs_by_lang, trim_by_category, top_categories, wc_by_lang, trim_by_min_max)
            else:
                # Other half is not usable
                skipped_docs.append(" ".join(words[max_word_limit:]))
            words = words[:max_word_limit]
    if not is_english(document):
        # Skip, non english
        skipped_docs.append(document)
        skipped_docs_by_lang[language] = skipped_docs
        return
    skipped_docs_by_lang[language] = skipped_docs

    lang_doc_list = []
    # if language exists, load existing list
    if language in docs_by_lang:
        lang_doc_list = docs_by_lang[language]

    lang_doc_list.append(document)
    docs_by_lang[language] = lang_doc_list

    if trim_by_category and top_categories:
        word_count_for_current_lang = 0
        if language in wc_by_lang:
            word_count_for_current_lang = wc_by_lang[language]

        word_count_for_current_lang += len(words)
        wc_by_lang[language] = word_count_for_current_lang


def is_english(document):
    if document == "":
        return True
    char_set = string.printable
    count = 0
    total = 0
    for x in document:
        if x in char_set:
            count += 1
        total += 1

    return count / total > 0.75


def deep_shuffle(data_list, target_list):
    # using Fisher–Yates shuffle Algorithm
    # to shuffle a list
    random.seed(0)
    for i in range(len(data_list) - 1, 0, -1):
        # Pick a random index from 0 to i
        j = random.randint(0, i + 1)

        # Swap arr[i] with the element at random index
        data_list[i], data_list[j] = data_list[j], data_list[i]
        target_list[i], target_list[j] = target_list[j], target_list[i]


def check_for_wc_inbalance(lowest_wc_for_category, top_categories, wc_by_lang, wc_for_lang_by_category, docs_by_lang):
    wc_cannot_fulfill = False
    for category in top_categories:
        for lang in category[1]:
            if (wc_by_lang[lang] < wc_for_lang_by_category[category[0]]):
                wc_cannot_fulfill = True
                if not config.is_silent:
                    print("WARN: " + "Wordcount of '{:,}' of the '{}' is less than category contribution: {:,}".format(
                        wc_by_lang[lang], lang, int(wc_for_lang_by_category[category[0]])))
                wc_gap = wc_for_lang_by_category[category[0]] - wc_by_lang[lang]
                gap_precentage = wc_gap / wc_by_lang[lang]
                # if gap_precentage > 0.01:
                if not config.is_silent:
                    print("WARN: Continuing with {:,} as lowest_wc_for_category, earlier it was {:,}".format(
                        wc_by_lang[lang] * len(category[1]), lowest_wc_for_category))
                lowest_wc_for_category = wc_by_lang[lang] * len(category[1])
                wc_for_lang_by_category = calc_wc_for_lang_by_category(top_categories, lowest_wc_for_category)
                break
                # else:
                # # pick random docs from same lang
                # docs_for_lang = docs_by_lang[lang]
                # docs_limit = len(docs_for_lang)
                # gen_wc = 0
                # while wc_gap > 0:
                #     new_doc = docs_for_lang[random.randint(0, docs_limit)]
                #     new_doc_wc = len(new_doc.split())
                #     wc_gap -= new_doc_wc
                #     wc_by_lang[lang] += new_doc_wc
                #     gen_wc += new_doc_wc
                #     docs_for_lang.append(new_doc)
                # print("WARN: " + "Randomly generated wordcount of '{:,}' for the '{}' using existing docs".format(gen_wc, lang))
    return lowest_wc_for_category, wc_cannot_fulfill, wc_for_lang_by_category


def check_for_dc_inbalance(lowest_dc_for_category, top_categories, docs_by_lang, dc_for_lang_by_category):
    dc_cannot_fulfill = False
    for category in top_categories:
        for lang in category[1]:
            if (len(docs_by_lang[lang]) < dc_for_lang_by_category[category[0]]):
                dc_cannot_fulfill = True
                if not config.is_silent:
                    print("WARN: " + "Doc count of '{:,}' of the '{}' is less than category contribution: {:,}".format(
                        len(docs_by_lang[lang]), lang, int(dc_for_lang_by_category[category[0]])))
                dc_gap = dc_for_lang_by_category[category[0]] - len(docs_by_lang[lang])
                gap_precentage = dc_gap / len(docs_by_lang[lang])
                # if gap_precentage > 0.01:
                if not config.is_silent:
                    print("WARN: Continuing with {:,} as lowest_dc_for_category, earlier it was {:,}".format(
                        len(docs_by_lang[lang]) * len(category[1]), lowest_dc_for_category))
                lowest_dc_for_category = len(docs_by_lang[lang]) * len(category[1])
                dc_for_lang_by_category = calc_dc_for_lang_by_category(top_categories, lowest_dc_for_category)
                break
                # else:
                # # pick random docs from same lang
                # docs_for_lang = docs_by_lang[lang]
                # docs_limit = len(docs_for_lang)
                # gen_dc = 0
                # while dc_gap > 0:
                #     new_doc = docs_for_lang[random.randint(0, docs_limit)]
                #     new_doc_dc = len(new_doc.split())
                #     dc_gap -= new_doc_dc
                #     len(docs_by_lang[lang]) += new_doc_dc
                #     gen_dc += new_doc_dc
                #     docs_for_lang.append(new_doc)
                # print("WARN: " + "Randomly generated wordcount of '{:,}' for the '{}' using existing docs".format(gen_dc, lang))
    return lowest_dc_for_category, dc_cannot_fulfill, dc_for_lang_by_category


def calc_wc_for_lang_by_category(categories_list, lowest_wc_for_category):
    words_count_for_lang_by_category = {}
    for category in categories_list:
        words_count_for_lang_by_category[category[0]] = lowest_wc_for_category / len(category[1])
    return words_count_for_lang_by_category


def calc_dc_for_lang_by_category(categories_list, lowest_dc_for_category):
    doc_count_for_lang_by_category = {}
    for category in categories_list:
        doc_count_for_lang_by_category[category[0]] = lowest_dc_for_category / len(category[1])
    return doc_count_for_lang_by_category

def strip_accent(text):
    text = html.unescape(text)
    text = text.replace("&acircumflex;", "â")
    text = text.replace("&icircumflex;", "î")
    text = text.replace("&ecircumflex;", "ê")
    text = text.replace("&ocircumflex;", "ô")
    text = text.replace("&ucircumflex;", "û")
    text = text.replace("&A-ACUTE;", "Â")
    text = text.replace("&I-ACUTE;", "Î")
    text = text.replace("&E-ACUTE;", "Ê")
    text = text.replace("&O-ACUTE;", "Ó")
    text = text.replace("&U-ACUTE;", "Û")

    text = text.replace("&eumlaut;", "ë")
    text = text.replace("&iumlaut;", "ï")
    text = text.replace("&oumlaut;", "ö")
    text = text.replace("&uumlaut;", "ü")

    text = text.replace("&aeligature;", "Æ")
    text = text.replace("&oeligature;", "Œ")
    text = text.replace("&ccedille;", "Ç")
    text = text.replace("&ntidle;", "ñ")
    text = text.replace("&obrack;", "[")
    text = text.replace("&cbrack;", "]")

    text = text.replace("&lsqbrack;", "{")
    text = text.replace("&rsqbrack;", "}")
    text = text.replace("&ampersand;", "&")

    text = text.replace("&degree;", "°")
    text = text.replace("&degree-sign;", "°")
    text = text.replace("&percent;", "%")
    text = text.replace("&scol;", ";")

    text = text.replace("&plus-or-minus;", "±")
    text = text.replace("&curved-dash;", "~")
    text = text.replace("&very-long-dash;", "—")
    text = text.replace("&long-dash;", "—")
    text = text.replace("&dotted-line;", "┄")
    text = text.replace("&dotted-line;", "┄")
    text = text.replace("&arrowhead;", "➤")
    text = text.replace("&right-arrow;", "→")
    text = text.replace("&black-square;", "■")
    text = text.replace("&peso;", "₱")
    text = text.replace("&centavo;", "￠")
    text = text.replace("&pound-sign;", "£")
    text = text.replace("&club;", "♣")
    text = text.replace("&heart;", "♥")
    text = text.replace("&spade;", "♠")
    text = text.replace("&diamond;", "♦")

    try:
        text = str(text, 'utf-8')
    except (TypeError, NameError):  # unicode is a default on python 3
        pass
    text = unicodedata.normalize('NFKD', text)
    text = text.encode('ascii', 'ignore')
    text = text.decode("utf-8")
    return str(text)


def addPullStop(doc):
    suffix = ""
    if not doc.endswith("."):
        suffix = "."
    return doc + suffix


def print_dataset(dataset):
    print("{:,} documents".format(len(dataset.data)))
    print("{:,} words ".format(sum([len(doc.split()) for doc in dataset.data])))
    print("{:,} categories".format(len(dataset.target_names)))
    print("categories {} ".format(dataset.target_names))
    word_count_by_target = {}
    docs_count_by_target = {}
    for doc, target in zip(dataset.data, dataset.target):
        word_count = 0
        if target in word_count_by_target:
            word_count = word_count_by_target[target]

        word_count += len(doc.split())
        word_count_by_target[target] = word_count

        docs_count = 0
        if target in docs_count_by_target:
            docs_count = docs_count_by_target[target]

        docs_count += 1
        docs_count_by_target[target] = docs_count
    for target in word_count_by_target:
        print("{} - {:,} words in {:,} docs".format(dataset.target_names[target], word_count_by_target[target],
                                                    docs_count_by_target[target]))
    print()


if __name__ == "__main__":
    print("------Italki-------")
    # all_cats = ["Telugu", "Romanian", "Sinhala", "Javanese", "German", "Indonesian", "Hindi", "Ukrainian", "Turkish",
    #             "Malayalam", "Punjabi", "Bulgarian", "French", "Portuguese", "Arabic", "English", "Korean", "Malay",
    #             "Italian", "Hungarian", "Japanese", "Tamil", "Spanish", "Thai", "Chinese", "Vietnamese"]
    # dataset = fetch_itaki(filename='merged.csv', subsets=["train", "test", "all"], categories=all_cats, verbose=True)
    #
    # print("Train_Dataset")
    # print("-------------")
    # print_dataset(dataset["train"])
    #
    # print("Test_Dataset")
    # print("------------")
    # print_dataset(dataset["test"])
    #
    # print("All_Dataset")
    # print("-----------")
    # print_dataset(dataset["all"])

    print("------ICE Corpus-------")
    all_cats = ["Canada", "HongKong", "India", "Ireland", "Jamaica", "Nigeria", "Philippines", "Singapore", "SriLanka",
                "USA"]
    dataset = fetch_ice(filename='merged_ice.csv', subsets=["train", "test", "all"], categories=all_cats, verbose=True)

    print("Train_Dataset")
    print("-------------")
    print_dataset(dataset["train"])

    print("Test_Dataset")
    print("------------")
    print_dataset(dataset["test"])

    print("All_Dataset")
    print("-----------")
    print_dataset(dataset["all"])

    print("\n\nDONE!")
