# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Moderation Class"""

from __future__ import annotations

import logging
import os
from typing import Callable, Literal, overload

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.trainer_utils import EvalPrediction

from constants import PROMPT_INPUT
from utils import calculate_binary_classification_metrics, resize_tokenizer_embedding


__all__ = ['Moderation']


ProblemType = Literal[
    'regression',
    'single_label_classification',
    'multi_label_classification',
]


class Moderation(nn.Module):
    """Moderation"""

    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizerBase,
        device: torch.device | str | int | None = None,
    ) -> None:
        """Initialize the moderation model."""
        super().__init__()
        self.model: PreTrainedModel = model.to(device) if device is not None else model
        self.tokenizer: PreTrainedTokenizerBase = tokenizer

        self.id2labels: dict[int, str] = self.model.config.id2label
        self.problem_type: ProblemType = self.model.config.problem_type

    @property
    def device(self) -> torch.device:
        """the device of the model."""
        return next(self.parameters()).device

    @property
    def num_labels(self) -> int:
        """Number of labels."""
        return len(self.id2labels)

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.BoolTensor,
        labels: torch.LongTensor | None = None,
        return_dict: bool | None = None,
    ) -> SequenceClassifierOutputWithPast | tuple[torch.Tensor, ...]:
        """Forward pass of the moderation model."""
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=return_dict,
        )

    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: str | os.PathLike,
        /,
        model_max_length: int = 512,
        padding_side: Literal['left', 'right'] = 'right',
        num_labels: int | None = None,
        id2label: dict[int, str] | None = None,
        problem_type: ProblemType | None = None,
        device_map: str | dict[str, torch.device | str | int] | None = None,
        device: torch.device | str | int | None = None,
    ) -> Moderation:
        """Initialize the moderation model."""
        model_name_or_path = os.path.expanduser(model_name_or_path)

        if device_map is not None and device is not None:
            raise ValueError(
                '`device_map` and `device` cannot be specified at the same time.',
            )

        if num_labels is not None and id2label is not None and len(id2label) != num_labels:
            logging.warning(
                'You passed along `num_labels=%d` with an incompatible id to label map: %s. '
                'The number of labels will be overwritten to %d.',
                num_labels,
                id2label,
                len(id2label),
            )
            num_labels = len(id2label)

        model_kwargs = {}
        if num_labels is not None:
            model_kwargs['num_labels'] = num_labels
        if id2label is not None:
            model_kwargs['id2label'] = id2label
        if problem_type is not None:
            model_kwargs['problem_type'] = problem_type
        if device_map is not None:
            model_kwargs['device_map'] = device_map
        model_kwargs['cache_dir'] ="../../cache"
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path,
            **model_kwargs,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            model_max_length=model_max_length,
            padding_side=padding_side,
            use_fast=(model.config.model_type != 'llama'),
            cache_dir="../../cache",
        )
        resize_tokenizer_embedding(model, tokenizer)
        return cls(model, tokenizer, device)

    def compute_metrics(self, pred: EvalPrediction) -> dict[str, float]:
        """Default compute_metrics function."""
        if self.problem_type == 'multi_label_classification':
            labels = torch.from_numpy(pred.label_ids)
            predictions = torch.sigmoid(torch.from_numpy(pred.predictions)) > 0.5

            flagged_labels = labels.any(dim=-1)
            flagged_predictions = predictions.any(dim=-1)
            metrics = calculate_binary_classification_metrics(
                labels=flagged_labels,
                predictions=flagged_predictions,
            )
            metric_dict = {f'flagged/{k}': v for k, v in metrics.items()}

            for i, label_name in self.id2labels.items():
                metrics = calculate_binary_classification_metrics(
                    labels=labels[:, i],
                    predictions=predictions[:, i],
                )
                metric_dict.update({f'{label_name}/{k}': v for k, v in metrics.items()})
            return metric_dict

        return {}

    def fit(
        self,
        training_args: TrainingArguments,
        train_dataset: Dataset,
        eval_dataset: Dataset | None = None,
        data_collator: Callable | None = None,
        compute_metrics: Callable | None = None,
    ) -> None:
        """Train the model."""
        if compute_metrics is None:
            compute_metrics = self.compute_metrics

        self.model.train()

        trainer = Trainer(
            model=self.model,
            tokenizer=self.tokenizer,
            args=training_args,
            compute_metrics=compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
        )
        trainer.train(
            ignore_keys_for_eval=(
                # This is required if you have transformers < 4.30.0
                ['past_key_values']
                if self.model.config.model_type == 'llama'
                else None
            ),
        )
        trainer.evaluate(
            ignore_keys=(
                # This is required if you have transformers < 4.30.0
                ['past_key_values']
                if self.model.config.model_type == 'llama'
                else None
            ),
        )
        trainer.save_state()
        trainer.save_model(output_dir=training_args.output_dir)

    @overload
    def predict(
        self,
        text: list[str],
        batch_size: int,
        return_bool: Literal[False],
        threshold: float,
    ) -> list[dict[str, float]]:
        ...

    @overload
    def predict(
        self,
        text: list[str],
        batch_size: int,
        return_bool: Literal[True],
        threshold: float,
    ) -> list[dict[str, bool]]:
        ...

    @overload
    def predict(
        self,
        text: str,
        batch_size: int,
        return_bool: Literal[False],
        threshold: float,
    ) -> dict[str, float]:
        ...

    @overload
    def predict(
        self,
        text: str,
        batch_size: int,
        return_bool: Literal[True],
        threshold: float,
    ) -> dict[str, bool]:
        ...

    @torch.inference_mode()
    def predict(
        self,
        text: list[str] | str,
        batch_size: int = 16,
        return_bool: bool = False,
        threshold: float = 0.4,
    ) -> list[dict[str, float | bool]] | dict[str, float | bool]:
        """Predict the moderation result of the input text."""
        batched_input = not isinstance(text, str)
        if not batched_input:
            text = [text]

        text = [
            t + self.tokenizer.eos_token if not t.endswith(self.tokenizer.eos_token) else t
            for t in text
        ]

        logging.info('Tokenizing the input text...')
        model_inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors='pt',
        )
        dataset = TensorDataset(model_inputs.input_ids, model_inputs.attention_mask)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        predictions = []
        for input_ids, attention_mask in tqdm(dataloader, desc='Predicting'):
            outputs = self.model(
                input_ids=input_ids.to(self.device),
                attention_mask=attention_mask.to(self.device),
            )
            predictions.append(outputs.logits)
        predictions = torch.cat(predictions, dim=0)
        predictions = torch.sigmoid(predictions)
        flagged = predictions.max(dim=1).values > threshold

        if return_bool:
            predictions = predictions > threshold

        outputs = []
        for i, t in enumerate(text):  # pylint: disable=invalid-name
            formatted_predictions = {
                'text': t[: -len(self.tokenizer.eos_token)],
                'flagged': flagged[i].item(),
                'categories': {
                    label_name: predictions[i, label_id].item()
                    for label_id, label_name in self.id2labels.items()
                },
            }
            outputs.append(formatted_predictions)

        return outputs if batched_input else outputs[0]


