import math
from typing import Optional, Tuple

import numpy as np  # type: ignore
import torch
from layers import SpectralNorm
from torch import nn
from torch.nn import functional as F

T = torch.Tensor

__all__ = ["DiagCovEncoder", "LowRankCovEncoder", "batched_sherman_morrison_rank_one_inverse"]


class MAB(nn.Module):
    def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: bool = False, pma: bool = False, bias: bool = False, pma_type: str = "no-residual"):
        super(MAB, self).__init__()
        """from the set transformer https://github.com/juho-lee/set_transformer/blob/master/modules.py"""
        self.pma = pma
        self.pma_type = pma_type

        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V, bias=bias)
        self.fc_k = nn.Linear(dim_K, dim_V, bias=bias)
        self.fc_v = nn.Linear(dim_K, dim_V, bias=bias)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V, bias=bias)

    def forward(self, Q: T, K: T) -> T:
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)

        if not self.pma or self.pma_type == "additive":
            O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        elif self.pma and self.pma_type == "multiplicative":
            O = torch.cat((Q_ * A.bmm(V_)).split(Q.size(0), 0), 2)
        elif self.pma and self.pma_type == "no-residual":
            O = torch.cat((A.bmm(V_)).split(Q.size(0), 0), 2)
        else:
            raise ValueError(f"got invalid combination of {self.pma=} and {self.pma_type=}")

        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O


class MABSN(nn.Module):
    def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: bool = False, c: float = 1.0, pma: bool = False, bias: bool = False, pma_type: str = "no-residual"):
        super(MABSN, self).__init__()
        """from the set transformer https://github.com/juho-lee/set_transformer/blob/master/modules.py"""

        self.pma = pma
        self.pma_type = pma_type

        self.dim_V = dim_V
        self.num_heads = num_heads
        self.dims_equal = (dim_Q == dim_K) and (dim_K == dim_V)
        # self.fc_q = SpectralNorm(nn.Linear(dim_Q, dim_V, bias=bias), c=c)
        # self.fc_k = SpectralNorm(nn.Linear(dim_K, dim_V, bias=bias), c=c)
        self.fc_q = nn.Linear(dim_Q, dim_V, bias=bias)
        self.fc_k = nn.Linear(dim_K, dim_V, bias=bias)
        self.fc_v = SpectralNorm(nn.Linear(dim_K, dim_V, bias=bias), c=c)

        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)

        self.fc_o = SpectralNorm(nn.Linear(dim_V, dim_V, bias=bias), c=c)

    def forward(self, Q: T, K: T) -> T:
        if self.dims_equal:
            Q = self.fc_q(Q)
            V = K + F.relu(self.fc_v(K))
            K = self.fc_k(K)
        else:
            Q = self.fc_q(Q)
            V = self.fc_v(K)
            K = self.fc_k(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)
        # print(f"spectral: {A=}")

        if not self.pma or self.pma_type == "additive":
            O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        elif self.pma and self.pma_type == "multiplicative":
            O = torch.cat((Q_ * A.bmm(V_)).split(Q.size(0), 0), 2)
        elif self.pma and self.pma_type == "no-residual":
            O = torch.cat((A.bmm(V_)).split(Q.size(0), 0), 2)
        else:
            raise ValueError(f"got invalid combination of {self.pma=} and {self.pma_type=}")

        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O


class SAB(nn.Module):
    def __init__(self, dim_in: int, dim_out: int, num_heads: int, ln: bool = False, spectral: bool = False, c: float = 1.0, bias: bool = False):
        super(SAB, self).__init__()
        self.mab = MABSN(dim_in, dim_in, dim_out, num_heads, ln=ln, c=c, bias=bias) if spectral else \
            MAB(dim_in, dim_in, dim_out, num_heads, ln=ln, bias=bias)

    def forward(self, X: T) -> T:
        return self.mab(X, X)  # type: ignore


class ISAB(nn.Module):
    def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool = False, spectral: bool = False, c: float = 1.0, bias: bool = False):
        super().__init__()
        self.I = nn.Parameter(T(1, num_inds, dim_out))
        # nn.init.xavier_uniform_(self.I)
        nn.init.normal_(self.I)
        self.mab0 = MABSN(dim_out, dim_in, dim_out, num_heads, ln=ln, c=c, bias=bias) if spectral else \
            MAB(dim_out, dim_in, dim_out, num_heads, ln=ln, bias=bias)

        self.mab1 = MABSN(dim_in, dim_out, dim_out, num_heads, ln=ln, c=c, bias=bias) if spectral else \
            MAB(dim_in, dim_out, dim_out, num_heads, ln=ln, bias=bias)

    def forward(self, X: T) -> T:
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)  # type: ignore


