import math
from typing import Any, Tuple

import numpy as np  # type: ignore
import torch
from torch import nn
from torch.utils.data import DataLoader

from deep_deterministic_uncertainty.wide_resnet import wide_sn_resnet28_cifar

T = torch.Tensor


class DDU(nn.Module):
    def __init__(self, base: nn.Module, num_classes: int, h_dim: int = 128, s: float = 0.001, m: float = 0.999):
        super().__init__()
        self.base = base
        self.classes = num_classes
        self.s = s
        self.h_dim = h_dim

        self.layers = base
        self.name = self.base.name  # type: ignore
        self.total = 0
        self.m = m

        self.out = nn.Linear(h_dim, num_classes)

        self.prec: T
        self.cov: T
        self.centroids: T
        self.class_counts: T
        self.cov_logdets: T

        self.register_buffer("prec", torch.zeros(num_classes, h_dim, h_dim, requires_grad=False))
        self.register_buffer("centroids", torch.zeros(num_classes, h_dim, requires_grad=False))
        self.register_buffer("class_counts", torch.zeros(num_classes, 1, requires_grad=False))
        self.register_buffer("cov", torch.zeros(num_classes, h_dim, h_dim, requires_grad=False))
        self.register_buffer("cov_logdets", torch.zeros(num_classes, 1, requires_grad=False))

        self.register_buffer("px_mean", torch.tensor(0.0))
        self.register_buffer("px_std", torch.tensor(0.0))

    def init_sigma_lambda(self) -> None:
        """reinitializes sigma and lambda to new values"""
        self.prec = torch.stack([torch.eye(self.h_dim, self.h_dim, requires_grad=False, device=self.prec.device) for _ in range(self.classes)])  # type: ignore
        self.prec *= self.s  # type: ignore
        self.cov = torch.zeros((self.classes, self.h_dim, self.h_dim), device=self.cov.device, requires_grad=False)  # type: ignore

    def forward(self, x: T, update_prec: bool = False) -> T:
        return self.out(self.layers(x))  # type: ignore

    def update_centroids(self, x: T, y: T) -> None:
        """update the centroids after the model has been trained. This should be done before updating the covariances"""
        with torch.no_grad():
            y = y.unsqueeze(0).repeat(self.classes, 1)
            is_class = y == torch.arange(self.classes, device=x.device).unsqueeze(1)
            phi = self.layers(x).unsqueeze(0).repeat(self.classes, 1, 1)
            phi = phi * is_class.unsqueeze(-1)
            self.centroids += phi.sum(dim=1)
            self.class_counts += is_class.sum(dim=1, keepdim=True)

    def compute_centroids(self) -> None:
        self.centroids = self.centroids / self.class_counts

    def update_covariance(self, x: T, y: T) -> None:
        with torch.no_grad():
            y = y.unsqueeze(0).repeat(self.classes, 1)
            is_class = y == torch.arange(self.classes, device=x.device).unsqueeze(1)
            phi = self.layers(x).unsqueeze(0).repeat(self.classes, 1, 1)
            phi = phi - self.centroids.unsqueeze(1)  # center all the instances by the centroids
            phi = phi * is_class.unsqueeze(-1)  # zero out instances which do not belong to the right class
            self.cov += torch.einsum("cnb,cnd->cbd", phi, phi)

    def compute_covariance(self) -> None:
        self.cov = self.cov / (self.class_counts.unsqueeze(-1) - 1)

    def invert_covariance(self) -> None:
        for i in range(self.classes):
            I = torch.eye(self.cov[i].size(0), device=self.cov.device) * 0.00001
            its = 0
            while True:
                if its == 5:
                    raise ValueError("tried to invert covariance 5 times and failed")

                its += 1
                I *= 10

                try:
                    self.prec[i] = torch.inverse(self.cov[i] + I)

                    _, s, _ = torch.svd(self.cov[i] + I)
                    self.cov_logdets[i] = torch.sum(torch.log(s))
                    if torch.any(torch.isnan(self.prec)) or torch.any(torch.isnan(self.cov_logdets)):
                        continue

                    break
                except RuntimeError:
                    pass

        if torch.any(torch.isinf(self.prec)) or torch.any(torch.isnan(self.prec)):
            raise ValueError(f"got NaNs in prec: {self.prec}")

        if torch.any(torch.isinf(self.cov_logdets)) or torch.any(torch.isnan(self.cov_logdets)):
            raise ValueError(f"got NaNs in prec: {self.cov_logdets}")

        # print(f"\n\nprec: {self.prec}\n\n")
        # print(f"\n\ncov logdets: {self.cov_logdets}\n\n")

    def log_px(self, phi: T) -> T:
        # normalization constant (1 / 2pi^(d/2) * sqrt(det(sigma)))
        c = (self.cov.size(1) / 2) * np.log(2 * math.pi) + 0.5 * self.cov_logdets
        if torch.any(torch.isinf(c)) or torch.any(torch.isnan(c)):
            print(f"c has nans: {c}")

        xmu = (phi.unsqueeze(1) - self.centroids.unsqueeze(0))  # (batch, class, h_dim)
        mahalanobis = torch.einsum("bch,chd->bcd", xmu, self.prec)
        mahalanobis = 0.5 * torch.einsum("bch,bch->bc", mahalanobis, xmu)

        if torch.any(torch.isinf(mahalanobis)) or torch.any(torch.isnan(mahalanobis)):
            print(f"mahalanobis has nans: {mahalanobis}")

        return -(c.t() + mahalanobis) + torch.log(self.class_counts / self.class_counts.sum()).t()  # type: ignore

    def tune(self, loader: DataLoader) -> None:
        m_log_px, total = torch.Tensor(), 0.0
        for i, (x, y) in enumerate(loader):
            phi = self.layers(x)
            logpx = self.log_px(phi)
            logpx = torch.logsumexp(logpx, dim=1)

            m_log_px = torch.cat((m_log_px, logpx))
            total += x.size(0)

        self.px_mean = m_log_px.mean()
        self.px_std = m_log_px.std()

    def inference(self, x: T, px_thresh: float = 3.0) -> Tuple[T, T, T]:
        phi = self.layers(x)
        logits = self.out(phi)
        # p_yx = logits.softmax(dim=-1)
        log_px = self.log_px(phi)
        m_log_px = torch.logsumexp(log_px, dim=1, keepdim=True)

        m_log_px_normed = (m_log_px - self.px_mean) / self.px_std
        m_log_px_mask = m_log_px_normed > -3  # type: ignore
        log_px = log_px * m_log_px_mask

        # aleatoric, epistemic
        return logits, log_px, m_log_px_mask


def DDU_WideResNet28_cifar(resnet_kwargs: Any = {}, ddu_kwargs: Any = {}) -> DDU:
    net = wide_sn_resnet28_cifar(**resnet_kwargs)
    return DDU(net, h_dim=net.out_dim, **ddu_kwargs)
