from typing import Optional

import torch
import torch.nn as nn


class DtiTokenEmbedding(nn.Embedding):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int,
        reparameterize: bool = True,
        legacy: bool = False,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(
            num_embeddings,
            embedding_dim,
            padding_idx=padding_idx,
            *args,
            **kwargs,
        )
        self.original_vocab_size = num_embeddings
        self.reparameterize = reparameterize
        self.legacy = legacy

        self.mean_scale = None
        self.max_scale = None
        self.min_scale = None

        if self.reparameterize:
            self.scales = nn.Embedding(num_embeddings, 1, padding_idx)
            nn.init.zeros_(self.scales.weight)

    def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
        inputs_embeds = super().forward(input_ids)
        index_rescale = input_ids >= self.original_vocab_size
        if self.reparameterize and index_rescale.any():
            # rescale_ids = input_ids[index_rescale] - self.original_vocab_size
            if not self.legacy:
                inputs_scales = self.scales(input_ids)
                inputs_embeds[index_rescale] = (
                    inputs_scales[index_rescale] * inputs_embeds[index_rescale]
                )
            else:
                inputs_scales = self.scales(input_ids)
                norm = torch.linalg.vector_norm(
                    inputs_embeds[index_rescale],
                    dim=-1,
                    keepdim=True,
                ).to(dtype=inputs_embeds.dtype)
                inputs_embeds[index_rescale] = (
                    inputs_scales[index_rescale] * inputs_embeds[index_rescale] / norm
                )
        return inputs_embeds

    @torch.no_grad()
    def get_min_scale(self) -> float:
        if self.min_scale is None:
            scale = torch.norm(
                self.weight,
                dim=-1,
                keepdim=True,
            )
            # Discard 0 scale
            scale = scale[scale > 0]
            print("Non-zero scale:", scale.shape)
            self.min_scale = scale.min().item()
        return self.min_scale

    @torch.no_grad()
    def get_mean_scale(self) -> float:
        if self.mean_scale is None:
            scale = torch.norm(
                self.weight,
                dim=-1,
                keepdim=True,
            )
            # Discard 0 scale
            scale = scale[scale > 0]
            print("Non-zero scale:", scale.shape)
            self.mean_scale = scale.mean().item()
        return self.mean_scale

    @torch.no_grad()
    def get_max_scale(self) -> float:
        if self.max_scale is None:
            scale = torch.norm(
                self.weight,
                dim=-1,
                keepdim=True,
            )
            self.max_scale = scale.max().item()
        return self.max_scale

    def resize_token_scales(
        self,
        new_num_tokens: Optional[int] = None,
    ) -> None:
        if not self.reparameterize:
            return

        if new_num_tokens is None:
            new_num_tokens = len(self.weight)

        # Update token scale embeddings.
        # This is needed because the token scale embeddings are not tied to the token.
        scales = self.scales.weight.data
        old_num_tokens, old_dim = scales.shape
        if new_num_tokens is None:
            new_num_tokens = old_num_tokens
        if new_num_tokens > old_num_tokens:
            scales = torch.cat(
                [scales, scales.new_zeros(new_num_tokens - old_num_tokens, old_dim)]
            )
        elif new_num_tokens < old_num_tokens:
            scales = scales[:new_num_tokens]
        self.scales.weight.data = scales
