"""
Provides standard metric evaluations for dialog, as well as an aggregator.

Original code from: parlai.core.metrics
"""

from __future__ import annotations

import re
import torch
from abc import ABC, abstractmethod
from collections import Counter
import functools
import datetime
import math
import numpy as np
from typing import (
    Any,
    Counter as TCounter,
    Dict,
    List,
    NamedTuple,
    Optional,
    Set,
    Tuple,
    Union,
)
from hexa.utils.message import Message


TScalar = Union[int, float, torch.Tensor]
TVector = Union[List[TScalar], torch.Tensor]

DEFAULT_METRICS = {'bleu-4', 'accuracy', 'f1'}
ROUGE_METRICS = {'rouge-1', 'rouge-2', 'rouge-L'}
ROUGE_METRICS_MEASURES = {'r', 'f', 'p'}
BLEU_METRICS = {'bleu-1', 'bleu-2', 'bleu-3', 'bleu-4'}
DISTINCT_METRICS = {
    'interdistinct-1',
    'interdistinct-2',
    'intradistinct-1',
    'intradistinct-2',
}
ALL_METRICS = DEFAULT_METRICS | ROUGE_METRICS | BLEU_METRICS | DISTINCT_METRICS


@functools.total_ordering  # type: ignore
class Metric(ABC):
    """
    Base class for storing metrics.

    Subclasses should define .value(). Examples are provided for each subclass.
    """

    @property
    def is_global(self) -> bool:
        """
        Indicates whether this metric should be reported globally or per-task.
        """
        return False

    @property
    def macro_average(self) -> bool:
        """
        Indicates whether this metric should be macro-averaged when globally reported.
        """
        return False

    @abstractmethod
    def value(self) -> float:
        """
        Return the value of the metric as a float.
        """
        pass

    @abstractmethod
    def __add__(self, other: Any) -> Metric:
        raise NotImplementedError

    def __iadd__(self, other):
        return self.__radd__(other)

    def __radd__(self, other: Any):
        if other is None:
            return self
        return self.__add__(other)

    def __str__(self) -> str:
        return f'{self.value():.4g}'

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.value():.4g})'

    def __float__(self) -> float:
        return float(self.value())

    def __int__(self) -> int:
        return int(self.value())

    def __eq__(self, other: Any) -> bool:
        if isinstance(other, Metric):
            return self.value() == other.value()
        else:
            return self.value() == other

    def __lt__(self, other: Any) -> bool:
        if isinstance(other, Metric):
            return self.value() < other.value()
        else:
            return self.value() < other

    def __sub__(self, other: Any) -> float:
        """
        Used heavily for assertAlmostEqual.
        """
        if not isinstance(other, float):
            raise TypeError('Metrics.__sub__ is intentionally limited to floats.')
        return self.value() - other

    def __rsub__(self, other: Any) -> float:
        """
        Used heavily for assertAlmostEqual.

        NOTE: This is not necessary in python 3.7+.
        """
        if not isinstance(other, float):
            raise TypeError('Metrics.__rsub__ is intentionally limited to floats.')
        return other - self.value()

    @classmethod
    def as_number(cls, obj: TScalar) -> Union[int, float]:
        if isinstance(obj, torch.Tensor):
            obj_as_number: Union[int, float] = obj.item()
        elif isinstance(obj, np.generic):
            obj_as_number: Union[int, float] = obj.item()
        else:
            obj_as_number = obj  # type: ignore
        assert isinstance(obj_as_number, int) or isinstance(obj_as_number, float)
        return obj_as_number

    @classmethod
    def as_float(cls, obj: TScalar) -> float:
        return float(cls.as_number(obj))

    @classmethod
    def as_int(cls, obj: TScalar) -> int:
        return int(cls.as_number(obj))

    @classmethod
    def many(cls, *objs: List[TVector]) -> List[Metric]:
        """
        Construct many of a Metric from the base parts.

        Useful if you separately compute numerators and denomenators, etc.
        """
        lengths = [len(o) for o in objs]
        objs = list(objs)  # convert from tuple for inplace modification
        for i, o in enumerate(objs):
            if isinstance(o, torch.Tensor) or isinstance(o, np.ndarray):
                # if the tensor is on GPU, make sure we transfer the whole thing
                # at once, instead of one-element-at-a-time during our list
                # comprehension
                objs[i] = o.tolist()
        if len(set(lengths)) != 1:
            raise IndexError(f'Uneven {cls.__name__} constructions: {lengths}')
        return [cls(*items) for items in zip(*objs)]