class QAModeration(Moderation):
    @overload
    def predict(  # pylint: disable=arguments-differ
        self,
        question: list[str],
        answer: list[str],
        batch_size: int,
        return_bool: Literal[False],
        threshold: float,
    ) -> list[dict[str, float]]:
        ...

    @overload
    def predict(  # pylint: disable=arguments-differ
        self,
        question: list[str],
        answer: list[str],
        batch_size: int,
        return_bool: Literal[True],
        threshold: float,
    ) -> list[dict[str, bool]]:
        ...

    @overload
    def predict(  # pylint: disable=arguments-differ
        self,
        question: str,
        answer: str,
        batch_size: int,
        return_bool: Literal[False],
        threshold: float,
    ) -> dict[str, float]:
        ...

    @overload
    def predict(  # pylint: disable=arguments-differ
        self,
        question: str,
        answer: str,
        batch_size: int,
        return_bool: Literal[True],
        threshold: float,
    ) -> dict[str, bool]:
        ...

    @torch.inference_mode()
    def predict(  # pylint: disable=arguments-differ,arguments-renamed
        self,
        question: list[str] | str,
        answer: list[str] | str,
        batch_size: int = 16,
        return_bool: bool = False,
        threshold: float = 0.4,
    ) -> list[dict[str, float | bool]] | dict[str, float | bool]:
        if isinstance(question, str) != isinstance(answer, str):
            raise ValueError('`question` and `answer` must be both str or be both list of str')

        batched_input = not isinstance(question, str)
        if batched_input:
            if len(question) != len(answer):
                raise ValueError('The `question` and `answer` lists must have the same length.')
            text = [PROMPT_INPUT.format(input=q) + a for q, a in zip(question, answer)]
        else:
            text = PROMPT_INPUT.format(input=question) + answer

        return super().predict(
            text,
            batch_size=batch_size,
            return_bool=return_bool,
            threshold=threshold,
        )