from typing import List, Tuple
import editdistance
from omegaconf import DictConfig, OmegaConf

from pado.core.base.metric import PadoMetric

__all__ = ["EditDistance", "WordErrorRate", "CharErrorRate"]


class EditDistance(PadoMetric):

    def __init__(self,
                 space_symbol: str = " ",
                 blank_symbol: str = "<b>",
                 unk_symbol: str = "<unk>",
                 *, ignore_unk: bool = True) -> None:
        super().__init__()
        self.space = space_symbol
        self.blank = blank_symbol
        self.unk = unk_symbol
        self.ignore_unk = ignore_unk

    def _remove_symbols(self, s: str) -> str:
        # remove blank
        new_s = s.replace(self.blank, "")
        if self.ignore_unk:
            new_s = new_s.replace(self.unk, "")
        new_s = new_s.replace("  ", " ").strip()
        return new_s

    def _calculate_distance(self, pred: str, target: str) -> Tuple[int, int]:
        # return distance and target length
        raise NotImplementedError

    def forward(self, pred: List[str], target: List[str]) -> Tuple[float, int, int]:
        """
        Assume strings are already decoded.
        pred:       (batch, seq)
        target:     (batch, seq)

        Return      (distance / length, distance, length)
        """
        if len(pred) != len(target):
            raise ValueError(f"EditDistance length mismatch: pred({len(pred)}) vs target({len(target)}).")

        distance, length = [], []
        for p, t in zip(pred, target):
            d, t_l = self._calculate_distance(p, t)
            distance.append(d)
            length.append(t_l)

        distance_sum = sum(distance)
        length_sum = sum(length)
        return distance_sum / length_sum, distance_sum, length_sum

    @classmethod
    def from_config(cls, cfg: DictConfig):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)


class CharErrorRate(EditDistance):

    def _calculate_distance(self, pred: str, target: str) -> Tuple[int, int]:
        pred = self._remove_symbols(pred).replace(self.space, "")
        target = self._remove_symbols(target).replace(self.space, "")

        distance = editdistance.eval(pred, target)
        length = len(target)
        return distance, length


class WordErrorRate(EditDistance):

    def _calculate_distance(self, pred: str, target: str) -> Tuple[int, int]:
        pred = self._remove_symbols(pred).split(self.space)
        target = self._remove_symbols(target).split(self.space)

        distance = editdistance.eval(pred, target)
        length = len(target)
        return distance, length
