import os
import argparse
from typing import List, Tuple
from collections import Counter
import re
import logging

try:
    import spacy
except ModuleNotFoundError:
    print("Please install spacy: pip install spacy")
    spacy = None
import hexa.tasks.constants as CONST
try:
    from nltk.corpus import stopwords

    STOP_WORDS = stopwords.words('english')
    import nltk
except (ModuleNotFoundError, LookupError):
    logging.error("Please install NLTK: pip install nltk")
    logging.error("Then run `nltk.download('stopwords')`")
    nltk = None
    stopwords = None
    STOP_WORDS = []

nlp = None

def uppercase(string: str) -> str:
    """
    Make the first character of the string uppercase, if the string is non-empty.
    """
    if len(string) == 0:
        return string
    else:
        return string[0].upper() + string[1:]


def normalize_reply(text: str, version=1) -> str:
    """
    Standardize the capitalization and punctuation spacing of the input text.
    Version 1: Fix sentence start casing, and punctuation.
    Version 2: Add trailing period, if missing.
    """

    switch_list = [(' .', '.'), (' ,', ','), (' ?', '?'), (' !', '!'), (" ' ", "'")]

    # add spaces so that words and punctuation can be seaprated
    new_text = text.lower()

    # normalize in case of human:
    for new, old in switch_list:
        new_text = new_text.replace(old, new).replace('  ', ' ')

    # split on punctuation to find sentence boundaries
    # capitalize stuff
    tokens = new_text.split(' ')
    for i in range(len(tokens)):
        if i == 0:
            tokens[i] = uppercase(tokens[i])
        elif tokens[i] in ('i', "i'm", "i've", "i'll", "i'd"):
            tokens[i] = uppercase(tokens[i])
        elif tokens[i] in '?.!' and i < len(tokens) - 1:
            tokens[i + 1] = uppercase(tokens[i + 1])
    new_text = ' '.join(tokens)
    new_text = ' ' + new_text + ' '

    for tup in switch_list:
        new_text = new_text.replace(tup[0], tup[1])

    # get rid of surrounding whitespace
    new_text = new_text.strip()
    new_text = new_text.replace('  ', ' ')

    if version > 1 and new_text and new_text[-1] not in '!.?)"\'':
        new_text += '.'

    return new_text


def str_to_msg(txt, ignore_fields=''):
    """
    Convert formatted string to ParlAI message dict.
    :param txt:
        formatted string to convert. String format is tab-separated fields,
        with colon separating field name and contents.
    :param ignore_fields:
        (default '') comma-separated field names to not
        include in the msg dict even if they're in the string.
    """

    def tostr(txt):
        txt = str(txt)
        txt = txt.replace('\\t', '\t')
        txt = txt.replace('\\n', '\n')
        txt = txt.replace('__PIPE__', '|')
        return txt

    def tolist(txt):
        vals = txt.split('|')
        for i, v in enumerate(vals):
            v = tostr(v)
            vals[i] = v
        return vals

    def convert(key, value):
        if key == 'text' or key == 'id':
            return tostr(value)
        elif (
            key == 'label_candidates'
            or key == 'labels'
            or key == 'eval_labels'
            or key == 'text_candidates'
        ):
            return tolist(value)
        elif key == 'reward':
            try:
                return int(value)
            except ValueError:
                return float(value)
        elif key == 'episode_done':
            return bool(value)
        else:
            return tostr(value)

    if txt == '' or txt is None:
        return None

    msg = {}
    for t in txt.split('\t'):
        ind = t.find(':')
        key = t[:ind]
        value = t[ind + 1 :]
        if key not in ignore_fields.split(','):
            msg[key] = convert(key, value)
    msg['episode_done'] = msg.get('episode_done', False)
    return msg


re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')


def normalize_answer(s):
    """
    Lower text and remove punctuation, articles and extra whitespace.
    """

    s = s.lower()
    s = re_punc.sub(' ', s)
    s = re_art.sub(' ', s)
    # TODO: this could almost certainly be faster with a regex \s+ -> ' '
    s = ' '.join(s.split())
    return s


