
"""Auto-models for Score LM which have both score and language model heads."""

from __future__ import annotations

import functools
import importlib
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers.models.auto as auto_module
from torch import distributed as dist
from transformers import PretrainedConfig
from transformers.models.auto.auto_factory import (
    _BaseAutoModelClass,
    _LazyAutoMapping,
    auto_class_update,
    getattribute_from_module,
)
from transformers.models.auto.configuration_auto import (
    CONFIG_MAPPING_NAMES,
    model_type_to_module_name,
)
from transformers.utils.generic import ModelOutput

from safe_rlhf.models.normalizer import NormalizeFunction, Normalizer


class _LazyAutoMappingInScoreLM(_LazyAutoMapping):
    def _load_attr_from_module(self, model_type: str, attr: str) -> Any:
        module_name = model_type_to_module_name(model_type)
        if module_name not in self._modules:
            self._modules[module_name] = importlib.import_module(
                f'.{module_name}',
                'safe_rlhf.models.score_lm',
            )
        return getattribute_from_module(self._modules[module_name], attr)


MODEL_FOR_SCORE_LM_MAPPING_NAMES: OrderedDict[str, str] = OrderedDict(
    [
        # Score model mapping
        ('llama', 'LlamaForScoreLM'),
        ('gpt_neox', 'GPTNeoXForScoreLM'),
        ('gpt2', 'GPT2ForScoreLM'),
    ],
)
MODEL_FOR_SCORE_LM_MAPPING: OrderedDict[str, Any] = _LazyAutoMappingInScoreLM(
    CONFIG_MAPPING_NAMES,
    MODEL_FOR_SCORE_LM_MAPPING_NAMES,
)


@functools.partial(auto_class_update, head_doc='score lm')
class AutoModelForScoreLM(_BaseAutoModelClass):
    _model_mapping: OrderedDict[str, Any] = MODEL_FOR_SCORE_LM_MAPPING


auto_module.MODEL_FOR_SCORE_LM_MAPPING = MODEL_FOR_SCORE_LM_MAPPING
setattr(auto_module, AutoModelForScoreLM.__name__, AutoModelForScoreLM)


@dataclass
class ScoreLMOutput(ModelOutput):
    """
    Output of the Score LM.

    Args:
        scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, score_dim)`):
            Prediction scores of each token in the sequence.
        end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim)`):
            Prediction scores of the end of the sequence.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_dim)`):
            Sequence of hidden-states at the output of the last layer of the model.
        end_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_dim)`):
            Last hidden state of the sequence at the output of the last layer of the model.
        end_index (`torch.LongTensor` of shape `(batch_size,)`):
            Indices of the end of the sequence.
    """

    scores: torch.FloatTensor | None = None  # size = (B, L, D)
    end_scores: torch.FloatTensor | None = None  # size = (B, D)
    logits: torch.FloatTensor | None = None  # size = (B, L, V)
    last_hidden_state: torch.FloatTensor | None = None  # size = (B, L, E)
    end_last_hidden_state: torch.FloatTensor | None = None  # size = (B, E)
    end_index: torch.LongTensor | None = None  # size = (B,)
    past_key_values: tuple[tuple[torch.FloatTensor]] | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None


