from dataclasses import dataclass
from typing import Optional

import torch

from XXX.uib.modules.summarizer import EntropySummarizer
from XXX.uib.utils.constants import dbl_null_threshold
from XXX.uib.utils.safe_module import SafeModule
import XXX.uib.information_quantities as iq
from XXX.uib.utils import nonparametric_mutual_info
from XXX.uib.utils import torch_kraskov_entropy

import numpy as np

from XXX.progress_bar import with_progress_bar

from XXX.uib.utils.tensor_chain import CpuTensorChain


@dataclass
class Estimate:
    mean: float
    stddev: float

    def __str__(self):
        return f"{self.mean:.4}+-{self.stddev:.4}"


@dataclass
class ContinuousIqBase:
    values: np.ndarray
    covariance: np.ndarray

    def __init__(self, values, covariance=None):
        self.values = np.array(values, dtype=np.float)
        assert self.values.shape[0] == 4
        self.covariance = covariance if covariance is not None else np.zeros((4, 4))

    @staticmethod
    def from_multiple(values_list):
        mean_value = np.mean(values_list, axis=0)
        if len(values_list) > 1:
            covariance = np.cov(values_list, rowvar=False)
        else:
            covariance = None
        return ContinuousIqBase(mean_value, covariance)

    @property
    def H_Y(self) -> float:
        return self.values[0]

    @property
    def H_Z(self) -> float:
        return self.values[1]

    @property
    def H_Z__Y(self) -> float:
        return self.values[2]

    @property
    def H_Z__X(self) -> float:
        return self.values[3]

    def get_iq_estimate(self, information_quantity: torch.Tensor):
        # y is categorical/one_hot
        # Z is continuous

        # We can compute the following measures:
        # H(Y), H(Yhat), H(Yhat|Y), H(Yhat|X)
        # and
        # H(Y|X) = 0

        # Orthogonal projection
        alpha_Z__Y = iq.H_Y_Z @ information_quantity
        # We can only compute H_Z__Y easily. The rest of the code will figure out the difference.
        information_quantity = information_quantity - alpha_Z__Y * iq.H_Z__Y

        alpha_Y = iq.H_Y @ information_quantity
        alpha_Z = iq.H_Z @ information_quantity

        alpha_Z__X = iq.H_Z__X @ information_quantity

        alpha = torch.stack([alpha_Y, alpha_Z, alpha_Z__Y, alpha_Z__X]).numpy()

        mean_iq = alpha @ self.values
        variance_iq = alpha @ self.covariance @ alpha

        return Estimate(mean_iq.item(), np.sqrt(variance_iq).item())

    def get_iq(self, information_quantity: torch.Tensor):
        # y is categorical/one_hot
        # Z is continuous

        # We can compute the following measures:
        # H(Y), H(Yhat), H(Yhat|Y), H(Yhat|X)
        # and
        # H(Y|X) = 0

        # Orthogonal projection
        alpha_Z__Y = iq.H_Y_Z @ information_quantity
        # We can only compute H_Z__Y easily. The rest of the code will figure out the difference.
        information_quantity = information_quantity - alpha_Z__Y * iq.H_Z__Y

        alpha_Y = iq.H_Y @ information_quantity
        alpha_Z = iq.H_Z @ information_quantity

        alpha_Z__X = iq.H_Z__X @ information_quantity

        result = float(alpha_Y * self.H_Y + alpha_Z * self.H_Z + alpha_Z__Y * self.H_Z__Y + alpha_Z__X * self.H_Z__X)
        return result