class FixedMetric(Metric):
    """
    Fixed metrics are verified to be the same when combined, or throw an error.

    FixedMetric is used for things like total_train_updates, which should not be
    combined across different multitasks or different workers.
    """

    __slots__ = ('_value',)

    def __init__(self, value: TScalar):
        self._value = self.as_number(value)

    def __add__(self, other: Optional[FixedMetric]) -> FixedMetric:
        if other is None:
            return self
        return other

    def value(self) -> float:
        return self._value


class AverageMetric(Metric):
    """
    Class that keeps a running average of some metric.

    Examples of AverageMetrics include hits@1, F1, accuracy, etc. These metrics all have
    per-example values that can be directly mapped back to a teacher.
    """

    __slots__ = ('_numer', '_denom')

    @property
    def macro_average(self) -> bool:
        """
        Indicates whether this metric should be macro-averaged when globally reported.
        """
        return True

    def __init__(self, numer: TScalar, denom: TScalar = 1):
        self._numer = self.as_number(numer)
        self._denom = self.as_number(denom)

    def __add__(self, other: Optional[AverageMetric]) -> AverageMetric:
        # NOTE: hinting can be cleaned up with "from __future__ import annotations" when
        # we drop Python 3.6
        if other is None:
            return self
        full_numer: TScalar = self._numer + other._numer
        full_denom: TScalar = self._denom + other._denom
        # always keep the same return type
        return type(self)(numer=full_numer, denom=full_denom)

    def value(self) -> float:
        if self._numer == 0 and self._denom == 0:
            # don't nan out if we haven't counted anything
            return 0.0
        if self._denom == 0:
            return float('nan')
        return self._numer / self._denom

class SumMetric(Metric):
    """
    Class that keeps a sum of some metric.
    """
    def __init__(self, val: TScalar, denom: TScalar = 1):
        self._val = self.as_number(val)

    def __add__(self, other: Optional[SumMetric]) -> AverageMetric:
        if other is None:
            return self
        full_val: TScalar = self._val + other._val
        # always keep the same return type
        return type(self)(val=full_val)

    def value(self) -> float:
        return self._val

class TimerMetric(Metric):
    """
    A timer metric keep tracks of the first/last times it was used.
    """

    __slots__ = ('_value', '_start', '_end')

    @classmethod
    def _now(cls) -> float:
        return datetime.datetime.utcnow().timestamp()

    def __init__(
        self,
        value: TScalar,
        start_time: Optional[float] = None,
        end_time: Optional[float] = None,
    ):
        self._value = self.as_number(value)
        if start_time is None:
            start_time = self._now()
        if end_time is None:
            end_time = self._now()
        self._start = start_time
        self._end = end_time

    def __add__(self, other: Optional[TimerMetric]) -> TimerMetric:
        # NOTE: hinting can be cleaned up with "from __future__ import annotations" when
        # we drop Python 3.6
        if other is None:
            return self
        total: TScalar = self._value + other._value
        start: float = min(self._start, other._start)
        end: float = max(self._end, other._end)
        return type(self)(total, start, end)

    def value(self) -> float:
        if self._value == 0 or self._end == self._start:
            return 0
        return self._value / (self._end - self._start)