class ScoreLMMixin:
    """Base class for score models."""

    score_head: nn.Linear
    lm_head: nn.Linear
    normalizer: Normalizer
    do_normalize: bool = False
    normalize_function: NormalizeFunction = 'affine'
    _is_score_head_initialized: bool = False

    def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: Any) -> None:
        """Initialize the score head."""
        if self._is_score_head_initialized:
            return

        self.score_dim = config.score_dim = kwargs.pop(
            'score_dim',
            getattr(config, 'score_dim', 1),
        )
        self.score_bias = config.score_bias = kwargs.pop(
            'score_bias',
            getattr(config, 'score_bias', True),
        )
        self.freeze_transformer_layers = kwargs.pop('freeze_transformer_layers', True)

        self.score_head = nn.Linear(hidden_size, config.score_dim, bias=config.score_bias)
        self.lm_head = nn.Linear(hidden_size, config.vocab_size, bias=False)

        if config.score_bias:
            nn.init.zeros_(self.score_head.bias)

        config.score_type = kwargs.pop('score_type', getattr(config, 'score_type', 'reward'))
        if config.score_type == 'reward':
            self.normalize_function = 'affine'
        elif config.score_type == 'cost':
            self.normalize_function = 'scale'
        elif config.score_type == 'critic':
            self.normalize_function = 'identity'
        else:
            raise ValueError(
                f"Invalid score type: {config.score_type}. Expected one of 'reward', 'cost', or 'critic'.",
            )

        self.do_normalize = config.do_normalize = kwargs.pop(
            'do_normalize',
            getattr(config, 'do_normalize', False),
        )

        config.normalizer_type = kwargs.pop(
            'normalizer_type',
            getattr(config, 'normalizer_type', None),
        )
        if config.normalizer_type not in {'RunningMeanStd', 'ExponentialMovingAverage', None}:
            raise ValueError(
                f'Invalid norm type: {config.normalizer_type}.'
                "Expected one of 'RunningMeanStd', 'ExponentialMovingAverage', or None.",
            )
        if config.normalizer_type == 'ExponentialMovingAverage':
            config.momentum = kwargs.pop('momentum', getattr(config, 'momentum', None))
        momentum = getattr(config, 'momentum', None)
        self.normalizer = Normalizer.instantiate(
            normalizer_type=config.normalizer_type,
            normalize_function=self.normalize_function,
            shape=(config.score_dim,),
            momentum=momentum,
        )

        mean = getattr(config, 'mean', None)
        var = getattr(config, 'var', None)
        self.normalizer.set_mean_var(mean, var)

        self._is_score_head_initialized = True

    def get_logits(
        self,
        last_hidden_state: torch.FloatTensor,
    ) -> torch.FloatTensor:
        """Forward pass of the language model head."""
        if self.freeze_transformer_layers:
            logits = self.lm_head(last_hidden_state.detach()).float()  # size = (B, L, V)
        else:
            logits = self.lm_head(last_hidden_state).float()  # size = (B, L, V)
        return logits

    def get_scores(
        self,
        last_hidden_state: torch.FloatTensor,  # size = (B, L, E)
        attention_mask: torch.BoolTensor | None = None,  # size = (B, L)
        return_dict: bool | None = None,
        past_key_values: tuple[tuple[torch.FloatTensor]] | None = None,
        hidden_states: tuple[torch.FloatTensor, ...] | None = None,
        attentions: tuple[torch.FloatTensor, ...] | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor] | ScoreLMOutput:
        """Forward pass of the score model."""
        B, L, E = last_hidden_state.size()

        if attention_mask is None:
            if B > 1:
                raise ValueError("'attention_mask' is required when batch size > 1.")
            attention_mask = last_hidden_state.new_ones(B, L, dtype=torch.bool)  # size = (B, L)

        scores = self.score_head(last_hidden_state).float()  # size = (B, L, D)
        logits = self.get_logits(last_hidden_state)

        end_index = torch.cat([m.nonzero()[-1] for m in attention_mask])  # size = (B,)
        end_last_hidden_state = torch.gather(  # size = (B, 1, E)
            last_hidden_state,
            dim=1,
            index=(
                end_index.to(last_hidden_state.device)
                .unsqueeze(dim=1)
                .unsqueeze(dim=2)
                .expand(-1, -1, last_hidden_state.size(-1))
            ),
        )
        end_scores = torch.gather(  # size = (B, 1, D)
            scores,
            dim=1,
            index=(
                end_index.to(scores.device)
                .unsqueeze(dim=1)
                .unsqueeze(dim=2)
                .expand(-1, -1, scores.size(-1))
            ),
        )
        end_last_hidden_state = end_last_hidden_state.squeeze(dim=1)  # size = (B, E)
        end_scores = end_scores.squeeze(dim=1)  # size = (B, D)

        if self.training:
            if dist.is_initialized():
                gathered_end_scores_list = [
                    torch.zeros_like(end_scores) for _ in range(dist.get_world_size())
                ]
                dist.all_gather(gathered_end_scores_list, end_scores)
                gathered_end_scores = torch.cat(gathered_end_scores_list, dim=0)
                self.normalizer.update(gathered_end_scores)
            else:
                self.normalizer.update(end_scores)
            self.config.mean = self.normalizer.mean.tolist()
            self.config.var = self.normalizer.var.tolist()

        if self.do_normalize:
            scores = self.normalizer.normalize(scores)
            end_scores = self.normalizer.normalize(end_scores)

        if not return_dict:
            return scores, end_scores

        return ScoreLMOutput(
            scores=scores,  # size = (B, L, D)
            end_scores=end_scores,  # size = (B, D)
            logits=logits,  # size = (B, L, V)
            last_hidden_state=last_hidden_state,  # size = (B, L, E)
            end_last_hidden_state=end_last_hidden_state,  # size = (B, E)
            end_index=end_index,  # size = (B,)
            past_key_values=past_key_values,
            hidden_states=hidden_states,
            attentions=attentions,
        )

    def set_normalize(self, mode: bool = True) -> None:
        if self.do_normalize == mode:
            return

        self.do_normalize = self.config.do_normalize = mode
