import math
from typing import Tuple

import torch
import torch as t
import torch.nn as nn
import torch.nn.functional as F

from .util import ( kernel_to_ldl_param,
                    ldl_param_to_chol,
                    init_eye_ldl,
                    eye_like,
                    retrying_cholesky,
                    sparse_eye_like,
                    add_jitter )

class Id(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x, *args, **kwargs):
        return x

class StructGram:
    """
    a class representing ii, ti, t blocks in a Gram matrix, but the underlying
    representations may or may not actually be these blocks

    representations must always be flat
    """
    def __init__(self,
                 ii=None,
                 ti=None,
                 tt_diag=None,
                 i=None,
                 t=None,
                 check=True):
        """disable checking for faster perf"""
        if check:
            # basic checks
            assert i is not None or ii is not None, "Either i or ii must be provided"
            assert t is not None or ti is not None, "Either t or ti must be provided"
            assert i is None or i.dim() == 2
            assert t is None or t.dim() == 2
            assert ii is None or ii.dim() == 2
            assert ti is None or ti.dim() == 2
            assert tt_diag is None or tt_diag.dim() == 1
            # pos def / consistency checks

            assert i is None or not i.isnan().any(), "i must not be NaN"
            assert t is None or not t.isnan().any(), "t must not be NaN"
            assert ii is None or not ii.isnan().any(), "ii must not be NaN"
            assert ti is None or not ti.isnan().any(), "ti must not be NaN"
            assert tt_diag is None or not tt_diag.isnan().any(), "tt_diag must not be NaN"

            assert i is None or torch.allclose(i, i.tril(), rtol=1e-8, atol=1e-8), "i must be lower triangular"
            assert ii is None or ii.diag().min() >= 0, "ii is not PSD"
            assert tt_diag is None or tt_diag.min() >= 0., "tt_diag is not PSD"
        self._i = i; self._t = t
        self._ii = ii; self._ti = ti
        self._tt_diag = tt_diag

    @property
    def fi(self):
        if self._i is None:
            self._i = retrying_cholesky(self._ii)
        return self._i
    @property
    def ft(self):
        if self._t is None:
            self._t = t.linalg.solve_triangular(self.fi.T, self.ti, left=False, upper=True)
        return self._t
    @property
    def ii(self):
        if self._ii is None:
            self._ii = self._i @ self._i.T
        return self._ii
    @property
    def ti(self):
        if self._ti is None:
            self._ti = self._t @ self._i.T
        return self._ti
    @property
    def tt_diag(self):
        if self._tt_diag is None:
            self._tt_diag = self.ft.square().sum(-1)
        return self._tt_diag
    def calc_ii_diag(self):
        if self._i is None:
            return self._ii.diag()
        else:
            return self.fi.square().sum(-1)
    def add(self, other: "StructGram", w=1., v=1.) -> "StructGram":
        return StructGram(
            ii=self.ii * w + other.ii * v,
            ti=self.ti * w + other.ti * v
        )
    def mul(self, w: float) -> "StructGram":
        return StructGram(
            ii=self.ii * w,
            ti=self.ti * w
        )
    def calc_tt(self):
        """expect this to be slow!"""
        return self.ft @ self.ft.T
    def calc_full_kernel(self):
        """expect this to be slow!"""
        ii = self.ii
        ti = self.ti
        tt = self.calc_tt()
        top = t.cat([ii, ti.T], dim=1)
        bottom = t.cat([ti, tt], dim=1)
        return t.cat([top, bottom], dim=0)


class Gram(nn.Module):
    def __init__(self, Pi, dof, obj_mode='exact', **kwargs):
        super().__init__()
        self.register_buffer('dof', t.tensor(dof))
        trilL, logD = init_eye_ldl(Pi)
        self.trilL = nn.Parameter(trilL)
        self.logD = nn.Parameter(logD)
        self.obj_mode = obj_mode
        self.Pi = Pi
    def _forward(self, Kfi, Kft) -> Tuple[t.tensor]:
        L = ldl_param_to_chol(self.trilL, self.logD) # here L = chol(Kii^-1 @ Gii)
        Fi = (Kfi @ L).tril()
        if self.dof == 0.:
            obj = self.dof
        else:
            tr_term = t.sum(L * L)
            logdet_term = -2*self.logD.sum()
            obj = -self.dof/2. * (tr_term + logdet_term - self.Pi)

        # forward eqns
        Ft = Kft @ L
        return Fi, Ft, obj
    def forward(self, K: StructGram) -> StructGram:
        fi, ft, obj = self._forward(K.fi, K.ft)
        return StructGram(i=fi, t=ft), obj

