# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cost model Trainer."""

from __future__ import annotations

from typing import Any

import torch
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm

from safe_rlhf.datasets import SafetyPreferenceDataset
from safe_rlhf.models import AutoModelForScore, ScoreModelOutput
from safe_rlhf.trainers import SupervisedTrainer
from safe_rlhf.utils import get_all_reduce_mean, is_main_process, split_prompt_response, to_device


class CostTrainer(SupervisedTrainer):
    """Trainer for cost model."""

    TRAINING_TYPE = 'cost'
    DATASET_TYPE = SafetyPreferenceDataset
    MODEL_TYPE = AutoModelForScore

    @property
    def extra_model_kwargs(self) -> dict[str, Any]:
        """Extra keyword arguments for initializing the model."""
        return {
            'score_type': 'cost',
            'do_normalize': self.args.normalize_score_during_training,
            'normalizer_type': self.args.normalizer_type,
            'momentum': self.args.normalizer_momentum,
        }

    @torch.no_grad()
    def eval(self) -> dict[str, Any]:  # pylint: disable=too-many-locals
        """Evaluate the model on the evaluation dataset."""
        if self.eval_dataloader is None:
            return {}

        self.set_eval()
        num_correct_predictions = 0
        num_correct_sign_predictions = 0
        num_total_predictions = 0

        eval_dataloader = tqdm(
            self.eval_dataloader,
            desc='Evaluating',
            disable=not is_main_process(),
            position=1,
            leave=False,
        )

        costs = []
        batch = None
        for batch in eval_dataloader:
            batch = to_device(batch, self.args.device)
            batch_size = batch['safer_input_ids'].size(0)

            # HINT: safer samples are supposed to have lower costs
            #       unsafer samples are supposed to have higher costs
            # size = (B,)
            lower_end_costs = self.model(
                batch['safer_input_ids'],
                attention_mask=batch['safer_attention_mask'],
            ).end_scores.squeeze(dim=-1)
            higher_end_costs = self.model(
                batch['unsafer_input_ids'],
                attention_mask=batch['unsafer_attention_mask'],
            ).end_scores.squeeze(dim=-1)

            num_correct_predictions += (higher_end_costs > lower_end_costs).sum()

            # HINT: safe samples are supposed to have negative costs
            #       unsafe samples are supposed to have positive costs
            #       safety sign: +1 for safe / -1 for unsafe
            num_correct_sign_predictions += (
                lower_end_costs * batch['safer_safety_sign'] < 0.0
            ).sum()
            num_correct_sign_predictions += (
                higher_end_costs * batch['unsafer_safety_sign'] < 0.0
            ).sum()
            num_total_predictions += batch_size

            costs.extend([lower_end_costs, higher_end_costs])

        if batch is None:
            self.logger.print('WARNING: `eval_dataloader` is empty.')
            return {}

        accuracy = num_correct_predictions / num_total_predictions
        accuracy = get_all_reduce_mean(accuracy)
        accuracy_sign = num_correct_sign_predictions / (2 * num_total_predictions)
        accuracy_sign = get_all_reduce_mean(accuracy_sign)

        # Gather costs from all devices for further analysis
        costs = torch.cat(costs, dim=0)
        if is_main_process():
            gathered_costs = [torch.empty_like(costs) for _ in range(dist.get_world_size())]
        else:
            gathered_costs = []
        dist.gather(costs, gathered_costs, dst=0)
        if is_main_process():
            costs = torch.cat(gathered_costs, dim=0)

        self.set_train()

        # Evaluation info
        info = {
            'eval/accuracy': accuracy.item(),
            'eval/accuracy_sign': accuracy_sign.item(),
            'eval/cost_mean': costs.mean().item(),
            'eval/cost_std': costs.std().item(),
        }

        if is_main_process():
            # Print some examples from the last batch
            max_num_rows = 3
            lower_cost_texts = self.tokenizer.batch_decode(
                batch['safer_input_ids'][:max_num_rows],
                skip_special_tokens=True,
            )
            higher_cost_texts = self.tokenizer.batch_decode(
                batch['unsafer_input_ids'][:max_num_rows],
                skip_special_tokens=True,
            )

            l_prompts, l_responses = split_prompt_response(lower_cost_texts)
            h_prompts, h_responses = split_prompt_response(higher_cost_texts)
            assert l_prompts == h_prompts, 'prompts do not match'
            l_is_safe = [str(l_sign >= 0) for l_sign in batch['safer_safety_sign'].tolist()]
            l_costs = [f'{cost:.6f}' for cost in lower_end_costs.tolist()]
            h_is_safe = [str(h_sign >= 0) for h_sign in batch['unsafer_safety_sign'].tolist()]
            h_costs = [f'{cost:.6f}' for cost in higher_end_costs.tolist()]

            title = ', '.join(
                f'{key.rpartition("/")[-1]} = {value:.6f}' for key, value in info.items()
            )
            self.logger.print_table(
                title=f'Evaluation: {title}',
                columns=[
                    'prompt',
                    'safer (lower-cost) response',
                    'is_safe',
                    'cost',
                    'unsafer (higher-cost) response',
                    'is_safe',
                    'cost',
                ],
                rows=tuple(
                    zip(
                        l_prompts[:max_num_rows],
                        l_responses[:max_num_rows],
                        l_is_safe[:max_num_rows],
                        l_costs[:max_num_rows],
                        h_responses[:max_num_rows],
                        h_is_safe[:max_num_rows],
                        h_costs[:max_num_rows],
                    ),
                ),
                max_num_rows=max_num_rows,
            )

        return info

    def loss(
        self,
        safer_input_ids: torch.LongTensor,  # size = (B, L)
        safer_attention_mask: torch.BoolTensor,  # size = (B, L)
        safer_safety_sign: torch.LongTensor,  # size = (B,) # +1 for safe / -1 for unsafe
        unsafer_input_ids: torch.LongTensor,  # size = (B, L)
        unsafer_attention_mask: torch.BoolTensor,  # size = (B, L)
        unsafer_safety_sign: torch.LongTensor,  # size = (B,) # +1 for safe / -1 for unsafe
    ) -> dict[str, torch.Tensor]:
        """Loss function for the cost model.

        Args:
            safer_input_ids (torch.LongTensor): The input ids of the safer examples.
            safer_attention_mask (torch.BoolTensor): The attention mask of the safer examples.
            safer_safety_sign (torch.LongTensor): The safety sign of the safer examples.
            unsafer_input_ids (torch.LongTensor): The input ids of the unsafer examples.
            unsafer_attention_mask (torch.BoolTensor): The attention mask of the unsafer examples.
            unsafer_safety_sign (torch.LongTensor): The safety sign of the unsafer examples.

        Returns:
            dict[str, torch.Tensor]: loss, higher_end_cost, lower_end_cost, accuracy
        """
        assert safer_input_ids.size(0) == unsafer_input_ids.size(0), 'batch size mismatch!'
        batch_size = safer_input_ids.size(0)

        output: ScoreModelOutput = self.model(
            torch.cat([safer_input_ids, unsafer_input_ids], dim=0),
            attention_mask=torch.cat([safer_attention_mask, unsafer_attention_mask], dim=0),
        )
        scores = output.scores  # size = (2 * B, L, 1)
        end_scores = output.end_scores  # size = (2 * B, 1)

        # Hints: safer examples are supposed to have lower costs
        #        unsafer examples are supposed to have higher costs
        # size = (B, L)
        lower_costs, higher_costs = scores.squeeze(dim=-1).chunk(chunks=2, dim=0)
        # size = (B,)
        lower_end_cost, higher_end_cost = end_scores.squeeze(dim=-1).chunk(chunks=2, dim=0)
        # safety_sign: +1 for safe / -1 for unsafe
        # cost_sign: -1 for safe / +1 for unsafe
        lower_cost_sign = -safer_safety_sign  # size = (B,)
        higher_cost_sign = -unsafer_safety_sign  # size = (B,)

        if self.args.loss_type == 'token-wise':
            losses = []
            for i in range(batch_size):
                assert not torch.all(
                    torch.eq(safer_input_ids[i], unsafer_input_ids[i]),
                ).item(), 'The safer and unsafer answers are the same!'
                lower_end_index = safer_attention_mask[i].nonzero()[-1].squeeze().item()
                higher_end_index = unsafer_attention_mask[i].nonzero()[-1].squeeze().item()
                end_index = max(higher_end_index, lower_end_index)

                diverge_index = (
                    (safer_input_ids[i] != unsafer_input_ids[i]).nonzero()[0].squeeze().item()
                )
                assert 0 <= diverge_index <= end_index, 'diverge index is out of range!'

                # size = (B, L)
                lower_truncated_costs = lower_costs[i, diverge_index : end_index + 1]
                higher_truncated_costs = higher_costs[i, diverge_index : end_index + 1]

                # safety_sign: +1 for safe / -1 for unsafe
                losses.append(
                    -F.logsigmoid(higher_truncated_costs - lower_truncated_costs).mean()
                    - F.logsigmoid(lower_cost_sign[i] * lower_truncated_costs).mean()
                    - F.logsigmoid(higher_cost_sign[i] * higher_truncated_costs).mean(),
                )

                if self.args.regularization > 0.0:
                    losses[-1] = (
                        losses[-1]
                        + self.args.regularization
                        * torch.stack([lower_truncated_costs, higher_truncated_costs])
                        .square()
                        .mean()
                    )

            loss = torch.stack(losses).mean()  # size = ()
        elif self.args.loss_type == 'sequence-wise':
            loss = (
                -F.logsigmoid(higher_end_cost - lower_end_cost)
                - F.logsigmoid(lower_cost_sign * lower_end_cost)
                - F.logsigmoid(higher_cost_sign * higher_end_cost)
            ).mean()

            if self.args.regularization > 0.0:
                loss = (
                    loss
                    + self.args.regularization
                    * torch.stack([lower_end_cost, higher_end_cost]).square().mean()
                )
        else:
            raise ValueError(f'Unknown loss type: {self.args.loss_type}')

        accuracy = (higher_end_cost > lower_end_cost).float().mean()  # size = ()
        accuracy_sign = (  # size = ()
            torch.stack(
                [
                    lower_cost_sign * lower_end_cost > 0.0,
                    higher_cost_sign * higher_end_cost > 0.0,
                ],
            )
            .float()
            .mean()
        )
        return {
            'loss': loss,  # size = ()
            'higher_end_cost': higher_end_cost,  # size = (B,)
            'lower_end_cost': lower_end_cost,  # size = (B,)
            'higher_costs': higher_costs,  # size = (B, L)
            'lower_costs': lower_costs,  # size = (B, L)
            'accuracy': accuracy,  # size = ()
            'accuracy_sign': accuracy_sign,  # size = ()
        }

    def train_step(
        self,
        safer_input_ids: torch.LongTensor,  # size = (B, L)
        safer_attention_mask: torch.BoolTensor,  # size = (B, L)
        safer_safety_sign: torch.LongTensor,  # size = (B,) # +1 for safe / -1 for unsafe
        unsafer_input_ids: torch.LongTensor,  # size = (B, L)
        unsafer_attention_mask: torch.BoolTensor,  # size = (B, L)
        unsafer_safety_sign: torch.LongTensor,  # size = (B,) # +1 for safe / -1 for unsafe
    ) -> dict[str, Any]:
        """Perform a single training step.

        Args:
            safer_input_ids (torch.LongTensor): The input ids of the safer examples.
            safer_attention_mask (torch.BoolTensor): The attention mask of the safer examples.
            safer_safety_sign (torch.LongTensor): The safety sign of the safer examples.
            unsafer_input_ids (torch.LongTensor): The input ids of the unsafer examples.
            unsafer_attention_mask (torch.BoolTensor): The attention mask of the unsafer examples.
            unsafer_safety_sign (torch.LongTensor): The safety sign of the unsafer examples.

        Returns:
            dict[str, Any]: training loss, training accuracy, learning rate
        """
        loss_dict = self.loss(
            safer_input_ids=safer_input_ids,
            safer_attention_mask=safer_attention_mask,
            safer_safety_sign=safer_safety_sign,
            unsafer_input_ids=unsafer_input_ids,
            unsafer_attention_mask=unsafer_attention_mask,
            unsafer_safety_sign=unsafer_safety_sign,
        )
        loss = loss_dict['loss']
        self.model.backward(loss)
        self.model.step()

        accuracy = loss_dict['accuracy']
        accuracy_sign = loss_dict['accuracy_sign']

        loss = get_all_reduce_mean(loss)
        accuracy = get_all_reduce_mean(accuracy)
        accuracy_sign = get_all_reduce_mean(accuracy_sign)

        return {
            'train/loss': loss.item(),
            'train/accuracy': accuracy.item(),
            'train/accuracy_sign': accuracy_sign.item(),
            'train/lr': self.model.optimizer.param_groups[0]['lr'],
        }
