"""
verifying the linear gcdkm solution by training a wide BNN on cora with langevin sampling
"""
import os; import sys
sys.path.append(os.getcwd())
import dataset.graph as GD

import torch as t
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.optim import Adam
from dkm.util import retrying_cholesky


torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

class LinearGraphDKM(nn.Module):
    def __init__(self, P=None, adj_sp=None, dof=1., num_layers=2):
        super(LinearGraphDKM, self).__init__()
        self.dof = dof
        self.adj_sp = adj_sp
        self.P = P
        self.num_layers = num_layers
    def initialize_Gs(self, Gs: list[t.tensor]):
        Vs = []
        for G in Gs:
            Vs.append(retrying_cholesky(G))
        self.Vs = nn.ParameterList(Vs)
    def initialize(self, G0: t.tensor, mode='nngp') -> list:
        G = G0
        Vs = []
        for i in range(self.num_layers):
            K = self._graph_mixup(G)
            if mode == 'nngp':
                Gnew = K
            elif mode == 'randn':
                V = t.randn(G.size(0), G.size(0), device=G.device, dtype=G.dtype)
                Gnew = V @ V.T / V.size(0)
            else:
                raise ValueError(f"mode={mode} not recognized")
            Vs.append(retrying_cholesky(Gnew))
            G = Gnew
        self.Vs = nn.ParameterList(Vs)
    def _graph_mixup(self, K):
        return  t.sparse.mm(self.adj_sp, t.sparse.mm(self.adj_sp, K).T).T
    def kl_div(self, G, K):
        cholK = retrying_cholesky(K)
        ## KinvG = K^-1 G, but add some jitter to aid stability of the logdet
        KinvG = t.cholesky_solve(G, cholK) + 1e-10 * t.eye(G.size(0), device=G.device, dtype=G.dtype)
        res = 0.5 *(-t.logdet(KinvG) + t.trace(KinvG) - K.size(0))
        return res
    def forward(self, G0, YYT):
        assert G0.size(0) == G0.size(-1) == self.P == YYT.size(0)
        loss = 0.
        K = G0
        for i in range(self.num_layers):
            V = self.Vs[i].tril() # V == cholG
            G = V @ V.T
            loss += self.dof * self.kl_div(G, self._graph_mixup(K))
            K = G

        loss += self.kl_div(YYT, self._graph_mixup(K))
        return loss


def train_gdkm(G0, adj_sp, YYT):
    """returns final layer representations"""
    device = 'cuda'
    nepoch = 10000

    adj_sp = adj_sp.to(device, dtype=t.float64)
    G0 = G0.to(device, dtype=t.float64)
    YYT = YYT.to(device, dtype=t.float64)

    model = LinearGraphDKM(P=G0.size(0), dof=1., num_layers=2, adj_sp=adj_sp)
    model = model.to(device, dtype=t.float64)

    model.initialize(G0, mode='randn')

    pretrained_Vs = [V.clone().detach().cpu() for V in model.Vs]
    pretrained_Gs = [V @ V.T for V in pretrained_Vs]

    optimizer = Adam(model.parameters(), lr=1.)

    def get_lr(epoch, a=1e-1, b=10., gamma=0.7):
        """follow decayed polynomial schedule"""
        return a * (b + epoch)**(-gamma)
    scheduler = t.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_lr)

    losses = []
    for epochi in range(nepoch):
        optimizer.zero_grad()
        loss = model(G0, YYT)
        (loss).backward()
        losses.append(loss.clone().detach().cpu().item())
        if epochi % 100 == 0:
            print(f"{epochi+1}/{nepoch}, loss={loss.item():.4f}, lr={optimizer.param_groups[0]['lr']:.4f}")
        optimizer.step()
        scheduler.step()

    trained_Vs = [V.clone().detach().cpu() for V in model.Vs]
    trained_Gs = [V @ V.T for V in trained_Vs]
    return pretrained_Gs, trained_Gs, losses

