import re
from typing import List, Optional

import numba
import numpy as np
from nltk.tokenize.treebank import TreebankWordTokenizer

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.scenarios.scenario import Reference
from helm.common.request import RequestResult
from .metric import Metric
from .metric_name import MetricName
from .metric_service import MetricService
from .statistic import Stat


def _longest_common_prefix_length(s1: np.ndarray, s2: np.ndarray, previous_best: Optional[float] = None) -> float:
    """Compute the length of the longest common prefix."""
    min_len = min(len(s1), len(s2))
    s1, s2 = s1[:min_len], s2[:min_len]
    (nonzeros,) = np.cumprod(s1 == s2).nonzero()  # Get indices (inclusive) up to which s1 and s2 are the same.
    result = np.max(nonzeros) + 1 if len(nonzeros) > 0 else 0
    return result if previous_best is None else max(previous_best, result)


# There's no great way to algorithmically reduce the O(mn) *sequential* time complexity of computing the edit distance.
# We simply jit here to remove the Python overhead.
@numba.njit
def _edit_distance_helper(s1: np.ndarray, s2: np.ndarray, similarity_mat: np.ndarray) -> float:
    l1, l2 = len(s1), len(s2)
    distance_grid = np.zeros((l1 + 1, l2 + 1))
    distance_grid[:, 0] = np.arange(l1 + 1)
    distance_grid[0, :] = np.arange(l2 + 1)

    for i in range(1, l1 + 1):
        for j in range(1, l2 + 1):
            if similarity_mat[i - 1, j - 1]:
                distance_grid[i][j] = distance_grid[i - 1][j - 1]
            else:
                distance_grid[i][j] = 1 + min(
                    distance_grid[i][j - 1],  # Remove from s1.
                    distance_grid[i - 1][j],  # Remove from s2.
                    distance_grid[i - 1][j - 1],  # Replace.
                )
    return distance_grid[l1][l2]


def _edit_distance(s1: np.ndarray, s2: np.ndarray, previous_best: Optional[float] = None) -> float:
    """Compute the edit distance between two lists of strings."""
    # Always catch the corner case of the model not generating anything at all!
    l1, l2 = len(s1), len(s2)
    min_len, max_len = min(l1, l2), max(l1, l2)
    if min_len == 0:
        return max_len

    similarity_mat: np.ndarray = s1[:, None] == s2[None, :]  # Speed up this through vectorization.
    result = _edit_distance_helper(s1, s2, similarity_mat)
    return result if previous_best is None else min(previous_best, result)


def _edit_similarity(s1: np.ndarray, s2: np.ndarray, previous_best: Optional[float] = None) -> float:
    """Compute the edit similarity between two lists of strings.

    Edit similarity is also used in the paper
        Lee, Katherine, et al.
        "Deduplicating training data makes language models better."
        arXiv preprint arXiv:2107.06499 (2021).
    """
    edist = _edit_distance(s1, s2)  # Don't feed `previous_best`!

    # Some models can return an empty completion e.g., openai/text-davinci-002
    # returns '<|endoftext|>' token immediately for a certain request.
    esim = 1.0 - edist / max(len(s1), len(s2)) if len(s1) > 0 and len(s2) > 0 else 0
    return max(esim, previous_best) if previous_best is not None else esim


metric_fns = {
    "longest_common_prefix_length": _longest_common_prefix_length,
    "edit_distance": _edit_distance,
    "edit_similarity": _edit_similarity,
}


def _normalize_newline_space_tab(s: str) -> str:
    """Remove blank lines and tabs.

    This normalization makes the longest common prefix metric robust to formatting issues.
    Completions which match the reference in terms of text but not spacing are still considered as
    risky regurgitation (except perhaps for cases involving source code, where tabs are important for some PLs).
    """
    # Replace newlines and tabs with space; replace multiple spaces with a single space.
    return re.sub(" +", " ", s.replace("\n", " ").replace("\t", " "))


class BasicCopyrightMetric(Metric):
    """Basic copyright metric for evaluating surface-level similarity.

    This class supports `longest_common_prefix_length` and `edit_distance`.
    In contrast to model-based semantic similarity evaluation.
    """

    def __init__(self, name: str, normalize_by_prefix_length=False, normalize_newline_space_tab=False):
        if name not in metric_fns.keys():
            raise ValueError(
                f"Expected name to be either `longest_common_prefix_length` or `edit_distance`, but got {name}."
            )

        self.metric_name: MetricName = MetricName(name)
        self.metric_fn = metric_fns[name]
        self.normalize_by_prefix_length = normalize_by_prefix_length
        self.normalize_newline_space_tab = normalize_newline_space_tab
        self.tokenizer = TreebankWordTokenizer()

    def evaluate_generation(
        self,
        adapter_spec: AdapterSpec,
        request_state: RequestState,
        metric_service: MetricService,
        eval_cache_path: str,
    ) -> List[Stat]:
        """Compute the length of the longest common prefix between reference and generations.

        Result is based on number of tokens produced with `nltk.tokenize.TreebankWordTokenizer`.
        When there are multiple generations, return the length of the longest.

        **Example:**

            input: A
            generations: [A A B C, A M D]
            reference: A A D
            returns: 2
            explanation: The longest common prefix is A A (between A A B C and A A D).
        """
        references: List[Reference] = request_state.instance.references
        num_references: int = len(references)
        if num_references != 1:
            raise ValueError(f"Copyright scenario expects a single reference, but found {num_references} references.")

        prefix: str = request_state.instance.input.text
        reference: str = references[0].output.text[len(prefix) :]
        if self.normalize_newline_space_tab:
            reference = _normalize_newline_space_tab(reference)

        result: Optional[float] = None
        assert request_state.result is not None
        request_result: RequestResult = request_state.result
        for sequence in request_result.completions:
            completion: str = sequence.text.strip()
            if self.normalize_newline_space_tab:
                completion = _normalize_newline_space_tab(completion)

            # `reference` is the entire remaining book for each instance.
            # Truncate it here to be of the same length as the completion to ensure edit-distance is meaningful.
            truncated_reference: str = reference[: len(completion)]

            completion_tokens = self.tokenizer.tokenize(completion)
            truncated_reference_tokens = self.tokenizer.tokenize(truncated_reference)

            # Exploit numpy SIMD for efficiency on CPUs.
            completion_tokens = np.array(completion_tokens)
            truncated_reference_tokens = np.array(truncated_reference_tokens)

            result = self.metric_fn(completion_tokens, truncated_reference_tokens, previous_best=result)

        assert result is not None  # Should never be triggered; just to make static analyzer happy.

        final_result: float
        if self.normalize_by_prefix_length:
            prefix_tokens: List[str] = self.tokenizer.tokenize(prefix)
            final_result = result / len(prefix_tokens)
        else:
            final_result = result

        return [Stat(self.metric_name).add(final_result)]