class PMA(nn.Module):
    def __init__(self, dim: int, num_heads: int, num_seeds: int, ln: bool = False, spectral: bool = False, c: float = 1.0, bias: bool = False, pma_type: str = "no-residual"):
        super().__init__()
        self.S = nn.Parameter(T(1, num_seeds, dim))
        # nn.init.xavier_uniform_(self.S)
        nn.init.normal_(self.S)
        self.mab = MABSN(dim, dim, dim, num_heads, ln=ln, c=c, pma=True, bias=bias, pma_type=pma_type) if spectral else \
            MAB(dim, dim, dim, num_heads, ln=ln, pma=True, bias=bias, pma_type=pma_type)

    def forward(self, X: T) -> T:
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)  # type: ignore


def softplus(x: T) -> T:
    return 0.01 + 0.99 * F.softplus(x)  # type: ignore


class SigmaEncoder(nn.Module):
    def __init__(
        self,
        dim_input: int,
        num_outputs: int,
        dim_output: int,
        dim_hidden: int = 128,
        n_inds: int = 8,
        num_heads: int = 4,
        ln: bool = False,
        p: float = 0.1,
        spectral: bool = False,
        sigma_spectral: bool = False,
        c: float = 1.0,
        bias: bool = True,
        pma_type: str = "no-residual"
    ):
        super().__init__()

        self.sigma_enc = nn.Sequential(
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln, spectral=sigma_spectral, c=c, bias=bias),
            SAB(dim_hidden, dim_hidden, num_heads, ln=ln, spectral=sigma_spectral, c=c, bias=bias),
        )

        self.sigma_pool = nn.Sequential(
            # nn.Dropout(p=p),
            PMA(dim_hidden, num_heads, num_outputs, ln=ln, spectral=sigma_spectral, c=c, bias=bias, pma_type=pma_type),
            # nn.Dropout(p=p),
        )

        self.sigma_out = SpectralNorm(nn.Linear(dim_hidden, dim_output)) if spectral else \
            nn.Linear(dim_hidden, dim_output)

    def centroid_enc_lyr(self, z: T) -> T:
        # return F.layer_norm(z, (z.size(-1),))
        return z
        # return self.centroid_enc(z.unsqueeze(1)).squeeze(1)  # type: ignore

    def base(self, z: T, n_way: int = 5, k_shot: int = 5) -> Tuple[T, T]:
        z = z.view(n_way, k_shot, -1)
        centroids = z.mean(dim=1, keepdim=True)
        # centroids = F.layer_norm(centroids, (centroids.size(-1),))

        # z = F.layer_norm(z, (z.size(-1),))
        if k_shot > 1:
            z = z - centroids

        z = self.sigma_enc(z)
        z = self.sigma_pool(z)

        # calculate the cosine similarity
        # vecs = z.transpose(0, 1)  # (seed, class, d)
        # print(f"{vecs=}")
        # dot = torch.einsum("scd,sed->sce", vecs, vecs)
        # norms = (vecs ** 2).sum(dim=-1).sqrt()
        # norms = norms.unsqueeze(1) * norms.unsqueeze(-1)
        # cos = dot / norms
        # print(f"after pma {vecs.size()=}: {cos=}")

        z = z + F.relu(self.sigma_out(z))

        return centroids, z

    def forward(self, z: T, n_way: int = 5, k_shot: int = 5, compute_cov: bool = True) -> Tuple[T, T, Optional[T], T]:
        raise NotImplementedError()


class DiagCovEncoder(SigmaEncoder):
    def __init__(
        self,
        dim_input: int,
        num_outputs: int,
        dim_output: int,
        dim_hidden: int = 128,
        n_inds: int = 8,
        num_heads: int = 4,
        ln: bool = False,
        p: float = 0.1,
        spectral: bool = False,
        c: float = 1.0,
        pma_type: str = "no-residual"
    ):
        super().__init__(
            dim_input=dim_input,
            num_outputs=num_outputs,
            dim_output=dim_output,
            dim_hidden=dim_hidden,
            n_inds=n_inds,
            num_heads=num_heads,
            ln=ln,
            p=p,
            spectral=spectral,
            c=c,
            pma_type=pma_type
        )

        self.shared_cov = nn.Parameter(torch.zeros(1, dim_hidden, requires_grad=True))

    def forward(self, z: T, n_way: int = 5, k_shot: int = 5, compute_cov: bool = False) -> Tuple[T, T, Optional[T], T]:
        centroids, z = self.base(z, n_way=n_way, k_shot=k_shot)

        z = z.squeeze(1)
        z = torch.clamp(torch.sigmoid(z), 0.1)
        # print(f"{z.max(dim=-1)=} {z.min(dim=-1)=}")

        covariance, cov_logdet = None, None
        if compute_cov:
            covariance = torch.stack([torch.diag(v) for v in z])

        precision = torch.stack([torch.diag(1 / v) for v in z])  # precision is the inverse diagonal
        cov_logdet = torch.sum(torch.log(torch.clamp(z, min=1e-2)), dim=-1)  # clamp for numerical stability
        # precision = torch.stack([torch.diag(torch.ones_like(v)) for v in z])  # precision is the inverse diagonal
        # cov_logdet = torch.tensor([0.0 for _ in range(n_way)], requires_grad=True, device=z.device).unsqueeze(0)

        return centroids.squeeze(1), precision, covariance, cov_logdet


