from dataclasses import dataclass
from typing import Optional

import torch

from XXX.uib.modules.summarizer import EntropySummarizer, IqBase
from XXX.uib.utils.constants import dbl_null_threshold
from XXX.uib.utils.safe_module import SafeModule
import XXX.uib.information_quantities as iq


@dataclass
class CategoricalIqBase(IqBase):
    H_Y: float
    H_Z: float
    H_Y_Z: float
    H_Z__X: float
    H_Y__X: float

    def get_iq(self, information_quantity: torch.Tensor):
        alpha_Y = information_quantity @ iq.H_Y
        alpha_Z = information_quantity @ iq.H_Z
        alpha_Y_Z = information_quantity @ iq.H_Y_Z
        alpha_Z__X = information_quantity @ iq.H_Z__X
        alpha_Y__X = information_quantity @ iq.H_Y__X

        return float(
            alpha_Y * self.H_Y
            + alpha_Z * self.H_Z
            + alpha_Y_Z * self.H_Y_Z
            + alpha_Z__X * self.H_Z__X
            + alpha_Y__X * self.H_Y__X
        )


# TODO: use the summarizer in the bayes decoder!
class CategoricalEntropiesSummarizer(SafeModule, EntropySummarizer):
    """Keeps track of relevant quantities, so we can compute any kind of information quantity after the fact."""

    acc_p_Y_Z: torch.Tensor
    # h_Z__X = acc_h_Z__X / num_x
    acc_h_Z__X: torch.Tensor
    num_x: int

    iq_base: Optional[IqBase]

    in_capacity: int
    out_capacity: int

    def __init__(self, in_capacity: int, out_capacity: int, dtype=torch.float64, device=None):
        super().__init__(dtype=dtype, device=device)

        self.out_capacity = out_capacity
        self.in_capacity = in_capacity
        self.num_x = 0
        self.iq_base = None

        self.register_buffer(
            "acc_p_Y_Z", torch.zeros((out_capacity, in_capacity), dtype=dtype, device=device, requires_grad=False)
        )
        self.register_buffer("acc_h_Z__X", torch.zeros((), dtype=dtype, device=device, requires_grad=False))

    def safe_forward(self, encoding):
        # predictions: BxC
        normalized_predictions = torch.nn.functional.softmax(encoding, dim=-1, dtype=self.dtype)
        p_Y__Z = self.get_p_Y__Z()
        return torch.log(p_Y__Z @ normalized_predictions.t()).t()

    def reset(self):
        with torch.no_grad():
            self.acc_p_Y_Z.zero_()
            self.acc_h_Z__X.zero_()
            self.num_x = 0
            self.iq_base = None

    def fit(self, latent_x_k_z: torch.Tensor, labels_x: torch.Tensor):
        assert latent_x_k_z.shape[1] == 1
        logits_X_Z = latent_x_k_z[:, 0, :]

        self.iq_base = None

        with torch.no_grad():
            logits_X_Z = logits_X_Z.to(dtype=self.dtype, device=self.tdevice)
            p_X_Z = torch.nn.functional.softmax(logits_X_Z, dim=1, dtype=self.dtype)

            labels_x = labels_x.to(dtype=self.dtype, device=self.tdevice)

            # TODO: replace with one_hot(labels_x, num_classes=self.out_capacity)?
            i_y_x = torch.tensor(range(self.out_capacity), device=labels_x.device, requires_grad=False).reshape(
                (-1, 1)
            ) == labels_x.reshape((1, -1))
            p_Y_Z = (i_y_x.double() @ p_X_Z).detach()
            self.acc_p_Y_Z += p_Y_Z

            I_X_Z = -torch.log(p_X_Z)
            I_X_Z[p_X_Z <= dbl_null_threshold] = 0.0
            h_Z__X = torch.sum(p_X_Z * I_X_Z)
            self.acc_h_Z__X += h_Z__X
            self.num_x += len(labels_x)

    def get_p_Y_Z(self):
        return self.acc_p_Y_Z / self.num_x

    def get_h_Z__X(self):
        return self.acc_h_Z__X / self.num_x

    def get_p_Y__Z(self):
        return self.acc_p_Y_Z / self.acc_p_Y_Z.sum(dim=0, keepdim=True)

    def compute_iq_base(self):
        if self.num_x == 0:
            self.iq_base = CategoricalIqBase(0.0, 0.0, 0.0, 0.0, 0.0)
        else:
            p_Y_Z = self.get_p_Y_Z()

            p_Y = torch.sum(p_Y_Z, dim=1, keepdim=False)
            I_Y = -torch.log(p_Y)
            I_Y[p_Y <= dbl_null_threshold] = 0.0
            H_Y = torch.sum(p_Y * I_Y).item()

            I_Y_Z = -torch.log(p_Y_Z)
            I_Y_Z[p_Y_Z <= dbl_null_threshold] = 0.0
            H_Y_Z = torch.sum(p_Y_Z * I_Y_Z).item()

            p_Z = torch.sum(p_Y_Z, dim=0, keepdim=False)
            I_Z = -torch.log(p_Z)
            I_Z[p_Z <= dbl_null_threshold] = 0.0
            H_Z = torch.sum(p_Z * I_Z).item()

            H_Z__X = self.get_h_Z__X().item()

            self.iq_base = CategoricalIqBase(H_Y, H_Z, H_Y_Z, H_Z__X, 0.0)

    def get_iq_base(self):
        if not self.iq_base:
            self.compute_iq_base()
        return self.iq_base

    def get_iq(self, information_quantity: torch.Tensor):
        if self.num_x == 0:
            return 0.0

        alpha_Y = information_quantity @ iq.H_Y
        alpha_Z = information_quantity @ iq.H_Z
        alpha_Y_Z = information_quantity @ iq.H_Y_Z
        alpha_Z__X = information_quantity @ iq.H_Z__X

        p_Y_Z = self.get_p_Y_Z()

        J_Y_Z = torch.zeros_like(p_Y_Z)
        if alpha_Y != 0:
            p_Y_1 = torch.sum(p_Y_Z, dim=1, keepdim=True)
            J_Y_Z += -alpha_Y * torch.log(p_Y_1)

        if alpha_Y_Z != 0:
            J_Y_Z += -alpha_Y_Z * torch.log(p_Y_Z)

        if alpha_Z != 0:
            p_1_Z = torch.sum(p_Y_Z, dim=0, keepdim=True)
            J_Y_Z += -alpha_Z * torch.log(p_1_Z)

        J_Y_Z[p_Y_Z <= dbl_null_threshold] = 0.0

        result = torch.sum(p_Y_Z * J_Y_Z)

        if alpha_Z__X != 0:
            h_Z__X = self.get_h_Z__X()
            result += alpha_Z__X * h_Z__X

        return result
