#! /usr/bin/env python3
# coding=utf-8


import spacy
import torch
import torch.nn as nn
import numpy as np
from datasets import load_metric
from spacy.tokens import Token
from spacy.lang.en.stop_words import STOP_WORDS

EXT_COFF = 0.8
ODR_COFF = 0.2
TASK_ALPHA = 0

nlp = spacy.load("en_core_web_sm", disable=['ner', 'parser'])

stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter, force=True)  # set attribute with getter


sacrebleu_metric = load_metric('sacrebleu', keep_in_memory=True)


def compute_reward(gen_tokens, labels_tokens, spice_scores):
    """
    compute reward based on actions and labels
    :param gen_tokens: single sent token ids
    :param labels_tokens: multiple ref sentences token ids
    :param spice_scores: the spice score for each sentence in labels
    :return: <float> a reward
    """
    scores = []
    for ref, spice in zip(labels_tokens, spice_scores):
        scores.append(spice * sacrebleu_metric.compute(predictions=[gen_tokens],
                                                       references=[[ref]])['score'])
    return np.mean(scores)


def _get_marker(entity, all_entities):
    ret = [0] * len(all_entities)
    for idx, ent in enumerate(all_entities):
        if ent in entity:
            ret[idx] = 1
    return ret


def _get_edit_dist(entity_a, entity_b):
    # This version is commutative, so as an optimization we force |a|>=|b|
    if len(entity_a) < len(entity_b):
        return _get_edit_dist(entity_b, entity_a)
    if len(entity_b) == 0:  # Can deal with empty sequences faster
        return len(entity_a)
    # Only two rows are really needed: the one currently filled in, and the previous
    distances = []
    distances.append([i for i in range(len(entity_b) + 1)])
    distances.append([0 for _ in range(len(entity_b) + 1)])
    # We can prefill the first row:
    costs = [0 for _ in range(3)]
    for i, a_token in enumerate(entity_a, start=1):
        distances[1][0] += 1  # Deals with the first column.
        for j, b_token in enumerate(entity_b, start=1):
            costs[0] = distances[1][j - 1] + 1
            costs[1] = distances[0][j] + 1
            costs[2] = distances[0][j - 1] + (0 if a_token == b_token else 1)
            distances[1][j] = min(costs)
        # Move to the next row:
        distances[0][:] = distances[1][:]
    return distances[1][len(entity_b)]  # normalize


def compute_new_reward(pred_sent, label_sents, device):
    """
    """
    dists = []
    for label_sent in label_sents:
        dist = torch.FloatTensor([0]).to(device)
        entities_pred = [token.lemma_.lower() for token in nlp(pred_sent) if not token.is_stop]

        entities_label = [token.lemma_.lower() for token in nlp(label_sent) if not token.is_stop]

        all_entities = list(set(entities_pred + entities_label))
        pred_seq = torch.FloatTensor(_get_marker(entities_pred, all_entities)).to(device)
        label_seq = torch.FloatTensor(_get_marker(entities_label, all_entities)).to(device)  # though we need Integer

        EXT_DIST = EXT_COFF * torch.dist(pred_seq, label_seq, p=1)
        ORD_DIST = ODR_COFF * _get_edit_dist(entities_pred, entities_label)

        dist = EXT_DIST + torch.tensor(ORD_DIST).to(device)
        dists.append(dist)
    dists = [dist for dist in dists if not torch.isnan(dist)]
    if len(dists) == 0:
        return torch.tensor(0.0)

    dists = torch.stack(dists)
    pos_dist = torch.min(dists)

    all_negs = dists[dists > pos_dist]
    neg_dist = torch.mean(all_negs) if len(all_negs) != 0 else torch.tensor(0.0)

    margin = pos_dist - neg_dist + TASK_ALPHA
    # return -torch.where(margin > 0, margin, torch.tensor(0.0))
    return_reward = torch.clamp(margin, max=0.0)
    # print(return_reward)
    return return_reward