class GlobalMetric:
    """
    A global metric is one that should not be aggregated across different tasks.

    Examples of global metric include things like learning rate and updates.
    These need to be accumulated or averaged over multiple parleys, but cannot
    be correlated with a single task.

    Key to it is the notion that any one worker or any one task already has a global
    view of the value, and so no combinations should be done. Note this is different
    then a FixedMetric, in that a GlobalMetric can be still averaged across multiple
    parleys(), but a FixedMetric is always fixed.
    """

    @property
    def is_global(self) -> bool:
        return True


class GlobalFixedMetric(GlobalMetric, FixedMetric):
    """
    Global fixed metric.

    Used for things like total_train_updates.
    """

    pass


class GlobalAverageMetric(GlobalMetric, AverageMetric):
    """
    Global Average metric.

    Used for things like learning rate, and many agent-specific metrics.
    """

    pass


class GlobalTimerMetric(GlobalMetric, TimerMetric):
    pass


class PPLMetric(AverageMetric):
    def value(self):
        return math.exp(super().value())


class Metrics(object):
    """
    Metrics aggregator.
    """

    def __init__(self, threadsafe=False, shared=None):
        if shared and 'data' in shared:
            # This is a clone
            self._data = shared['data']
        else:
            # The original
            self._data = {}

        # recent data is to track per-example metrics, and so should never be
        # shared
        self._recent_data = {}

    def __str__(self):
        return str(self._data)

    def __repr__(self):
        return f'Metrics({repr(self._data)})'

    def add(self, key: str, value: Optional[Metric]) -> None:
        """
        Record an accumulation to a metric.
        """
        self._data[key] = self._data.get(key) + value
        self._recent_data[key] = self._recent_data.get(key) + value

    def report(self):
        """
        Report the metrics over all data seen so far.
        """
        return self._data.copy()

    def clear_recent(self):
        """
        Clear recent metrics (latest example).
        """
        self._recent_data.clear()

    def report_recent(self):
        """
        Report recent metrics (latest example).
        """
        return self._recent_data.copy()

    def clear(self):
        """
        Clear all the metrics.
        """
        self._data.clear()
        self._recent_data.clear()

    def share(self):
        return {'data': self._data}

    def add_metrics(self, other: "Metrics") -> None:
        """
        Aggregate another Metrics objects metrics into this one.

        Note that it is assumed that the keys for metrics are disjoint between Metrics
        objects.
        """
        for k, v in other._data.items():
            self.add(k, v)