def train_on_cora():

    dataset = GD.get_dataset("cora", split="public")

    P = 200
    random_ixs = t.randperm(dataset.X.size(0))[:P]

    X = dataset.X[random_ixs].double()
    y = dataset.y[random_ixs]
    A = dataset.adj_sp.to_dense()
    A = (A != 0.).double()
    A = A[random_ixs,:][:, random_ixs]
    A.fill_diagonal_(0.)

    def dense_to_sparse(A):
        """converts dense adjacency tensor A to edge_index"""
        edge_index = A.nonzero().t()
        edge_attr = A[edge_index[0], edge_index[1]]
        assert (edge_attr == 1.).all(), "edge_attr should be 1"
        return edge_index
    edge_ixs = dense_to_sparse(A)
    adj_sp = GD.Ahat_interp_id(edge_ixs, num_nodes=P, lmbda=0.5).double()


    tt_diag = (X * X).sum(-1).unsqueeze(-1)
    X = X * t.rsqrt(tt_diag)



    sortixs = t.argsort(y)
    Y = F.one_hot(y, num_classes=dataset.num_classes).double()
    YYT = Y @ Y.T / Y.size(-1) + 1e-4 * t.eye(Y.size(0), device=Y.device, dtype=t.float64)
    G0 = X @ X.T / X.size(-1) + 1e-4 * t.eye(X.size(0), device=X.device, dtype=t.float64)

    analytic_Gs, analytic_loss = linear_gcdkm_analytic(G0, adj_sp, YYT)
    pretrained_Gs, trained_Gs, losses = train_gdkm(G0, adj_sp, YYT)

    analytic_Gs = [G[sortixs, :][:, sortixs].detach().cpu().numpy() for G in analytic_Gs]
    pretrained_Gs = [G[sortixs, :][:, sortixs].detach().cpu().numpy() for G in pretrained_Gs]
    trained_Gs = [G[sortixs, :][:, sortixs].detach().cpu().numpy() for G in trained_Gs]

    analytic_loss = analytic_loss.detach().cpu().item()
    return (G0[sortixs,:][:,sortixs], YYT[sortixs,:][:,sortixs]),\
           (analytic_Gs, analytic_loss),\
           (pretrained_Gs, trained_Gs, losses)

def linear_gcdkm_analytic(G0, adj_sp, YYT):
    def linear_gcdkm(A, ell, L, Gfinal=None):
        """
        calculates G^\ell, for the linear GCDKM

        this is defined as

        G^\ell = A^{ell-1} E A^{ell-1}
        where E = C^{\ell/(L+1)} A G0 A,
              C = (A^{-L} Gfinal A^{-L}) (A G0 A)^-1
        """
        G0inv = t.linalg.pinv(G0)
        Ainv = t.linalg.pinv(A)
        AinvL = t.linalg.matrix_power(Ainv, L)
        C = (AinvL @ Gfinal @ AinvL) @ (Ainv @ G0inv @ Ainv.T)

        ## compute C^(ell/(L+1))
        L_, V = t.linalg.eig(C)
        Lpow = L_.real.clamp(min=0.).pow(ell / (L+1))
        Lpow = t.complex(Lpow, t.zeros_like(Lpow))
        D = V @ t.diag_embed(Lpow) @ t.linalg.inv(V)
        if D.imag.abs().max() > 1e-6:
            print("warning! found large imag parts in D:", D.imag.abs().max().item())
        D = D.real

        E = D @ (A @ G0 @ A.T)
        Aellm1 = t.linalg.matrix_power(A, ell-1)
        res =  Aellm1 @ E @ Aellm1.T
        return res
    A = adj_sp.to_dense()
    Kdkm = linear_gcdkm(A, 2, 2, Gfinal=YYT)

    G1 = linear_gcdkm(A, 1, 2, Gfinal=YYT)
    G2 = linear_gcdkm(A, 2, 2, Gfinal=YYT)
    trained_Gs = [G1, G2]

    model = LinearGraphDKM(P=G0.size(0), dof=1., num_layers=2, adj_sp=adj_sp)
    model = model.to(dtype=t.float64)
    model.initialize_Gs([G1,G2])
    loss = model(G0, YYT)

    return trained_Gs, loss