class Output(nn.Module):
    def __init__(self, P, out_features, init_mu=None, mc_samples=1000, returns='samples',
                       initialization='nngp', learn_mu = True, chunk_size=None,
                       **kwargs
                       ):
        """
        output layer for a deep kernel machine

        if [chunk_size] is an integer, we will calculate the mc estimator in batches of that size,
        this can save a little bit of memory... but the best way to save memory is to make
        [mc_samples] smaller.
        """
        super().__init__()
        assert initialization in ['randn', 'nngp'], f"initialization must be one of ['randn', 'nngp'], got {initialization}"
        self.out_features = out_features
        self.Pi = P
        if init_mu is None:
            mu_p = t.randn(P, out_features) / out_features
        else:
            mu_p = init_mu
            mu_p.shape == (P, out_features)

        if learn_mu:
            self.mu_p = nn.Parameter(mu_p)
        else:
            self.register_buffer('mu_p', mu_p)
        trilL, logD = init_eye_ldl(P)
        self.trilL = nn.Parameter(trilL)
        self.logD = nn.Parameter(logD)


        self.initialized = False

        self.mc_samples = mc_samples
        self.returns = returns
        assert self.returns in ['samples']
        self.chunk_size = chunk_size

    def _forward(self, Kfi: t.tensor, Kft: t.tensor, labels: t.tensor, ixs) -> Tuple[t.tensor]:
        chol_Sigma = ldl_param_to_chol(self.trilL, self.logD)
        mu = self.mu_p
        #  calc obj first !
        tr_term = t.sum(chol_Sigma * chol_Sigma)
        logdet_term = -2*self.logD.sum()
        mumu_term = t.sum(mu * mu)
        obj = -0.5 * ((tr_term + logdet_term)*self.out_features + mumu_term)

        # pred and lls
        Hi = Kfi
        if ixs is None:
            Ht = Kft
        else:
            Ht = Kft[ixs]
            labels = labels[ixs]

        """
        all we really do here is sample logits for the different classes and estimate
        the cross entropy/log likelihood

        it's slightly complicated because we calculate the mc estimator in batches
        (to help with the memory cost)
        """
        mc_batch_size = self.mc_samples if self.chunk_size is None else self.chunk_size
        niter = self.mc_samples//mc_batch_size
        if labels is not None:
            logprobs_acc = 0.
            ll_acc = 0.
            for _ in range(niter):
                std_samples = t.randn((mc_batch_size, *mu.shape), device=Hi.device).detach()
                ws = mu + t.einsum('ij,mjk->mik', chol_Sigma, std_samples)  # (MC, Pi, Nout)

                f_samples = t.einsum('si,mij->msj', Ht, ws) # (N, MC, Nout)
                mean_logprob = t.logsumexp(F.log_softmax(f_samples, dim=-1), dim=0) - t.log(t.tensor(f_samples.shape[0], device=Hi.device)).detach()

                cross_ent = F.cross_entropy(f_samples.permute(1, 2, 0),#(batch, class, mc)
                                            labels.unsqueeze(-1).repeat(1, mc_batch_size), #(batch, mc)
                                            reduction='none').mean(dim=-1)
                ll = -cross_ent

                ll_acc += ll
                logprobs_acc += mean_logprob
            mean_logprobs = logprobs_acc / niter
            lls = ll_acc / niter
            return lls, mean_logprobs, obj
        else:
            logprobs_acc = 0.
            for _ in range(niter):
                std_samples = t.randn((mc_batch_size, *mu.shape), device=Hi.device)
                ws = mu + t.einsum('ij,mjk->mik', chol_Sigma, std_samples).detach()  # (MC, Pi, Nout)

                f_samples = t.einsum('si,mij->msj', Ht, ws).detach() # (N, MC, Nout)
                mean_logprob = t.logsumexp(F.log_softmax(f_samples, dim=-1), dim=0) - t.log(t.tensor(f_samples.shape[0], device=Hi.device)).detach()

                logprobs_acc += mean_logprob
            mean_logprobs = logprobs_acc / niter
            lls = None
        return lls, mean_logprobs, obj
    def forward(self, K, labels=None, ixs=None, **kwargs):
        if not self.initialized:
            self.initialized = True
            with t.no_grad():
                trilL, logD = kernel_to_ldl_param(K.ii + eye_like(K.ii))
                self.trilL.data = trilL
                self.logD.data = logD
        return self._forward(K.fi, K.ft, labels, ixs=ixs)

