from typing import List, Optional

import torch

from XXX.uib.modules.decoder_interface import DecoderInterface
from experiments.models.zero_entropy_noise import zero_entropy_noise_var


class ClusterDecoder(DecoderInterface):
    means_Y_Z: torch.Tensor

    in_capacity: int
    out_capacity: int

    def __init__(self, in_capacity, out_capacity, dtype=torch.float64, device=None):
        super().__init__(dtype=dtype, device=device)

        self.out_capacity = out_capacity
        self.in_capacity = in_capacity

        self.register_buffer(
            "means_Y_Z", torch.zeros((out_capacity, in_capacity), dtype=dtype, device=device, requires_grad=False)
        )

    def safe_forward(self, predictions):
        dist_b_y = torch.cdist(self.means_Y_Z[None, ...], predictions[:, None, :].double())
        # Softmax will turn this into a softmin.
        return -dist_b_y[:, :, 0]

    def reset(self):
        with torch.no_grad():
            self.means_Y_Z.zero_()

    def fit(self, encodings: torch.Tensor, targets: torch.Tensor):
        with torch.no_grad():
            new_means_Y_Z = torch.stack(
                tuple(torch.mean(encodings[targets == y].double(), dim=0) for y in range(self.out_capacity))
            )
            # TODO: fix this!
            self.means_Y_Z = (0.90 * self.means_Y_Z + new_means_Y_Z) / 1.90


class GaussianMixtureDecoder(DecoderInterface):
    sum_Y_Z: torch.Tensor
    sum_outer_product_Y_Z_Z: torch.Tensor
    num_samples_Y: torch.Tensor

    cached_gaussians: Optional[torch.distributions.MultivariateNormal]

    in_capacity: int
    out_capacity: int

    def __init__(self, in_capacity, out_capacity, dtype=torch.float64, device=None):
        super().__init__(dtype=dtype, device=device)

        self.out_capacity = out_capacity
        self.in_capacity = in_capacity

        self.register_buffer(
            "sum_Y_Z", torch.zeros((out_capacity, in_capacity), dtype=dtype, device=device, requires_grad=False)
        )
        self.register_buffer(
            "sum_outer_product_Y_Z_Z",
            torch.zeros((out_capacity, in_capacity, in_capacity), dtype=dtype, device=device, requires_grad=False)
        )
        self.register_buffer(
            "num_samples_Y", torch.zeros((out_capacity,), dtype=torch.long, device=device, requires_grad=False)
        )

        self.cached_gaussians = None

    def fit(self, encodings: torch.Tensor, targets: torch.Tensor):
        self.cached_gaussians = None

        with torch.no_grad():
            self.num_samples_Y += torch.sum(
                torch.eye(self.out_capacity, dtype=torch.long, device=targets.device)[targets], dim=0)

            for y in range(self.out_capacity):
                y_encodings_B_ = encodings[targets == y].double()
                # Handle stochasticity
                if len(y_encodings_B_.shape) == 3:
                    y_encodings_B_Z = y_encodings_B_.flatten(0, 1)
                else:
                    y_encodings_B_Z = y_encodings_B_

                y_new_mean_Z = torch.sum(y_encodings_B_Z, dim=0)
                self.sum_Y_Z[y] += y_new_mean_Z

                # y_new_mean_squared_Y_Z_Z = y_encodings_B_Z.t() @ y_encodings_B_Z
                for b in range(len(y_encodings_B_Z)):
                    v = y_encodings_B_Z[b, :, None] @ y_encodings_B_Z[b, None, :]
                    if b == 0:
                        y_new_mean_squared_Y_Z_Z = v
                    else:
                        y_new_mean_squared_Y_Z_Z += v

                self.sum_outer_product_Y_Z_Z[y] += y_new_mean_squared_Y_Z_Z

    def update_gaussians(self):
        mean_Y_Z = self.sum_Y_Z / self.num_samples_Y[:, None]
        squared_mean_Z_Y_Y = mean_Y_Z[:, :, None] @ mean_Y_Z[:, None, :]
        mean_outer_Y_Z_Z = self.sum_outer_product_Y_Z_Z / self.num_samples_Y[:, None, None]

        covariance_Y_Z_Z = mean_outer_Y_Z_Z - squared_mean_Z_Y_Y
        # for i in range(len(covariance_Y_Z_Z)):
        #     covariance_Y_Z_Z[i].fill_diagonal_(zero_entropy_noise_var)

        print(torch.diagonal(covariance_Y_Z_Z, dim1=1, dim2=2).min())
        try:
            self.cached_gaussians = torch.distributions.MultivariateNormal(
                loc=mean_Y_Z,
                covariance_matrix=covariance_Y_Z_Z + torch.eye(covariance_Y_Z_Z.shape[1],
                                                               device=covariance_Y_Z_Z.device,
                                                               dtype=covariance_Y_Z_Z.dtype) * 1e-9
            )

        except RuntimeError as r:
            if "cholesky_cuda" not in str(r):
                raise

    def safe_forward(self, predictions_B_Z):
        stochastic = False
        if len(predictions_B_Z.shape) == 3:
            assert predictions_B_Z.shape[1] == 1
            predictions_B_Z = predictions_B_Z[:,0,:]
            stochastic = True

        if self.cached_gaussians is None:
            self.update_gaussians()

        if self.cached_gaussians is not None:
            log_probs_B_Y = self.cached_gaussians.log_prob(predictions_B_Z[:, None, :])
        else:
            log_probs_B_Y = torch.zeros((len(predictions_B_Z), self.out_capacity), dtype=predictions_B_Z.dtype,
                                        device=predictions_B_Z.device)

        return log_probs_B_Y

    def reset(self):
        with torch.no_grad():
            self.sum_Y_Z.zero_()
            self.sum_outer_product_Y_Z_Z.zero_()
            self.num_samples_Y.zero_()