def mk_artefacts():
    import pickle
    from pathlib import Path
    import matplotlib.pyplot as plt
    artefacts_path = Path("./lineardkm/artefacts")
    artefacts_path.mkdir(parents=True, exist_ok=True)
    dump_path = artefacts_path / "dump.pkl"
    if not dump_path.exists():
        artefacts_path.mkdir(parents=True, exist_ok=True)
        res = train_on_cora()
        with open(dump_path, "wb") as f:
            pickle.dump(res, f)
    else:
        ## retrieve pickle file
        with open(dump_path, "rb") as f:
            res = pickle.load(f)
    (G0, YYT), (analytic_Gs, analytic_loss), (pretrained_Gs, trained_Gs, losses) = res


    analytic_Gs = [t.tensor(G) for G in analytic_Gs]
    trained_Gs = [t.tensor(G) for G in trained_Gs]
    pretrained_Gs = [t.tensor(G) for G in pretrained_Gs]

    trained_Gs = [G * t.rsqrt(G.diag().unsqueeze(-1) * G.diag()) for G in trained_Gs]
    analytic_Gs = [G * t.rsqrt(G.diag().unsqueeze(-1) * G.diag()) for G in analytic_Gs]
    pretrained_Gs = [G * t.rsqrt(G.diag().unsqueeze(-1) * G.diag()) for G in pretrained_Gs]

    G0 = G0 * t.rsqrt(G0.diag().unsqueeze(-1) * G0.diag())
    YYT = YYT * t.rsqrt(YYT.diag().unsqueeze(-1) * YYT.diag())

    vmax = max([G0.max(), YYT.max()] + [G.max() for G in analytic_Gs] + [G.max() for G in pretrained_Gs] + [G.max() for G in trained_Gs])
    vmin = min([G0.min(), YYT.min()] + [G.min() for G in analytic_Gs] + [G.min() for G in pretrained_Gs] + [G.min() for G in trained_Gs])
    imshow_kwargs = dict(cmap='viridis', interpolation='none', vmin=vmin, vmax=vmax)

    fig, axes = plt.subplots(1, 1, figsize=(6, 3))
    zoom = 2000
    axes.plot(list(range(zoom, len(losses))), losses[zoom:])
    axes.axhline(analytic_loss, color='red', linestyle='--')
    axes.set_xlabel("train iter")
    axes.set_ylabel("loss")
    plt.tight_layout()
    plt.savefig(artefacts_path / "losses.pdf")
    plt.clf()

    fig, axes = plt.subplots(3, 5, figsize=(8, 6), gridspec_kw=dict(width_ratios=[1,1,1,1,0.05]))

    for i in range(3):
        for j in range(4):
            axes[i, j].tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)

    axes[0, 0].axis('off')
    axes[2, 0].axis('off')

    axes[0, 3].axis('off')
    axes[2, 3].axis('off')
    axes[0, 4].axis('off')
    axes[2, 4].axis('off')

    axes[0,2].set_title(r"$\mathbf{G}^2$")
    axes[1,0].set_title(r"$\mathbf{X}\mathbf{X}^T/N_{\text{in}}$")
    axes[1,3].set_title(r"$\mathbf{Y}\mathbf{Y}^T/N_{\text{out}}$")
    axes[1, 1].set_ylabel("analytic")
    axes[0, 1].set_title(r"$\mathbf{G}^1$")
    axes[0, 1].set_ylabel("initialization")
    axes[2, 1].set_ylabel("trained")


    im = axes[1, 0].imshow(G0, **imshow_kwargs)
    im = axes[1, 3].imshow(YYT, **imshow_kwargs)
    im = axes[1, 1].imshow(analytic_Gs[0], **imshow_kwargs)
    im = axes[1, 2].imshow(analytic_Gs[1], **imshow_kwargs)
    im = axes[0, 1].imshow(pretrained_Gs[0], **imshow_kwargs)
    im = axes[0, 2].imshow(pretrained_Gs[1], **imshow_kwargs)
    im = axes[2, 1].imshow(trained_Gs[0], **imshow_kwargs)
    im = axes[2, 2].imshow(trained_Gs[1], **imshow_kwargs)

    plt.colorbar(im, cax=axes[1, 4], shrink=0.5)
    plt.tight_layout()
    plt.savefig(artefacts_path / "Gs.pdf")

if __name__ == '__main__':
    mk_artefacts()