class MetricLogger:
    metric_classes = {
        'average': AverageMetric,
        'timer': TimerMetric,
        'fixed': FixedMetric,
        'ppl': PPLMetric,
        'sum': SumMetric
    }

    def __init__(self, mode:str, config_dict:Dict[str, str]):
        self.mode = mode
        self.config = config_dict
        self.metric_class = {}
        self.metrics = {}
        self.task_metrics = {}
        for key, metric_type in self.config.items():
            self.metric_class[key] = self.metric_classes[metric_type]

    def init(self, keys: Union[List, str]):
        if type(keys) == str:
            keys = [keys]
        for key in keys:
            if key in self.metrics:
                self.metrics[key] = None
            for task_id in self.task_metrics.keys():
                if key in self.task_metrics[task_id]:
                    self.task_metrics[task_id][key] = None

    def log(self, key, *args):
        metric_class = self.metric_class[key]
        item = metric_class(*args)
        if not key in self.metrics:
            self.metrics[key] = item
        else:
            self.metrics[key] += item

    def log_many(self, key, task_ids, *args):
        metric_class = self.metric_class[key]
        item = np.sum(metric_class.many(*args))
        if not key in self.metrics:
            self.metrics[key] = item
        else:
            self.metrics[key] += item

        if task_ids is not None:
            task_args = list(zip(*args))
            for i, task_id in enumerate(task_ids):
                _args = task_args[i]
                item = metric_class(*_args)
                if not task_id in self.task_metrics:
                    self.task_metrics[task_id] = {}
                if not key in self.task_metrics[task_id]:
                    self.task_metrics[task_id][key] = item
                else:
                    self.task_metrics[task_id][key] += item

    def keys(self):
        return self.metrics.keys()

    def get_log_item(self, key: str, task_id: Optional[str] = None):
        if self.mode:
            logkey = ''
            if task_id is not None:
                logkey += f'{task_id}/'
            logkey += f'{key}/{self.mode}'
        else:
            logkey = key
        logval = None

        if task_id is None:
            if key in self.metrics:
                if self.metrics[key] is not None:
                    logval = self.metrics[key].value()
        else:
            if key in self.task_metrics[task_id]:
                logitem = self.task_metrics[task_id][key]
                if logitem is None:
                    pass
                    # print('here')
                else:
                    logval = logitem.value()
        return logkey, logval

    def update_logs(self, logs: Dict, init=True):
        for key in self.metrics.keys():
            logkey, logval = self.get_log_item(key)
            if logval is not None:
                logs[logkey] = logval

        for task_id in self.task_metrics.keys():
            for key in self.task_metrics[task_id].keys():
                logkey, logval = self.get_log_item(key, task_id)
                if logval is not None:
                    logs[logkey] = logval

        # FIXME: not sure the logging items should be initialized always after logging
        if init:
            self.init(self.metrics.keys())
        return logs


re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')


def normalize_answer(s):
    """
    Lower text and remove punctuation, articles and extra whitespace.
    """

    s = s.lower()
    s = re_punc.sub(' ', s)
    s = re_art.sub(' ', s)
    # TODO: this could almost certainly be faster with a regex \s+ -> ' '
    s = ' '.join(s.split())
    return s


class F1Metric(AverageMetric):
    """
    Helper class which computes token-level F1.
    """

    @staticmethod
    def _prec_recall_f1_score(pred_items, gold_items):
        """
        Compute precision, recall and f1 given a set of gold and prediction items.
        :param pred_items: iterable of predicted values
        :param gold_items: iterable of gold values
        :return: tuple (p, r, f1) for precision, recall, f1
        """
        common = Counter(gold_items) & Counter(pred_items)
        num_same = sum(common.values())
        if num_same == 0:
            return 0, 0, 0
        precision = 1.0 * num_same / len(pred_items)
        recall = 1.0 * num_same / len(gold_items)
        f1 = (2 * precision * recall) / (precision + recall)
        return precision, recall, f1

    @staticmethod
    def compute(
        guess: str, answers: List[str], expose_p_and_r: bool = False
    ) -> Union[F1Metric, Tuple[F1Metric, F1Metric, F1Metric]]:
        if guess is None or answers is None:
            return AverageMetric(0, 0)
        g_tokens = normalize_answer(guess).split()
        scores = [
            F1Metric._prec_recall_f1_score(g_tokens, normalize_answer(a).split())
            for a in answers
        ]
        max_p, max_r, max_f1 = 0, 0, 0
        for p, r, f1 in scores:
            max_p, max_r, max_f1 = max(max_p, p), max(max_r, r), max(f1, max_f1)
        if expose_p_and_r:
            return (F1Metric(max_p, 1), F1Metric(max_r, 1), F1Metric(max_f1, 1))
        else:
            return F1Metric(max_f1, 1)


