import torch

from XXX.uib.modules.decoder_interface import DecoderInterface
from XXX.uib.utils.constants import dbl_null_threshold
from XXX.uib.utils.safe_module import SafeModule


class CategoricalDecoder(DecoderInterface):
    c_Y_Z: torch.Tensor

    in_capacity: int
    out_capacity: int

    def __init__(self, in_capacity: int, out_capacity: int, dtype=torch.float64, device=None):
        super().__init__(dtype=dtype, device=device)

        self.out_capacity = out_capacity
        self.in_capacity = in_capacity

        self.register_buffer(
            "c_Y_Z", torch.zeros((out_capacity, in_capacity), dtype=dtype, device=device, requires_grad=False)
        )

    def safe_forward(self, encodings_x_z):
        raise ValueError("This shouldn't called anymore---CategoricalEncoderDecoder does the lifting!")
        # # predictions: BxC
        # p_Y__Z = self.c_Y_Z / self.c_Y_Z.sum(dim=0, keepdim=True)
        # return encodings_x_z @ p_Y__Z.t()

    def reset(self):
        with torch.no_grad():
            self.c_Y_Z.zero_()

    def train(self, mode: bool = True):
        if self.training is False and mode is True:
            self.reset()
        super().train(mode)

    def fit(self, encodings: torch.Tensor, labels_x: torch.Tensor):
        with torch.no_grad():
            normalized_encodings = torch.nn.functional.softmax(encodings, dim=-1, dtype=self.dtype)

            i_y_label = torch.tensor(range(self.out_capacity), device=labels_x.device, requires_grad=False).reshape(
                (-1, 1)
            ) == labels_x.reshape((1, -1))
            new_c_Y_Z = (i_y_label.double() @ normalized_encodings).detach()

        self.update_c_Y_Z(new_c_Y_Z)

    def update_c_Y_Z(self, new_c_Y_Z):
        with torch.no_grad():
            self.c_Y_Z += new_c_Y_Z

    def get_p_Y_Z(self):
        C = torch.sum(self.c_Y_Z)
        if C != 0.0:
            return self.c_Y_Z / C
        return self.c_Y_Z * 0

    def get_p_Y__Z(self):
        return self.c_Y_Z / self.c_Y_Z.sum(dim=0, keepdim=True)

    def decoder_uncertainty(self):
        p_Y_Z = self.get_p_Y_Z()
        p_1_Z = torch.sum(p_Y_Z, dim=0, keepdim=True)

        I_p_Y_Z = -torch.log(p_Y_Z) + torch.log(p_1_Z)
        I_p_Y_Z[p_Y_Z <= dbl_null_threshold] = 0.0
        return torch.sum(I_p_Y_Z)


class CategoricalPermutationDecoder(CategoricalDecoder):
    def safe_forward(self, encoding):
        eps = 1e-22
        return torch.log(eps + torch.nn.functional.one_hot(self.decode_index(encoding), self.out_capacity))

    def decode_index(self, predictions):
        predicted = predictions.argmax(dim=-1)
        return self.c_Y_Z.argmax(dim=0, keepdim=False)[predicted]
