# -*- coding: utf-8 -*-

"""F1 Scorer."""

from typing import List

from scorers.base_scorer import BaseScorer


class FScorer(BaseScorer):
    """Corpus level F1 Score evaluator."""

    def __init__(self):
        """Initialize class."""
        self.__precision = []
        self.__recall = []

    @classmethod
    def from_scorers(cls, scorers: List['FScorer']) -> 'FScorer':
        """Get new scorers that is the ensamble of the scorers.

        Args:
            scorers: list of scorers

        Returns:
            FScorer: a new FScorer

        """
        new_scorer = cls()
        for scorer in scorers:
            new_scorer.__precision.extend(scorer.__precision)
            new_scorer.__recall.extend(scorer.__recall)
        return new_scorer
    
    def in_annotations(self, ann: dict, reference_anns: List[dict]) -> bool:
        """Check if annotation is present in reference annotations list.

        Args:
            ann: annotation to check against
            reference_anns: list of reference annotations

        Returns:
            bool: depending on the condition

        """
        for item in reference_anns:
            if item['key'] == ann['key']:
                assert len(item['values']) == len(ann['values']) == 1
                return item['values'][0]['value'] == ann['values'][0]['value'] \
                       or ('value_variants' in item['values'][0]
                           and ann['values'][0]['value'] in item['values'][0]['value_variants'])
        return False

    def remove_ann(self, anns: List[dict], ann: dict) -> dict:
        """Return annotations excluding provided as the second argument.

        Args:
            anns: annotations to filter
            ann: annotation to exclude

        Returns:
            annotations: anns without ann

        """
        return [a for a in anns if a != ann]

    def add(self, out_items: List[dict], ref_items: List[dict]):
        """Add more items for computing corpus level scores.

        Args:
            out_items: outs from a single document (line)
            ref_items: reference of the evaluated document (line)

        """
        ref_items_copy = ref_items['annotations'].copy()
        indicators = []
        for pred in out_items['annotations']:
            if self.in_annotations(pred, ref_items_copy):
                indicators.append(1)
                self.remove_ann(ref_items_copy, pred)
            else:
                indicators.append(0)
        self.__add_to_precision(indicators)

        indicators = []
        out_items_copy = out_items['annotations'].copy()
        for ref in ref_items['annotations']:
            if self.in_annotations(ref, out_items_copy):
                indicators.append(1)
                self.remove_ann(out_items_copy, ref)
            else:
                indicators.append(0)
        self.__add_to_recall(indicators)

    def __add_to_precision(self, item: List[int]):
        if isinstance(item, list):
            self.__precision.extend(item)
        else:
            self.__precision.append(item)

    def __add_to_recall(self, item: List[int]):
        if isinstance(item, list):
            self.__recall.extend(item)
        else:
            self.__recall.append(item)

    def precision(self) -> float:
        """Compute precision.

        Returns:
            float: corpus level precision

        """
        if self.__precision:
            precision = sum(self.__precision) / len(self.__precision)
        else:
            precision = 0.0
        return precision

    @property
    def precision_support(self):
        return self.__precision

    @property
    def recall_support(self):
        return self.__recall

    def recall(self) -> float:
        """Compute recall.

        Returns:
            float: corpus level recall

        """
        if self.__recall:
            recall = sum(self.__recall) / len(self.__recall)
        else:
            recall = 0.0
        return recall

    def f_score(self) -> float:
        """Compute F1 score.

        Returns:
            float: corpus level F1 score.

        """
        precision = self.precision()
        recall = self.recall()
        if precision or recall:
            fscore = 2 * precision * recall / (precision + recall)
        else:
            fscore = 0.0
        return fscore

    def false_negative(self) -> int:
        """Return the number of false negatives.

        Returns:
            int: number of false negatives.

        """
        return len(self.__recall) - sum(self.__recall)

    def false_positive(self) -> int:
        """Return the number of false positives.

        Returns:
            int: number of false positives.

        """
        return len(self.__precision) - sum(self.__precision)

    def true_positive(self) -> int:
        """Return number of true positives.

        Returns:
            int: number of true positives.

        """
        return sum(self.__precision)

    def condition_positive(self) -> int:
        """Return number of condition positives.

        Returns:
            int: number of condition positives.

        """
        return len(self.__precision)

    def score(self):
        return self.f_score()

    @classmethod
    def support_feature_scores(cls) -> bool:
        return True

    @classmethod
    def metric_name(cls) -> str:
        return "F1"