class JitteryK(nn.Module):
    """jitter and bias gram blocks"""
    def __init__(self, mult_eps=1e-4, abs_eps=1e-4, eps=1e-6, **kwargs):
        super().__init__()
        self.register_buffer('mult_eps', t.tensor(mult_eps))
        self.register_buffer('abs_eps', t.tensor(abs_eps))
        self.register_buffer('eps', t.tensor(eps))
    def stabilise(self, Kii):
        I = eye_like(Kii)
        return Kii*(1.+self.mult_eps*I) + self.abs_eps*I

def _relu_kernel(ti, i, t, mult_eps, abs_eps, eps):
    t_i = (i * (1 + mult_eps) + abs_eps) * (t[..., None] * (1 + mult_eps) + abs_eps)
    theta = (ti * t_i.rsqrt()).clamp(-1+eps, 1-eps).acos()
    t_i_sin_theta = (t_i - ti ** 2).clamp(min=eps).sqrt()
    K = (t_i_sin_theta + (torch.pi - theta) * ti) / torch.pi
    return K
class ReluKernel(JitteryK):
    """
    compute K = K_arccos(G), but also add some jitter, where

    K_ij = sqrt(Gii . Gjj) / pi * [sin(theta_ij) + (pi - theta_ij) * cos(theta_ij)]
    theta_ij = arccos(Gij / sqrt(Gii . Gjj))
    """
    def _forward(self, Gii: t.tensor, Gti: t.tensor, Gt) -> Tuple[t.tensor]:
        diag_ii = Gii.diagonal(dim1=-1, dim2=-2)
        ii = _relu_kernel(Gii, diag_ii, diag_ii, self.mult_eps, self.abs_eps, self.eps)

        ## force diag of ii to be the original diag
        iiv = ii.view(*ii.shape[:-2], -1)
        iiv[::(ii.shape[-1] + 1)] = diag_ii * (1 + self.mult_eps) + self.abs_eps
        ii = iiv.view(*ii.shape)

        ti = _relu_kernel(Gti, diag_ii, Gt, self.mult_eps, self.abs_eps, self.eps)
        return ii, ti
    def forward(self, G: StructGram) -> StructGram:
        ii, ti = self._forward(G.ii, G.ti, G.tt_diag)
        return StructGram(ii=ii, ti=ti)

class F2G(nn.Module):
    def __init__(self, Pi=None, N=None, Xi=None, do_learn_Xi=False, **kwargs):
        super().__init__()
        self.N = N
        self.Pi = Pi
        if Xi is None:
            Xi = t.randn((Pi, N)) / math.sqrt(N)
        if do_learn_Xi:
            self.Xi = nn.Parameter(Xi)
        else:
            self.register_buffer("Xi", Xi)
    def _forward(self, Xt: t.tensor) -> Tuple[t.tensor]:
        i = self.Xi # (Pi, N)
        ii = i @ i.T # (Pi, Pi)
        ti = Xt @ i.T
        ii = add_jitter(ii, eps=1e-4) + eye_like(ii) * 1e-4 * ii.trace()
        return ii, ti
    def forward(self, Xt: t.tensor) -> Tuple[StructGram, t.tensor]:
        ii, ti = self._forward(Xt)
        return StructGram(ii=ii, ti=ti)


