"""Reward model Trainer."""

from __future__ import annotations

from typing import Any

import torch
import torch.distributed as dist
import torch.nn.functional as F
from tqdm import tqdm

from safe_rlhf.datasets import PreferenceDataset
from safe_rlhf.models import AutoModelForScore, ScoreModelOutput
from safe_rlhf.trainers import SupervisedTrainer
from safe_rlhf.utils import get_all_reduce_mean, is_main_process, split_prompt_response, to_device


class RewardTrainer(SupervisedTrainer):
    """Trainer for reward model."""

    TRAINING_TYPE = 'reward'
    DATASET_TYPE = PreferenceDataset
    MODEL_TYPE = AutoModelForScore

    @property
    def extra_model_kwargs(self) -> dict[str, Any]:
        """Extra keyword arguments for initializing the model."""
        return {
            'score_type': 'reward',
            'do_normalize': self.args.normalize_score_during_training,
            'normalizer_type': self.args.normalizer_type,
            'momentum': self.args.normalizer_momentum,
        }

    @torch.no_grad()
    def eval(self) -> dict[str, Any]:
        """Evaluate the model on the evaluation dataset."""
        if self.eval_dataloader is None:
            return {}

        self.set_eval()
        num_correct_predictions = 0
        num_total_predictions = 0

        eval_dataloader = tqdm(
            self.eval_dataloader,
            desc='Evaluating',
            disable=not is_main_process(),
            position=1,
            leave=False,
        )

        rewards = []
        batch = None
        for batch in eval_dataloader:
            batch = to_device(batch, self.args.device)
            batch_size = batch['better_input_ids'].size(0)
            higher_end_rewards = self.model(
                batch['better_input_ids'],
                attention_mask=batch['better_attention_mask'],
            ).end_scores.squeeze(dim=-1)
            lower_end_rewards = self.model(
                batch['worse_input_ids'],
                attention_mask=batch['worse_attention_mask'],
            ).end_scores.squeeze(dim=-1)

            num_correct_predictions += (higher_end_rewards > lower_end_rewards).sum()
            num_total_predictions += batch_size

            rewards.extend([higher_end_rewards, lower_end_rewards])

        if batch is None:
            self.logger.print('WARNING: `eval_dataloader` is empty.')
            return {}

        accuracy = num_correct_predictions / num_total_predictions
        accuracy = get_all_reduce_mean(accuracy)

        # Gather rewards from all devices for further analysis
        rewards = torch.cat(rewards, dim=0)
        if is_main_process():
            gathered_rewards = [torch.empty_like(rewards) for _ in range(dist.get_world_size())]
        else:
            gathered_rewards = []
        dist.gather(rewards, gathered_rewards, dst=0)
        if is_main_process():
            rewards = torch.cat(gathered_rewards, dim=0)

        self.set_train()

        # Evaluation info
        info = {
            'eval/accuracy': accuracy.item(),
            'eval/rewards_mean': rewards.mean().item(),
            'eval/rewards_std': rewards.std().item(),
        }

        if is_main_process():
            # Print some examples from the last batch
            max_num_rows = 3
            higher_reward_texts = self.tokenizer.batch_decode(
                batch['better_input_ids'][:max_num_rows],
                skip_special_tokens=True,
            )
            lower_reward_texts = self.tokenizer.batch_decode(
                batch['worse_input_ids'][:max_num_rows],
                skip_special_tokens=True,
            )

            h_prompts, h_responses = split_prompt_response(higher_reward_texts)
            l_prompts, l_responses = split_prompt_response(lower_reward_texts)
            assert h_prompts == l_prompts, 'prompts are not the same'
            h_rewards = [f'{reward:.6f}' for reward in higher_end_rewards.tolist()]
            l_rewards = [f'{reward:.6f}' for reward in lower_end_rewards.tolist()]

            title = ', '.join(
                f'{key.rpartition("/")[-1]} = {value:.6f}' for key, value in info.items()
            )
            self.logger.print_table(
                title=f'Evaluation: {title}',
                columns=[
                    'prompt',
                    'higher-reward response',
                    'reward',
                    'lower-reward response',
                    'reward',
                ],
                rows=tuple(
                    zip(
                        h_prompts[:max_num_rows],
                        h_responses[:max_num_rows],
                        h_rewards[:max_num_rows],
                        l_responses[:max_num_rows],
                        l_rewards[:max_num_rows],
                    ),
                ),
                max_num_rows=max_num_rows,
            )

        return info

    def loss(
        self,
        better_input_ids: torch.LongTensor,  # size = (B, L)
        better_attention_mask: torch.BoolTensor,  # size = (B, L)
        worse_input_ids: torch.LongTensor,  # size = (B, L)
        worse_attention_mask: torch.BoolTensor,  # size = (B, L)
    ) -> dict[str, torch.Tensor]:
        """Loss function for the reward model.

        Args:
            better_input_ids (torch.LongTensor): The input ids of the better answer.
            better_attention_mask (torch.BoolTensor): The attention mask of the better answer.
            worse_input_ids (torch.LongTensor): The input ids of the worse answer.
            worse_attention_mask (torch.BoolTensor): The attention mask of the worse answer.

        Returns:
            dict[str, torch.Tensor]: loss, higher_end_rewards, lower_end_rewards, accuracy
        """
        assert better_input_ids.size(0) == worse_input_ids.size(0), 'batch size mismatch!'
        batch_size = better_input_ids.size(0)

        output: ScoreModelOutput = self.model(
            torch.cat([better_input_ids, worse_input_ids], dim=0),
            attention_mask=torch.cat([better_attention_mask, worse_attention_mask], dim=0),
        )
        scores = output.scores  # size = (2 * B, L, 1)
        end_scores = output.end_scores  # size = (2 * B, 1)
        # size = (B, L)
        higher_rewards, lower_rewards = scores.squeeze(dim=-1).chunk(chunks=2, dim=0)
        # size = (B,)
        higher_end_rewards, lower_end_rewards = end_scores.squeeze(dim=-1).chunk(chunks=2, dim=0)

        if self.args.loss_type == 'token-wise':
            losses = []
            for i in range(batch_size):
                assert not torch.all(
                    torch.eq(better_input_ids[i], worse_input_ids[i]),
                ).item(), 'The better and worse answers are the same!'
                higher_end_index = better_attention_mask[i].nonzero()[-1]
                lower_end_index = worse_attention_mask[i].nonzero()[-1]
                end_index = max(higher_end_index, lower_end_index)

                diverge_index = (better_input_ids[i] != worse_input_ids[i]).nonzero()[0]
                assert 0 <= diverge_index <= end_index, 'diverge index is out of range!'

                # size = (L,)
                higher_truncated_rewards = higher_rewards[i, diverge_index : end_index + 1]
                lower_truncated_rewards = lower_rewards[i, diverge_index : end_index + 1]

                losses.append(
                    -F.logsigmoid(higher_truncated_rewards - lower_truncated_rewards).mean(),
                )

                if self.args.regularization > 0.0:
                    losses[-1] = losses[-1] + self.args.regularization * (
                        torch.square(lower_truncated_rewards).mean()
                        + torch.square(higher_truncated_rewards).mean()
                    )

            loss = torch.stack(losses).mean()  # size = ()
        elif self.args.loss_type == 'sequence-wise':
            loss = -F.logsigmoid(higher_end_rewards - lower_end_rewards).mean()

            if self.args.regularization > 0.0:
                loss = loss + self.args.regularization * (
                    torch.square(lower_end_rewards).mean() + torch.square(higher_end_rewards).mean()
                )
        else:
            raise ValueError(f'Unknown loss type: {self.args.loss_type}')

        accuracy = (higher_end_rewards > lower_end_rewards).float().mean()  # size = ()
        return {
            'loss': loss,  # size = ()
            'higher_end_rewards': higher_end_rewards,  # size = (B,)
            'lower_end_rewards': lower_end_rewards,  # size = (B,)
            'accuracy': accuracy,  # size = ()
        }

    def train_step(
        self,
        better_input_ids: torch.LongTensor,  # size = (B, L)
        better_attention_mask: torch.BoolTensor,  # size = (B, L)
        worse_input_ids: torch.LongTensor,  # size = (B, L)
        worse_attention_mask: torch.BoolTensor,  # size = (B, L)
    ) -> dict[str, Any]:
        """Perform a single training step.

        Args:
            better_input_ids (torch.LongTensor): The input ids of the better answer.
            better_attention_mask (torch.BoolTensor): The attention mask of the better answer.
            worse_input_ids (torch.LongTensor): The input ids of the worse answer.
            worse_attention_mask (torch.BoolTensor): The attention mask of the worse answer.

        Returns:
            dict[str, Any]: training loss, training accuracy, learning rate
        """
        loss_dict = self.loss(
            better_input_ids=better_input_ids,
            better_attention_mask=better_attention_mask,
            worse_input_ids=worse_input_ids,
            worse_attention_mask=worse_attention_mask,
        )
        loss = loss_dict['loss']
        self.model.backward(loss)
        self.model.step()

        accuracy = loss_dict['accuracy']

        loss = get_all_reduce_mean(loss)
        accuracy = get_all_reduce_mean(accuracy)

        return {
            'train/loss': loss.item(),
            'train/accuracy': accuracy.item(),
            'train/lr': self.model.optimizer.param_groups[0]['lr'],
        }
