"""Auto-models for score models."""

from __future__ import annotations

import functools
import importlib
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn
from torch import distributed as dist
from transformers import PretrainedConfig
from transformers.models.auto.auto_factory import (
    _BaseAutoModelClass,
    _LazyAutoMapping,
    auto_class_update,
    getattribute_from_module,
)
from transformers.models.auto.configuration_auto import (
    CONFIG_MAPPING_NAMES,
    model_type_to_module_name,
)
from transformers.utils.generic import ModelOutput

from safe_rlhf.models.normalizer import NormalizeFunction, Normalizer


class _LazyAutoMappingInSafeRLHF(_LazyAutoMapping):
    def _load_attr_from_module(self, model_type: str, attr: str) -> Any:
        module_name = model_type_to_module_name(model_type)
        if module_name not in self._modules:
            self._modules[module_name] = importlib.import_module(
                f'.{module_name}',
                'safe_rlhf.models.score_model',
            )
        return getattribute_from_module(self._modules[module_name], attr)


MODEL_FOR_SCROE_MAPPING_NAMES: OrderedDict[str, str] = OrderedDict(
    [
        # Score model mapping
        ('llama', 'LlamaModelForScore'),
        ('bloom', 'BloomModelForScore'),
        ('open_llama', 'OpenLlamaForScore'),
        ('opt', 'OPTForScore'),
        ('gpt_neo', 'GPTNeoForScore'),
        ('gptj', 'GPTJForScore'),
        ('gpt2', 'GPT2ForScore'),
        ('gpt_neox', 'GPTNeoXForScore'),
    ],
)
MODEL_FOR_SCORE_MAPPING: OrderedDict[str, Any] = _LazyAutoMappingInSafeRLHF(
    CONFIG_MAPPING_NAMES,
    MODEL_FOR_SCROE_MAPPING_NAMES,
)


@functools.partial(auto_class_update, head_doc='score model')
class AutoModelForScore(_BaseAutoModelClass):
    _model_mapping: OrderedDict[str, Any] = MODEL_FOR_SCORE_MAPPING


@dataclass
class ScoreModelOutput(ModelOutput):
    """
    Output of the score model.

    Args:
        scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim, sequence_length)`):
            Prediction scores of the score model.
        end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim)`):
            Prediction scores of the end of the sequence.
    """

    scores: torch.Tensor | None = None  # size = (B, L, D)
    end_scores: torch.Tensor | None = None  # size = (B, D)


class ScoreModelMixin:
    """Base class for score models."""

    score_head: nn.Linear
    normalizer: Normalizer
    do_normalize: bool = False
    normalize_function: NormalizeFunction = 'affine'
    _initialized: bool = False

    def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: Any) -> None:
        """Initialize the score head."""
        if self._initialized:
            return

        config.score_dim = kwargs.pop('score_dim', getattr(config, 'score_dim', 1))
        config.bias = kwargs.pop('bias', getattr(config, 'bias', False))

        config.score_type = kwargs.pop('score_type', getattr(config, 'score_type', 'reward'))
        if config.score_type == 'reward':
            self.normalize_function = 'affine'
        elif config.score_type == 'cost':
            self.normalize_function = 'scale'
        elif config.score_type == 'critic':
            self.normalize_function = 'identity'
        else:
            raise ValueError(
                f"Invalid score type: {config.score_type}. Expected one of 'reward', 'cost', or 'critic'.",
            )

        config.do_normalize = kwargs.pop(
            'do_normalize',
            getattr(config, 'do_normalize', False),
        )
        self.do_normalize = config.do_normalize

        config.normalizer_type = kwargs.pop(
            'normalizer_type',
            getattr(config, 'normalizer_type', None),
        )
        if config.normalizer_type not in {'RunningMeanStd', 'ExponentialMovingAverage', None}:
            raise ValueError(
                f'Invalid norm type: {config.normalizer_type}.'
                "Expected one of 'RunningMeadStd', 'ExponentialMovingAverage', or None.",
            )
        if config.normalizer_type == 'ExponentialMovingAverage':
            config.momentum = kwargs.pop('momentum', getattr(config, 'momentum', None))
        momentum = getattr(config, 'momentum', None)

        self.score_head = nn.Linear(hidden_size, config.score_dim, bias=config.bias)
        self.normalizer = Normalizer.instantiate(
            normalizer_type=config.normalizer_type,
            normalize_function=self.normalize_function,
            shape=(config.score_dim,),
            momentum=momentum,
        )

        mean = getattr(config, 'mean', None)
        var = getattr(config, 'var', None)
        self.normalizer.set_mean_var(mean, var)

        self._initialized = True

    def get_score(
        self,
        hidden_state: torch.Tensor,  # size = (B, L, E)
        attention_mask: torch.BoolTensor,  # size = (B, L)
        return_dict: bool | None = None,
    ) -> ScoreModelOutput:
        """Forward pass of the score model."""
        scores = self.score_head(hidden_state)  # size = (B, L, D)

        end_score = []
        for i in range(hidden_state.size(0)):
            end_index = attention_mask[i].nonzero()[-1].item()
            end_score.append(scores[i, end_index])  # size = (D,)
        end_score = torch.stack(end_score, dim=0)  # size = (B, D)

        if self.training:
            if dist.is_initialized():
                gathered_end_score_list = [
                    torch.zeros_like(end_score) for _ in range(dist.get_world_size())
                ]
                dist.all_gather(gathered_end_score_list, end_score)
                gathered_end_score = torch.cat(gathered_end_score_list, dim=0)
                self.normalizer.update(gathered_end_score)
            else:
                self.normalizer.update(end_score)
            self.config.mean = self.normalizer.mean.tolist()
            self.config.var = self.normalizer.var.tolist()

        if self.do_normalize:
            scores = self.normalizer.normalize(scores)
            end_score = self.normalizer.normalize(end_score)

        if not return_dict:
            return scores, end_score

        return ScoreModelOutput(
            scores=scores,  # size = (B, L, D)
            end_scores=end_score,  # size = (B, D)
        )

    def set_normalize(self, mode: bool = True) -> None:
        if self.do_normalize == mode:
            return

        self.do_normalize = self.config.do_normalize = mode