class GraphMixup(nn.Module):
    def __init__(self, Pi=None, mode='fixed-indep', lmbda=0.):
        """
        cases(mode):
          fixed-indep == ind points are independent and unconnected to each other _and_ test points
          fixed-full == ind points are initially sampled from the test points. they are connected to each other according to the original graph structure
          none == no mixup (equivalent to fully-connected network)
        """
        super().__init__()
        self.mode = mode
        self.lmbda = lmbda; self._eye = None
    def _residual_adj(self, adj):
        """
        returns A = (1-lmbda) * adj_sp + lmbda * eye
        """
        assert 0. <= self.lmbda <= 1., f"convex combo of adj and eye is required! got lmbda={self.lmbda}"
        if self.lmbda <= 0.: return adj
        if self._eye is None:
            self._eye = sparse_eye_like(adj)
        return (1. - self.lmbda) * adj + self.lmbda * self._eye
    def forward(self, G: StructGram, adj_sp=None):
        if adj_sp is None and self.mode == 'none':
            raise ValueError(f"adj_sp must be provided if mode = '{self.mode}'")
        adj_sp = self._residual_adj(adj_sp) # residual with eye, if lmbda > 0
        if self.mode == 'none': return G
        if self.mode == 'fixed-indep':
            return StructGram(i=G.fi, t=t.sparse.mm(adj_sp, G.ft))
        if self.mode == 'learned-indep':
            Aind = self.C @ self.C.T
            Aind = Aind + 1e-4 * eye_like(Aind)
            ii = Aind @ G.ii @ Aind.T
            ti = t.sparse.mm(adj_sp, G.ti) @ Aind.T
            return StructGram(ii=ii, ti=ti)
        if self.mode == 'fixed-full':
            F = t.cat((G.fi, G.ft), dim=0)
            assert F.size(0) == adj_sp.size(0)
            F = t.sparse.mm(adj_sp, F)
            Fi = F[:G.fi.size(0)]; Ft = F[G.fi.size(0):]
            ii = Fi @ Fi.T
            ii = ii + 1e-4*eye_like(Fi)
            return StructGram(ii=ii, ti=Ft @ Fi.T)

class Center(nn.Module):
    def __init__(self, Pi=None, learned=False, mode='id', **kwargs):
        super().__init__()
        self.mode = mode
        self.learned = learned
        if learned:
            self.gamma = nn.Parameter(t.ones(Pi))
            self.beta = nn.Parameter(t.zeros(Pi))
    def _forward(self, Kft: t.tensor) -> Tuple[t.tensor]:
        if self.mode == 'batch':
            ft = (Kft - Kft.mean(0, keepdims=True))
        elif self.mode == 'layer':
            ft =  (Kft - Kft.mean(1, keepdims=True))
        elif self.mode == 'id':
            ft = Kft
        if self.learned:
            ft = ft * self.gamma + self.beta
        return ft
    def forward(self, K: StructGram) -> StructGram:
        if self.mode == 'id' and not self.learned: return K ## do nothing
        ft = self._forward(K.ft)
        return StructGram(i=K.fi, t=ft)

kernel_dict = {'relu': ReluKernel, 'id': Id}

class MeanPooling(nn.Module):
    def forward(self, G: StructGram, Ss: list[int]=None) -> StructGram:
        Gti_new = t.stack([x.mean(0) for x in t.split(G.ti, Ss)])
        Gt_new = t.stack([x.mean(0) for x in t.split(G.t, Ss)])
        return StructGram(G.ii, Gti_new, Gt_new)