class ContinuousLatentLabelEntropiesSummarizer(SafeModule, EntropySummarizer):
    """Keeps track of relevant quantities, so we can compute any kind of information quantity after the fact."""

    encodings_x_k_Z: CpuTensorChain
    labels_x: CpuTensorChain
    iq_base: Optional[ContinuousIqBase]
    fixed_Z__X: Optional[float]
    num_splits: int
    k: int

    def __init__(self, dtype=torch.float64, device="cpu", *, fixed_Z__X=None, num_splits=6, k=1):
        super().__init__(dtype=dtype, device=device)

        self.k = k
        self.num_splits = num_splits
        self.fixed_Z__X = fixed_Z__X
        self.encodings_x_k_Z = CpuTensorChain.create()
        self.labels_x = CpuTensorChain.create()
        self.iq_base = None

    def safe_forward(self, encoding):
        return None

    def reset(self):
        self.encodings_x_k_Z.reset()
        self.labels_x.reset()
        self.iq_base = None

    def fit(self, encodings_x_k_Z: torch.Tensor, labels_x: torch.Tensor):
        assert encodings_x_k_Z.dim() == 3

        self.iq_base = None

        with torch.no_grad():
            self.labels_x.append(labels_x.detach())
            self.encodings_x_k_Z.append(encodings_x_k_Z.detach())

    def compute_iq_values_for(self, encodings_x_k_Z, labels_x, k):
        #encodings_x_k_Z = encodings_x_k_Z.cuda().double()
        #labels_x = labels_x.cuda()

        # y is categorical/one_hot
        # Z is continuous

        # We can compute the following measures:
        # H(Y), H(Yhat), H(Yhat|Y), H(Yhat|X)
        # and
        # H(Y|X) = 0

        # # How do I in-place permute?
        # sorted_labels_x, indices = labels_x.sort()
        # sorted_encodings_X_Z = encodings_X_Z[indices]

        labels, label_count = torch.unique(labels_x, return_counts=True)
        label_count = self.convert_tensor(label_count, non_blocking=True)

        p_Y = label_count / label_count.sum()
        I_Y = -torch.log(p_Y)
        I_Y[p_Y <= dbl_null_threshold] = 0.0
        H_Y = torch.sum(p_Y * I_Y).item()

        flattened_encodings_x_k_Z = encodings_x_k_Z.flatten(0, 1)
        H_Z = nonparametric_mutual_info.entropy(flattened_encodings_x_k_Z.numpy(), k=k)
        #H_Z = torch_kraskov_entropy.entropy(flattened_encodings_x_k_Z, max_k=k)

        p_Y = label_count / label_count.sum()

        H_Z__Y = 0.0
        for i, label in enumerate(labels):
            specific_encodings_X_Z = encodings_x_k_Z[labels_x == label].flatten(0, 1)
            specific_h_Z__Y = nonparametric_mutual_info.entropy(specific_encodings_X_Z.numpy(), k=k)
            #specific_h_Z__Y = torch_kraskov_entropy.entropy(specific_encodings_X_Z, max_k=k)
            if np.isfinite(specific_h_Z__Y):
                H_Z__Y += (p_Y[i] * specific_h_Z__Y).item()

        H_Z__X = 0.0
        if self.fixed_Z__X is not None:
            H_Z__X = self.fixed_Z__X
        elif encodings_x_k_Z.shape[1] > 1:
            length = len(encodings_x_k_Z)
            for i in range(length):
                specific_encodings_k_Z = encodings_x_k_Z[i]
                specific_h_Z__X = nonparametric_mutual_info.entropy(specific_encodings_k_Z.numpy(), k=k)
                #specific_h_Z__X = torch_kraskov_entropy.entropy(specific_encodings_k_Z, max_k=k)
                if np.isfinite(specific_h_Z__X):
                    H_Z__X += specific_h_Z__X
            H_Z__X /= length

        iq_values = [H_Y, H_Z, H_Z__Y, H_Z__X]
        return iq_values

    def compute_iq_base(self, k, num_splits):
        if not len(self.encodings_x_k_Z):
            self.iq_base = ContinuousIqBase([0.0, 0.0, 0.0, 0.0])
        else:
            # y is categorical/one_hot
            # Z is continuous

            # We can compute the following measures:
            # H(Y), H(Yhat), H(Yhat|Y), H(Yhat|X)
            # and
            # H(Y|X) = 0

            encodings_x_k_Z = self.encodings_x_k_Z.get()
            labels_x = self.labels_x.get()

            encodings_xk_Z = encodings_x_k_Z.flatten(0, 1)

            values_list = []
            num_total_samples = len(encodings_xk_Z)
            all_indices = torch.randperm(num_total_samples)

            # How do I bootstrap estimates?
            for indices in with_progress_bar(torch.chunk(all_indices, num_splits)):
                # Need to sample *without* replacement because otherwise the estimator blows up ;-/

                sampled_encodings_xk_Z = encodings_xk_Z[indices]
                sampled_encodings_x_k_Z = sampled_encodings_xk_Z[:, None, :]
                sampled_labels_x = labels_x[:, None].expand(*encodings_x_k_Z.shape[0:2]).flatten()[indices]

                values = self.compute_iq_values_for(sampled_encodings_x_k_Z, sampled_labels_x, k)
                values_list.append(values)

            self.iq_base = ContinuousIqBase.from_multiple(values_list)

    def get_iq_base(self):
        if not self.iq_base:
            self.compute_iq_base(self.k, self.num_splits)
        return self.iq_base

    def get_iq(self, information_quantity: torch.Tensor, k=1):
        information_quantity = self.convert_tensor(information_quantity, device_only=True, non_blocking=True).clone()

        if not len(self.encodings_x_k_Z):
            return 0.0

        # y is categorical/one_hot
        # Z is continuous

        # We can compute the following measures:
        # H(Y), H(Yhat), H(Yhat|Y), H(Yhat|X)
        # and
        # H(Y|X) = 0

        # Orthogonal projection
        alpha_Z__Y = iq.H_Y_Z @ information_quantity
        # We can only compute H_Z__Y easily. The rest of the code will figure out the difference.
        information_quantity -= alpha_Z__Y * iq.H_Z__Y

        alpha_Y = iq.H_Y @ information_quantity
        alpha_Z = iq.H_Z @ information_quantity

        alpha_Z__X = iq.H_Z__X @ information_quantity

        encodings_x_k_Z = self.encodings_x_k_Z.get()
        labels_x = self.labels_x.get()

        # # How do I in-place permute?
        # sorted_labels_x, indices = labels_x.sort()
        # sorted_encodings_X_Z = encodings_X_Z[indices]

        labels, label_count = torch.unique(labels_x, return_counts=True)
        label_count = self.convert_tensor(label_count, non_blocking=True)

        result = 0.0
        if alpha_Y != 0.0:
            p_Y = label_count / label_count.sum()
            I_Y = -torch.log(p_Y)
            I_Y[p_Y <= dbl_null_threshold] = 0.0
            h_Y = torch.sum(p_Y * I_Y)
            result += alpha_Y * h_Y


        if alpha_Z != 0.0:
            flattened_encodings_x_k_Z = encodings_x_k_Z.flatten(0, 1)
            # h_Z = nonparametric_mutual_info.entropy(flattened_encodings_x_k_Z.detach().cpu().numpy(), k=k)
            h_Z = torch_kraskov_entropy.entropy(flattened_encodings_x_k_Z, max_k=k)
            if np.isfinite(h_Z):
                result += alpha_Z * h_Z

        if alpha_Z__Y != 0.0:
            p_Y = label_count / label_count.sum()

            h_Z__Y = 0.0
            for label in labels:
                specific_encodings_X_Z = encodings_x_k_Z[labels_x == label].flatten(0, 1)
                # specific_h_Z__Y = nonparametric_mutual_info.entropy(specific_encodings_X_Z.detach().cpu().numpy(), k=k)
                specific_h_Z__Y = torch_kraskov_entropy.entropy(specific_encodings_X_Z.detach().cpu().numpy(), max_k=k)
                if np.isfinite(specific_h_Z__Y):
                    h_Z__Y += p_Y[label] * specific_h_Z__Y
            result += alpha_Z__Y * h_Z__Y

        if alpha_Z__X != 0.0:
            h_Z__X = 0.0
            num_samples = encodings_x_k_Z.shape[1]
            if num_samples > 1:
                raise NotImplementedError("This is too slow to be useful atm ! :(")
                # knn_k = min(num_samples - 1, k)
                # length = len(encodings_x_k_Z)
                # encodings_x_k_Z.detach()
                # for i in range(length):
                #     specific_encodings_k_Z = np_encodings_x_k_Z[i]
                #     specific_h_Z__X = nonparametric_mutual_info.entropy(
                #         specific_encodings_k_Z, k=knn_k
                #     )
                #     if np.isfinite(specific_h_Z__X):
                #         h_Z__X += specific_h_Z__X
                # h_Z__X /= length

        return self.convert_tensor(torch.as_tensor(result))

    def get_p_Y(self, capacity):
        labels, label_count = torch.unique(self.labels_x.get(), return_counts=True)
        label_count = self.convert_tensor(label_count, non_blocking=True)
        p_Y = torch.zeros(capacity, dtype=self.dtype, device=self.tdevice)
        p_Y[labels] = label_count / label_count.sum()
        return p_Y


class OldContinuousLatentLabelEntropiesSummarizer(ContinuousLatentLabelEntropiesSummarizer):
    def fit(self, encodings_X_Z: torch.Tensor, labels_x: torch.Tensor):
        super().fit(encodings_X_Z[:, None, :], labels_x)
