from dataclasses import dataclass
from typing import Optional

import torch

from XXX.uib.losses.very_approx_regularizers import get_entropy, get_batch_entropy
from XXX.uib.modules.summarizer import EntropySummarizer, IqBase
from XXX.uib.utils.constants import flt_null_threshold
import XXX.uib.information_quantities as iq
from XXX.uib.utils.safe_module import SafeModule


@dataclass
class L2AndVarianceIqBase(IqBase):
    H_Y: float
    H_Z: float
    H_Z__Y: float
    H_Z__X: float

    # Uses global Var[Z|Y] and Var[Z|X] to estimate the entropy
    # instead of computing an expectation over local entropies.
    H_global_Z__Y: float
    H_global_Z__X: float

    L2_Z: float
    L2_mean_Z__X: float

    def get_iq(self, information_quantity: torch.Tensor):
        # y is categorical/one_hot
        # Z is continuous

        # We can compute the following measures:
        # H(Y), H(Yhat), H(Yhat|Y), H(Yhat|X)
        # and
        # H(Y|X) = 0

        # Orthogonal projection
        alpha_Z__Y = iq.H_Y_Z @ information_quantity
        # We can only compute H_Z__Y easily. The rest of the code will figure out the difference.
        information_quantity = information_quantity - alpha_Z__Y * iq.H_Z__Y

        alpha_Y = iq.H_Y @ information_quantity
        alpha_Z = iq.H_Z @ information_quantity

        alpha_Z__X = iq.H_Z__X @ information_quantity

        return float(alpha_Y * self.H_Y + alpha_Z * self.H_Z + alpha_Z__Y * self.H_Z__Y + alpha_Z__X * self.H_Z__X)


class L2AndVarianceSummarizer(SafeModule, EntropySummarizer):
    """Keeps track of relevant quantities, so we can compute any kind of information quantity after the fact."""

    total_squared_Z: torch.Tensor
    total_Z: torch.Tensor
    total_squared_mean_Z__X: torch.Tensor
    total_squared_Z__Y_y_z: torch.Tensor
    total_Z__Y_y_z: torch.Tensor
    num_y: torch.Tensor
    num_x: int

    total_H_Z__X: float

    iq_base: Optional[IqBase]

    in_capacity: int
    out_capacity: int

    def __init__(self, in_capacity: int, out_capacity: int, dtype=torch.float64, device="cpu"):
        super().__init__(dtype, device)

        self.out_capacity = out_capacity
        self.in_capacity = in_capacity

        self.num_x = 0
        self.total_H_Z__X = 0
        self.iq_base = None

        self.total_squared_Z = torch.zeros(in_capacity, dtype=dtype, device=device, requires_grad=False)
        self.total_Z = torch.zeros(in_capacity, dtype=dtype, device=device, requires_grad=False)
        self.total_squared_mean_Z__X = torch.zeros(in_capacity, dtype=dtype, device=device, requires_grad=False)
        self.total_squared_Z__Y_y_z = torch.zeros(
            (out_capacity, in_capacity), dtype=dtype, device=device, requires_grad=False
        )
        self.total_Z__Y_y_z = torch.zeros((out_capacity, in_capacity), dtype=dtype, device=device, requires_grad=False)
        self.num_y = torch.zeros(out_capacity, dtype=dtype, device=device, requires_grad=False)

    def reset(self):
        with torch.no_grad():
            self.total_squared_Z.zero_()
            self.total_Z.zero_()
            self.total_squared_mean_Z__X.zero_()
            self.total_squared_Z__Y_y_z.zero_()
            self.total_Z__Y_y_z.zero_()
            self.num_y.zero_()
            self.num_x = 0
            self.total_H_Z__X = 0
            self.iq_base = None

    def fit(self, latent_x_k_z: torch.Tensor, labels_x: torch.Tensor):
        latent_x_k_z = self.convert_tensor(latent_x_k_z, non_blocking=True)
        labels_x = self.convert_tensor(labels_x, device_only=True, non_blocking=True)

        self.iq_base = None

        with torch.no_grad():
            squared_x_k_z = latent_x_k_z ** 2

            mean_squared_Z__X_x_z = torch.mean(squared_x_k_z, dim=1)
            self.total_squared_Z += torch.sum(mean_squared_Z__X_x_z, dim=0)

            mean_Z__X_x_z = torch.mean(latent_x_k_z, dim=1)
            self.total_Z += torch.sum(mean_Z__X_x_z, dim=0)

            self.total_squared_mean_Z__X += torch.sum(mean_Z__X_x_z ** 2, dim=0)

            local_varianze_Z__X_x_z = mean_squared_Z__X_x_z - mean_Z__X_x_z ** 2
            local_entropy_Z__X_x = get_batch_entropy(local_varianze_Z__X_x_z)
            self.total_H_Z__X += torch.sum(local_entropy_Z__X_x, dim=0)

            labels, counts = torch.unique(labels_x, return_counts=True)

            for label, count in zip(labels, counts):
                self.total_squared_Z__Y_y_z[label] += torch.sum(mean_squared_Z__X_x_z[labels_x == label], dim=0)
                self.total_Z__Y_y_z[label] += torch.sum(mean_Z__X_x_z[labels_x == label], dim=0)
                self.num_y[label] += count

            self.num_x += len(labels_x)

    def compute_iq_base(self):
        if self.num_x == 0:
            self.iq_base = L2AndVarianceIqBase(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        else:
            p_Y = self.num_y / self.num_x
            I_Y = -torch.log(p_Y)
            I_Y[p_Y <= flt_null_threshold] = 0.0
            H_Y = torch.sum(p_Y * I_Y)

            variance_Z = self.total_squared_Z / self.num_x - (self.total_Z / self.num_x) ** 2
            H_Z = get_entropy(variance_Z)

            variance_Z__X = self.total_squared_Z / self.num_x - self.total_squared_mean_Z__X / self.num_x
            global_H_Z__X = get_entropy(variance_Z__X)

            H_Z__X = self.total_H_Z__X / self.num_x

            variance_Z__Y_y_z = (
                self.total_squared_Z__Y_y_z / self.num_y[:, None] - (self.total_Z__Y_y_z / self.num_y[:, None]) ** 2
            )
            global_H_Z__Y = get_entropy(torch.sum(variance_Z__Y_y_z * p_Y[:, None], dim=0))

            H_Z__Y_y = get_batch_entropy(variance_Z__Y_y_z)
            H_Z__Y = torch.sum(H_Z__Y_y * p_Y)

            L2_Z = torch.sum(self.total_squared_Z / self.num_x)
            L2_mean_Z__X = torch.sum(self.total_squared_mean_Z__X / self.num_x)

            self.iq_base = L2AndVarianceIqBase(
                H_Y.item(),
                H_Z.item(),
                H_Z__Y.item(),
                H_Z__X.item(),
                global_H_Z__Y.item(),
                global_H_Z__X.item(),
                L2_Z.item(),
                L2_mean_Z__X.item(),
            )

    def get_iq_base(self) -> L2AndVarianceIqBase:
        if not self.iq_base:
            self.compute_iq_base()
        return self.iq_base