class RougeMetric(AverageMetric):
    _evaluator = None

    @staticmethod
    def compute_many(
        guess: str, answers: List[str], measure: str = 'r'
    ) -> Tuple[Optional[RougeMetric], Optional[RougeMetric], Optional[RougeMetric]]:
        """
        Compute ROUGE score between guess and *any* answer.
        Done with compute_many due to increased efficiency.
        :return: (rouge-1, rouge-2, rouge-L)
        """
        measure = measure.lower()
        assert (
            measure in ROUGE_METRICS_MEASURES
        ), "Use one of recall 'r' (default), f1 'f', or precision 'p'."

        # possible global initialization
        try:
            import rouge
        except ImportError:
            # User doesn't have py-rouge installed, so we can't use it.
            # We'll just turn off rouge computations
            return None, None, None

        if RougeMetric._evaluator is None:
            RougeMetric._evaluator = rouge.Rouge(
                metrics=['rouge-n', 'rouge-l'], max_n=2
            )
        try:
            scores = [
                RougeMetric._evaluator.get_scores(
                    normalize_answer(guess), normalize_answer(a)
                )
                for a in answers
            ]
        except LookupError:
            # warn_once(
            #     'ROUGE requires nltk punkt tokenizer. Please run '
            #     '`python -c "import nltk; nltk.download(\'punkt\')`'
            # )
            return None, None, None

        scores_rouge1 = max(score['rouge-1'][measure] for score in scores)
        scores_rouge2 = max(score['rouge-2'][measure] for score in scores)
        scores_rougeL = max(score['rouge-l'][measure] for score in scores)
        return (
            RougeMetric(scores_rouge1),
            RougeMetric(scores_rouge2),
            RougeMetric(scores_rougeL),
        )


class ExactMatchMetric(AverageMetric):
    @staticmethod
    def compute(guess: str, answers: List[str]) -> ExactMatchMetric:
        if guess is None or answers is None:
            return None
        guess = normalize_answer(guess)
        for a in answers:
            if guess == normalize_answer(a):
                return ExactMatchMetric(1)
        return ExactMatchMetric(0)


class BleuMetric(AverageMetric):
    @staticmethod
    def compute(guess: str, answers: List[str], k: int = 4) -> Optional[BleuMetric]:
        """
        Compute approximate BLEU score between guess and a set of answers.
        """
        try:
            from nltk.translate import bleu_score as nltkbleu
        except ImportError:
            # User doesn't have nltk installed, so we can't use it for bleu
            # We'll just turn off things, but we might want to warn the user
            return None

        # Warning: BLEU calculation *should* include proper tokenization and
        # punctuation etc. We're using the normalize_answer for everything though,
        # so we're over-estimating our BLEU scores.  Also note that NLTK's bleu is
        # going to be slower than fairseq's (which is written in C), but fairseq's
        # requires that everything be in arrays of ints (i.e. as tensors). NLTK's
        # works with strings, which is better suited for this module.
        weights = [1 / k for _ in range(k)]
        score = nltkbleu.sentence_bleu(
            [normalize_answer(a).split(" ") for a in answers],
            normalize_answer(guess).split(" "),
            smoothing_function=nltkbleu.SmoothingFunction(epsilon=1e-12).method1,
            weights=weights,
        )
        return BleuMetric(score)


class InterDistinctMetric(Metric):
    """
    Compute inter-distinct metric over corpus-level.
    """

    def __init__(self, counts: TCounter[Tuple]):
        """
        :param counts:
            collections.Counter of ngram -> frequency
        """
        self._counts = counts

    def __add__(self, other):
        return InterDistinctMetric(self._counts + other._counts)

    def value(self):
        return max(len(self._counts), 1e-12) / max(sum(self._counts.values()), 1e-5)

    @classmethod
    def _ngram(cls, seq, n):
        for i in range(len(seq) - n + 1):
            yield tuple(seq[i : i + n])

    @classmethod
    def compute(cls, text, ngram=1):
        tokens = normalize_answer(text).split()
        return InterDistinctMetric(Counter(cls._ngram(tokens, ngram)))


