from typing import Tuple

import logging

import torch
import torch.nn as nn


class Codebook(nn.Module):
    def __init__(
        self,
        size: int,
        dim: int,
        detach_input_seq: bool = True,
        uniform_range: Tuple[float, float] = [-1, 1]
    ):
        """
        Args:
            size: size of codebook
            dim: code dimension
            detach_input_seq: if `True`, input sequence for encodeing will be detached
        """
        super().__init__()
        self.logger = logging.getLogger("codebook")
        self.size = size
        self.dim = dim
        self.detach_input_seq = detach_input_seq

        self.logger.info("Creating codebook with size: %d, dimension: %d", size, dim)

        self.codes = nn.Embedding(size, dim)
        self._reset_parameters(uniform_range)
        self.activate()

    def _reset_parameters(self, uniform_range: Tuple[float, float]):
        self.logger.info("Initializing with Uniform[%.2f, %.2f]", uniform_range[0], uniform_range[1])
        nn.init.uniform_(self.codes.weight, uniform_range[0], uniform_range[1])

    def deactivate(self):
        self.logger.debug("Deactivated codebook!")
        self._activate = False

    def activate(self):
        self.logger.debug("Activated codebook!")
        self._activate = True

    def initial_codes(self, code_fp: str):
        self.logger.info("Loading from external codes...")
        codes: torch.Tensor = torch.load(code_fp, map_location="cpu")
        if codes.shape[0] > self.size:
            self.logger.warning("Too much external codes, using random picked codes...")
            rand_perm = torch.randperm(codes.shape[0])
            codes = codes[rand_perm][:self.size]
        with torch.no_grad():
            self.codes.weight.copy_(codes)

    def encode(self, seq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.detach_input_seq:
            seq = seq.detach()
        n, bs = seq.shape[:2]
        # [n, bs, dim] -> [n * bs, dim]
        seq = seq.reshape(n * bs, self.dim)
        # [n * bs, self.size] distance matrix
        match = torch.cdist(seq, self.codes.weight).argmin(dim=1)
        if self._activate:
            seq = self.codes(match)
        seq = seq.reshape(n, bs, self.dim)
        match = match.reshape(n, bs)
        return seq, match

    def forward(self, seq: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            seq: [n, bs, dim]
        Return:
            encoded sequence [n, bs, dim], and matched code for each input token [n, bs]
        """
        t = seq.shape[2]
        assert int(t) == self.dim, f"dimension {seq.shape[2]} not match to {self.dim}"
        return self.encode(seq)