class ResGraphDKM(nn.Module):
    def __init__(self, Pi=None, dof=None, Nin=None, Nout=None, num_layers=None,
                       feat_to_gram_params=dict(),
                       output_params=dict(),
                       kernel='relu',
                       gram_params=dict(),
                       center=None,
                       center_learned=False,
                       do_checkpointing=False,
                       mixup_params=None,
                       ):
        super().__init__()
        self.num_layers = num_layers
        self.is_nngp = dof == float('inf') or dof == 'inf'
        self.f2g = F2G(Pi=Pi, N=Nin, **feat_to_gram_params)
        self.k = kernel_dict[kernel](Pi=Pi)
        self.mixup = GraphMixup(Pi=Pi, **mixup_params)
        if not self.is_nngp:
            self.grams_1 = nn.ModuleList([Gram(Pi, dof, **gram_params) for _ in range(num_layers)])
            self.grams_2 = nn.ModuleList([Gram(Pi, dof, **gram_params) for _ in range(num_layers)])
            self.g0 = Gram(Pi, dof, **gram_params)

        """bn"""
        bn = Center
        bn_params = dict(mode=center, learned=center_learned)

        self.bns = nn.ModuleList([bn(Pi=Pi,**bn_params) for _ in range(num_layers)])
        self.output_bn = bn(Pi=Pi, **bn_params)

        self.output = Output(Pi, Nout, **output_params)
        self.Pi = Pi
        self.do_checkpointing = do_checkpointing

    def _forward(self, X, adj_sp, labels, ti_dims, ixs, returns):
        reg = 0.
        x = self.f2g(X)
        if not self.is_nngp:
            x, o = self.g0(x); reg += o
        x = self.k(x)
        for i in range(self.num_layers):
            # norm + mixup
            x_res = self.bns[i](x)
            x_res = self.mixup(x_res, adj_sp=adj_sp)

            # FC layers
            if not self.is_nngp:
                x_res, o = self.grams_1[i](x_res); reg += o
            x_res = self.k(x_res)
            if not self.is_nngp:
                x_res, o = self.grams_2[i](x_res); reg += o
            x = x.add(x_res, w=0.5, v=0.5)
        x = self.output_bn(x)
        x = self.k(x)

        if returns == 'final-kernel': return x
        else:
            lls, mean_logprob, o = self.output(x, ti_dims=ti_dims, labels=labels, ixs=ixs); reg += o
            return lls, mean_logprob, reg
    def forward(self, X, adj_sp=None, Ss=None, labels=None, ixs=None, returns='samples'):
        assert returns in ['samples', 'final-kernel']
        ti_dims = [len(Ss), self.Pi] if Ss is not None else X.size()
        if self.do_checkpointing:
            return t.utils.checkpoint.checkpoint(self._forward, X, adj_sp, labels, ti_dims, ixs, returns, use_reentrant=False)
        else:
            return self._forward(X, adj_sp, labels, ti_dims, ixs, returns)

class KipfGraphDKM(nn.Module):
    def __init__(self, Pi=None, dof=None, Nin=None, Nout=None, num_layers=None,
                       feat_to_gram_params=dict(),
                       output_params=dict(),
                       kernel='relu',
                       gram_params=dict(),
                       mixup_params=dict(),
                       center=None,
                       center_learned=False,
                       do_checkpointing=False,
                       residual=False,
                       **kwargs):
        super().__init__()
        self.num_layers = num_layers
        self.do_checkpointing = do_checkpointing
        self.is_nngp = float(dof) == float('inf')
        self.k = kernel_dict[kernel]()
        self.mixup = GraphMixup(Pi=Pi, **mixup_params)
        self.k = kernel_dict[kernel]()
        if not self.is_nngp:
            self.gs = nn.ModuleList([Gram(Pi, dof, **gram_params) for _ in range(num_layers)])
            self.g0 = Gram(Pi, dof, **gram_params)

        self.f2g = F2G(Pi=Pi, N=Nin, **feat_to_gram_params)

        bn = Center
        bn_params = dict(mode=center, learned=center_learned)

        self.bns = nn.ModuleList([bn(Pi=Pi, **bn_params) for _ in range(num_layers)])
        self.output = Output(Pi, Nout, **output_params)
        self.residual = residual
    def _block(self, i, x, adj_sp):
        x = self.mixup(x, adj_sp=adj_sp)
        if not self.is_nngp:
            x, o = self.gs[i](x)
        else: o = 0.
        x = self.bns[i](x)
        x = self.k(x)
        return x, o
    def _forward(self, X, adj_sp, labels, ti_dims, ixs, returns):
        reg = 0
        # input
        x = self.f2g(X)
        if not self.is_nngp:
            x, o = self.g0(x); reg += o
        x = self.k(x)

        # hidden
        for i in range(self.num_layers):
            x_res, o = self._block(i, x, adj_sp); reg += o
            if self.residual:
                x = x.add(x_res, w=0.5, v=0.5)
            else:
                x = x_res

        if returns == 'final-kernel': return x
        else:
            lls, mean_logprob, o = self.output(x, ti_dims=ti_dims, labels=labels, ixs=ixs); reg += o
            return lls, mean_logprob, reg
    def forward(self, X, adj_sp=None, labels=None, ixs=None, returns='samples'):
        ti_dims = X.size()
        if self.do_checkpointing:
            return t.utils.checkpoint.checkpoint(self._forward, X, adj_sp, labels, ti_dims, ixs, returns, use_reentrant=False)
        else:
            return self._forward(X, adj_sp, labels, ti_dims, ixs, returns)