class TeacherMetrics(Metrics):
    """
    Helper container which encapsulates standard metrics (F1, BLEU, ...).
    """

    def __init__(
        self, metrics_list: str = "default", shared: Dict[str, Any] = None
    ) -> None:
        super().__init__(shared=shared)
        self._metrics_list = self._infer_metrics(metrics_list)
        self.eval_pr = [1, 5, 10, 100]

    @staticmethod
    def _infer_metrics(cli_arg: str) -> Set[str]:
        """
        Parse the CLI metric into a list of metrics we wish to compute.
        """
        col: Set[str] = set()
        names = cli_arg.split(",")
        for n in names:
            if n == 'default':
                col |= DEFAULT_METRICS
            elif n == 'rouge':
                col |= ROUGE_METRICS
            elif n == 'bleu':
                col |= BLEU_METRICS
            elif n == 'distinct':
                col |= DISTINCT_METRICS
            elif n == 'all':
                col |= ALL_METRICS
            else:
                col.add(n)
        return col

    def _update_ranking_metrics(self, observation, labels):
        text_cands = observation.get('text_candidates', None)
        if text_cands is None:
            return

        # Now loop through text candidates, assuming they are sorted.
        # If any of them is a label then score a point.
        # maintain hits@1, 5, 10, 50, 100,  etc.
        label_set = set(normalize_answer(l) for l in labels)
        cnts = {k: 0 for k in self.eval_pr}
        cnt = 0
        for c in text_cands:
            cnt += 1
            if normalize_answer(c) in label_set:
                for k in self.eval_pr:
                    if cnt <= k:
                        cnts[k] += 1
        # hits metric is 1 if cnts[k] > 0.
        # (other metrics such as p@k and r@k take
        # the value of cnt into account.)
        for k in self.eval_pr:
            self.add(f'hits@{k}', AverageMetric(cnts[k] > 0))

    def evaluate_response(self, observation: Message, labels: List[str]) -> None:
        """
        Compute all required text-based metrics based on an observation and labels.
        """
        prediction = observation.get('text', None)

        self.add('exs', SumMetric(1))

        if prediction is not None:
            self.add('accuracy', ExactMatchMetric.compute(prediction, labels))
            precision, recall, f1 = F1Metric.compute(
                prediction, labels, expose_p_and_r=True
            )
            self.add('precision', precision)
            self.add('recall', recall)
            self.add('f1', f1)

            for k in range(1, 5):  # 1..4
                if f'bleu-{k}' in self._metrics_list:
                    self.add(f'bleu-{k}', BleuMetric.compute(prediction, labels, k))
            # if any of the rouges are in the list
            if self._metrics_list & ROUGE_METRICS:
                r1, r2, rL = RougeMetric.compute_many(prediction, labels)
                if 'rouge-1' in self._metrics_list and r1:
                    self.add('rouge_1', r1)
                if 'rouge-2' in self._metrics_list and r2:
                    self.add('rouge_2', r2)
                if 'rouge-L' in self._metrics_list and rL:
                    self.add('rouge_L', rL)
            # compute distinct-k
            for k in [1, 2]:
                if f'interdistinct-{k}' in self._metrics_list:
                    self.add(
                        f'interdistinct-{k}', InterDistinctMetric.compute(prediction, k)
                    )
                if f'intradistinct-{k}' in self._metrics_list:
                    self.add(
                        f'intradistinct-{k}', IntraDistinctMetric.compute(prediction, k)
                    )

        # Ranking metrics.
        self._update_ranking_metrics(observation, labels)

        self._consume_user_metrics(observation)

    def _consume_user_metrics(self, observation):
        # User-reported metrics
        if 'metrics' in observation:
            for uk, v in observation['metrics'].items():
                if v is None:
                    continue
                if uk in ALL_METRICS:
                    # don't let the user override our metrics
                    uk = f'USER_{uk}'
                assert isinstance(uk, str), f'{type(uk)} is not a str'
                if not isinstance(v, Metric):
                    # warn_once(f'Metric {uk} is assumed to be averaged per example.')
                    v = AverageMetric(v)
                assert isinstance(v, Metric)
                self.add(uk, v)