# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from pathlib import Path

import pandas as pd
from rouge_cli import calculate_rouge_path

from utils import calculate_rouge


PRED = [
    'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the'
    ' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe'
    " depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.",
    "The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal"
    " accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's"
    " founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the"
    " body.",
    "Amnesty International releases its annual report on the death penalty. The report catalogs the use of"
    " state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the"
    " world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital"
    " punishment.",
]

TGT = [
    'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .'
    ' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz'
    " had informed his Lufthansa training school of an episode of severe depression, airline says .",
    "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ."
    " Israel and the United States opposed the move, which could open the door to war crimes investigations against"
    " Israelis .",
    "Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to"
    " death . Organization claims that governments around the world are using the threat of terrorism to advance"
    " executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death"
    " sentences up by 28% .",
]


def test_disaggregated_scores_are_determinstic():
    no_aggregation = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2", "rougeL"])
    assert isinstance(no_aggregation, defaultdict)
    no_aggregation_just_r2 = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2"])
    assert (
        pd.DataFrame(no_aggregation["rouge2"]).fmeasure.mean()
        == pd.DataFrame(no_aggregation_just_r2["rouge2"]).fmeasure.mean()
    )


def test_newline_cnn_improvement():
    k = "rougeLsum"
    score = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=[k])[k]
    score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=[k])[k]
    assert score > score_no_sep


def test_newline_irrelevant_for_other_metrics():
    k = ["rouge1", "rouge2", "rougeL"]
    score_sep = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=k)
    score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=k)
    assert score_sep == score_no_sep


def test_single_sent_scores_dont_depend_on_newline_sep():
    pred = [
        "Her older sister, Margot Frank, died in 1945, a month earlier than previously thought.",
        'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .',
    ]
    tgt = [
        "Margot Frank, died in 1945, a month earlier than previously thought.",
        'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of'
        " the final seconds on board Flight 9525.",
    ]
    assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)


def test_pegasus_newline():
    pred = [
        """" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
    ]
    tgt = [
        """ Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says ."""
    ]

    prev_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"], newline_sep=False)["rougeLsum"]
    new_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"])["rougeLsum"]
    assert new_score > prev_score


def test_rouge_cli():
    data_dir = Path("examples/seq2seq/test_data/wmt_en_ro")
    metrics = calculate_rouge_path(data_dir.joinpath("test.source"), data_dir.joinpath("test.target"))
    assert isinstance(metrics, dict)
    metrics_default_dict = calculate_rouge_path(
        data_dir.joinpath("test.source"), data_dir.joinpath("test.target"), bootstrap_aggregation=False
    )
    assert isinstance(metrics_default_dict, defaultdict)