def _prec_recall_f1_score(pred_items, gold_items):
    """
    Compute precision, recall and f1 given a set of gold and prediction items.
    :param pred_items: iterable of predicted values
    :param gold_items: iterable of gold values
    :return: tuple (p, r, f1) for precision, recall, f1
    """
    common = Counter(gold_items) & Counter(pred_items)
    num_same = sum(common.values())
    if num_same == 0:
        return 0, 0, 0
    precision = 1.0 * num_same / len(pred_items)
    recall = 1.0 * num_same / len(gold_items)
    f1 = (2 * precision * recall) / (precision + recall)
    return precision, recall, f1


def f1_metric(guess: str, answers: List[str]):
    if guess is None or answers is None:
        # return AverageMetric(0, 0)
        return 0
    g_tokens = normalize_answer(guess).split()
    scores = [
        _prec_recall_f1_score(g_tokens, normalize_answer(a).split())
        for a in answers
    ]
    # return F1Metric(max(f1 for p, r, f1 in scores), 1)
    return max(f1 for p, r, f1 in scores)


def extract_entities(
    sentence: str,
    pos: Tuple[str] = ('PROPN', 'NOUN'),
    use_named_entities: bool = True,
    use_noun_chunks: bool = True,
) -> List[str]:
    """
    Given a sentence, extract the entities from the sentence.

    :param sentence:
        provided sentence
    :param pos:
        parts of speech to look at
    :param use_named_entities:
        whether to include named entities
    :param use_noun_chunks:
        whether to include noun chunks.

    :return entities:
        return list of entities.
    """
    global nlp
    if nlp is None:
        # logging.info('Loading spacy once')
        try:
            assert spacy is not None
            nlp = spacy.load("en_core_web_sm")
        except Exception:
            raise RuntimeError(
                'Please download: python -m spacy download en_core_web_sm'
            )
    doc = nlp(sentence)
    results = []
    if pos:
        for token in doc:
            if token.pos_ in pos:
                results.append(token)
    if use_named_entities:
        for ent in doc.ents:
            results.append(ent)
    if use_noun_chunks:
        for chunk in doc.noun_chunks:
            if chunk.text.lower() not in STOP_WORDS:
                results.append(chunk)
    results = list(set([r.text for r in results]))
    return results


def _normalize_persona_line(x: str) -> str:
    """
    Normalize a persona line.
    """
    if x.startswith('your persona: '):
        # Normalize the sentence appearing after 'your persona:'
        x = x[len('your persona: ') :]
        x = normalize_reply(x)
        x = 'your persona: ' + x
    elif x.startswith("partner's persona: "):
        x = x[len("partner's persona: ") :]
        x = normalize_reply(x)
        x = "partner's persona: " + x
    return x


def calc_f1_msc(pred, gold_items):
    """
    Calculate F1 overlap between prediction sentence and gold labels.

    :param pred:
        prediction string
    :param gold_items:
        list of gold items

    :return f1_overlap:
    """
    try:
        assert nltk is not None
        pred_items = nltk.word_tokenize(pred)
    except IndexError:
        # malformed prediction; return 0
        return 0
    common = Counter(gold_items) & Counter(pred_items)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_items)
    recall = 1.0 * num_same / len(gold_items)
    f1 = (2 * precision * recall) / (precision + recall)

    context_ent = extract_entities(' '.join(gold_items))
    label_ent = extract_entities(pred)
    ents = set(context_ent).intersection(label_ent)
    if len(ents) == 0:
        f1 = 0

    return f1


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def create_message_from_tuple(entry):
    message = {}
    if entry[0] is not None:
        message[CONST.MESSAGE_TEXT] = entry[0]
    if len(entry) > 1 and entry[1] is not None:
        l = entry[1]
        if isinstance(l, str):
            l = [l]
        message[CONST.LABELS] = l
    if len(entry) > 2 and entry[2] is not None:
        message['reward'] = entry[2]
    if len(entry) > 3 and entry[3] is not None:
        message['label_candidates'] = entry[3]
    if len(entry) > 4 and entry[4] is not None:
        raise ValueError
    return message


def build_all_personas(config):
    """
        Build the personas list from which we sample memories for MDM.
        """
    personas_path = os.path.join(
        config['datapath'], 'blended_skill_talk', 'persona_list.txt'
    )
    if not os.path.exists(personas_path):
        raise NotImplementedError

    with open(personas_path) as f:
        all_personas = [l.replace('||', '').replace('\n', '') for l in f.readlines()][
                       :-1
                       ]
    return all_personas