def batched_sherman_morrison_rank_one_inverse(A_diag: T, B_factors: T) -> Tuple[T, T]:
    """takes in a set of rank one covariance factors which need to be inverted (I + B)^-1 = (I + (uu^T)_1 + (uu^T)_2)^-1"""
    Ainv = torch.zeros(A_diag.size(0), A_diag.size(1), A_diag.size(1), device=A_diag.device)
    Ainv = torch.diag_embed(1 / A_diag)

    A_logdet = torch.log(A_diag).sum(dim=-1, keepdim=True)

    for i in range(B_factors.size(1)):
        # computer sherman morrison according to this order of operations which should be the fastest
        # sherman morrisson recursion can be found and referenced in the paper here: https://math.stackexchange.com/questions/17776/inverse-of-the-sum-of-matrices
        # https://timvieira.github.io/blog/post/2021/03/25/fast-rank-one-updates-to-matrix-inverse/

        B = B_factors[:, i : i + 1]
        Bu = torch.bmm(Ainv, B.transpose(1, 2))
        u_tB = torch.bmm(B, Ainv)

        uu_T = torch.bmm(B.transpose(1, 2), B)
        Ainv_uu_T = torch.bmm(Ainv, uu_T)
        # this is the trace operation in einsum notation
        denom = 1 + torch.einsum("bii->b", Ainv_uu_T).view(B.size(0), 1, 1)

        # during the beginning of training, there can be some negative values here, so we need to clamp it
        logdet_update = 1 + B.bmm(Ainv).bmm(B.transpose(1, 2)).squeeze(-1)

        if torch.any(logdet_update < 0):
            print(f"{B=}")
            print(f"{Ainv=}")
            print(f"{A_logdet=}")
            raise ValueError(f"this value shouldn't be less than zero {i=}: {logdet_update=}")

        A_logdet = A_logdet + torch.log(logdet_update)
        Ainv = Ainv - (torch.bmm(Bu, u_tB) / denom)

    return Ainv, A_logdet.squeeze(-1)  # type: ignore


def _batch_lowrank_logdet(W, D, capacitance_tril):
    r"""
    Uses "matrix determinant lemma"::
        log|W @ W.T + D| = log|C| + log|D|,
    where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
    the log determinant.
    """
    return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)


class LowRankCovEncoder(SigmaEncoder):
    def __init__(
        self,
        dim_input: int,
        num_outputs: int,
        dim_output: int,
        dim_hidden: int = 128,
        n_inds: int = 8,
        num_heads: int = 4,
        ln: bool = True,
        p: float = 0.1,
        spectral: bool = False,
        c: float = 1.0,
        pma_type: str = "no-residual"
    ):
        super().__init__(
            dim_input=dim_input,
            num_outputs=num_outputs + 1,
            dim_output=dim_output,
            dim_hidden=dim_hidden,
            n_inds=n_inds,
            num_heads=num_heads,
            ln=ln,
            p=p,
            spectral=spectral,
            c=c,
            pma_type=pma_type
        )

        self.rank = num_outputs

    def forward(self, z: T, n_way: int = 5, k_shot: int = 5, compute_cov: bool = False) -> Tuple[T, T, Optional[T], T]:
        """should return centroids, precision, logdet"""
        centroids, z = self.base(z, n_way=n_way, k_shot=k_shot)

        z, prior_diag = torch.tanh(z[:, :-1]), torch.clamp(torch.sigmoid(z[:, -1]), 0.1)
        # z, prior_diag = z[:, :-1], torch.sigmoid(z[:, -1])

        # N = torch.distributions.LowRankMultivariateNormal(centroids.squeeze(1), z.transpose(1, 2), prior_diag)
        # log_det = _batch_lowrank_logdet(N._unbroadcasted_cov_factor, N._unbroadcasted_cov_diag, N._capacitance_tril)
        # precision = N.precision_matrix
        # covariance = None
        # if compute_cov:
        #     covariance = N.covariance_matrix

        # return centroids.squeeze(1), precision, covariance, log_det

        precision, cov_logdet = batched_sherman_morrison_rank_one_inverse(A_diag=prior_diag, B_factors=z)
        covariance = None
        if compute_cov:
            covariance = torch.stack([torch.diag(v) for v in prior_diag])
            covariance += z.transpose(1, 2).bmm(z)  # / max(k_shot - 1, 1)

        return centroids.squeeze(1), precision, covariance, cov_logdet
