from math import sqrt

import torch.nn as nn
from transformers import PreTrainedModel
from transformers.models.roberta.modeling_roberta import (
    RobertaPreTrainedModel as _RobertaPreTrainedModel,
)

from zarya.model.config import DebertaConfig, RobertaConfig
from zarya.model.layers import Zero


class RobertaPreTrainedModel(_RobertaPreTrainedModel):
    config_class = RobertaConfig

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, Zero):
            module.a.data.normal_(mean=0.0, std=sqrt(self.config.initializer_range))
            module.b.data.normal_(mean=0.0, std=sqrt(self.config.initializer_range))


class DebertaPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = DebertaConfig
    base_model_prefix = "deberta"
    _keys_to_ignore_on_load_missing = ["position_ids"]
    _keys_to_ignore_on_load_unexpected = ["position_embeddings"]

    def __init__(self, config):
        super().__init__(config)
        self._register_load_state_dict_pre_hook(self._pre_load_hook)

    def _init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, Zero):
            module.a.data.normal_(mean=0.0, std=sqrt(self.config.initializer_range))
            module.b.data.normal_(mean=0.0, std=sqrt(self.config.initializer_range))

    def _pre_load_hook(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        """
        Removes the classifier if it doesn't have the correct number of labels.
        """
        self_state = self.state_dict()
        if (
            ("classifier.weight" in self_state)
            and ("classifier.weight" in state_dict)
            and self_state["classifier.weight"].size()
            != state_dict["classifier.weight"].size()
        ):
            del state_dict["classifier.weight"]
            if "classifier.bias" in state_dict:
                del state_dict["classifier.bias"]
