from typing import Any, Dict, List, Optional

import rouge
from torch import LongTensor
from transformers import GenerationConfig

from ._base import BaseTask
from ._utils.generation_utils import postprocess_generation_ids


class TextSummarizationTask(BaseTask):
    def __init__(
        self,
        model,
        tokenizer,
        data_name_or_path: str,
        prompt_col_name: str,
        label_col_name: str,
        device: Optional[str] = None,
        **kwargs,
    ):
        kwargs["merge_prompt_label"] = False
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            data_name_or_path=data_name_or_path,
            prompt_col_name=prompt_col_name,
            label_col_name=label_col_name,
            device=device,
            **kwargs,
        )

    def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[str]:
        generation_config = kwargs["generation_config"]
        output_ids = self.model.generate(
            input_ids=batch_data["input_ids"],
            attention_mask=batch_data["attention_mask"],
            generation_config=generation_config,
        )
        return [
            each[0].lower().strip()
            for each in postprocess_generation_ids(
                input_ids=batch_data["input_ids"],
                output_ids=output_ids,
                num_return_sequences=generation_config.num_return_sequences,
                tokenizer=self.tokenizer,
            )
        ]

    def _parse_labels(self, label_ids: LongTensor) -> List[str]:
        labels = []
        for one_label_ids in label_ids:
            one_label_ids = one_label_ids[(one_label_ids == -100).sum() :]
            label = self.tokenizer.decode(one_label_ids).lower().strip()
            labels.append(label)

        return labels

    def _metric(self, pred: List[Any], label: List[Any]) -> Dict[str, Dict[str, float]]:
        metric = rouge.Rouge()
        return metric.get_scores(hyps=pred, refs=label, avg=True)

    def run(self, generation_config: Optional[GenerationConfig] = None) -> Dict[str, float]:
        if not generation_config:
            generation_config = GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=128)
        generation_config.num_return_sequences = 1
        generation_config.eos_token_id = self.tokenizer.eos_token_id
        generation_config.pad_token_id = self.tokenizer.pad_token_id
        return super().run(generation_config=generation_config)


__all__ = ["TextSummarizationTask"]
