# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from constants import DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_PAD_TOKEN, DEFAULT_UNK_TOKEN


# Reference: https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py
def resize_tokenizer_embedding(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
) -> None:
    """Resize tokenizer and embedding."""

    special_tokens_dict = {}
    if tokenizer.pad_token is None:
        special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN

    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    model.config.bos_token_id = tokenizer.bos_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id

    if num_new_tokens > 0:
        if model.get_input_embeddings() is not None:
            input_embeddings = model.get_input_embeddings().weight.data
            input_embeddings_mean = input_embeddings[:-num_new_tokens].mean(
                dim=0,
                keepdim=True,
            )
            input_embeddings[-num_new_tokens:] = input_embeddings_mean

        if model.get_output_embeddings() is not None:
            output_embeddings = model.get_output_embeddings().weight.data
            output_embeddings_mean = output_embeddings[:-num_new_tokens].mean(
                dim=0,
                keepdim=True,
            )
            output_embeddings[-num_new_tokens:] = output_embeddings_mean


def calculate_binary_classification_metrics(
    labels: torch.Tensor,
    predictions: torch.Tensor,
    epsilon: float = 1e-8,
) -> dict[str, float]:
    """Calculate binary classification metrics."""
    assert (
        labels.shape == predictions.shape
    ), 'The shapes of labels and predictions should be the same.'

    tp = ((labels == 1) & (predictions == 1)).sum().item()  # pylint: disable=invalid-name
    fp = ((labels == 0) & (predictions == 1)).sum().item()  # pylint: disable=invalid-name
    tn = ((labels == 0) & (predictions == 0)).sum().item()  # pylint: disable=invalid-name
    fn = ((labels == 1) & (predictions == 0)).sum().item()  # pylint: disable=invalid-name
    accuracy = (tp + tn) / (tp + fp + tn + fn)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)
    f1 = 2 * precision * recall / (precision + recall + epsilon)  # pylint: disable=invalid-name
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }