import torch as t
import matplotlib.pyplot as plt
import torch_geometric as TG
import math

def turn_off_axes_ticks(ax):
    ax.set_xticks([]); ax.set_yticks([])

plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 16
})

def get_eig_decomp(eigvs):
    assert eigvs.dim() == 1
    size = eigvs.size(0)
    V = t.empty(size, size, dtype=eigvs.dtype, device=eigvs.device)
    t.nn.init.orthogonal_(V)
    return eigvs, V

def compute_root(C, pow):
    L_, V = t.linalg.eig(C)
    Lpow = L_.real.clamp(min=0.).pow(pow)
    Lpow = t.complex(Lpow, t.zeros_like(Lpow))
    D = V @ t.diag_embed(Lpow) @ t.linalg.inv(V)
    if D.imag.abs().max() > 1e-6:
        print("warning! found large imag parts in D:", D.imag.abs().max().item())
    D = D.real
    return D

def test(ell=None):

    """
    generate some adj, G0 and yyT
    """
    P = 50
    Phalf = P//2

    t.manual_seed(0)
    t.set_default_dtype(t.float64)

    edge_ixs = TG.utils.erdos_renyi_graph(P, 0.1)
    adj_sp = t.sparse.FloatTensor(edge_ixs, t.ones(edge_ixs.size(1)), t.Size([P, P]))
    A = adj_sp.to_dense()
    A = A + t.eye(P) # add self-loops
    Ddiag = A.sum(dim=1) # degrees
    Ddiagroot = Ddiag.pow(-0.5)
    A = A * Ddiagroot[:, None] * Ddiagroot[None, :] # normalize

    G0 = t.randn(P, P) / math.sqrt(P); G0 = G0 @ G0.T + 1e-6 * t.eye(P)
    L0, V0 = t.linalg.eigh(G0)

    yyT = t.block_diag(t.ones(Phalf,Phalf), t.ones(Phalf,Phalf)) / Phalf + 1e-4 * t.eye(P)
    Ly, Vy = t.linalg.eigh(yyT)

    lmbdas = [0.1, 0.3, 0.5, 1.0]
    nrows = len(lmbdas)
    fig, ax = plt.subplots(nrows, 3, figsize=(10, 10))

    graph_data = {
    }

    for ix, lmbda in enumerate(lmbdas):
        Alambda = A * (1-lmbda) + lmbda * t.eye(P)
        La, Va = t.linalg.eigh(Alambda)
        L = 2
        A = Va @ t.diag(La) @ Va.T
        G0 = V0 @ t.diag(L0) @ V0.T
        Ainv = Va @ t.diag(La.reciprocal()) @ Va.T
        G0inv = V0 @ t.diag(L0.reciprocal()) @ V0.T

        Ainv = Va @ t.diag(La.reciprocal()) @ Va.T
        G0inv = V0 @ t.diag(L0.reciprocal()) @ V0.T
        AinvL = Va @ t.diag(La.pow(-L)) @ Va.T
        Aellm1 = (Va @ t.diag(La.pow(ell-1)) @ Va.T)
        Aell = (Va @ t.diag(La.pow(ell)) @ Va.T)
        yyT = Vy @ t.diag(Ly) @ Vy.T

        C = (AinvL @ yyT @ AinvL) @ (Ainv @ G0inv @ Ainv.T)
        D = compute_root(C, ell/(L+1))
        Gell =  Aellm1 @ D @ (A @ G0 @ A.T) @ Aellm1.T


        C = yyT @ G0inv
        D = compute_root(C, ell/(L+1))
        Gell_fc = D @ G0

        Gnngp = Aell @ G0 @ Aell.T

        eps = 1e-6
        Gell = Gell * t.rsqrt(Gell.diag() * Gell.diag().unsqueeze(1) + eps)
        Gell_fc = Gell_fc * t.rsqrt(Gell_fc.diag() * Gell_fc.diag().unsqueeze(1) + eps)
        Gnngp = Gnngp * t.rsqrt(Gnngp.diag() * Gnngp.diag().unsqueeze(1) + eps)
        yyT = yyT * t.rsqrt(yyT.diag() * yyT.diag().unsqueeze(1) + eps)
        graph_data[lmbda] = dict(Gell_graph=Gell, Gnngp=Gnngp, yyT=yyT, Gell_fc=Gell_fc)

    nkernels = 5
    fig, axes = plt.subplots(2, nkernels + 1, figsize=(18, 6), gridspec_kw={'width_ratios': [1]*nkernels + [0.05]})
    vmax = -1000; vmin = 1000
    for typ in ['Gell_graph', 'Gnngp', 'yyT']:
        vmax = max([graph_data[lmbda][typ].max().item() for lmbda in lmbdas])
        vmin = min([graph_data[lmbda][typ].min().item() for lmbda in lmbdas])

    im_K_list = []
    im_yyT_list = []

    get_label = {'Gell_graph': 'DKM', 'Gnngp': 'NNGP'}

    for row, typ in enumerate(['Gell_graph', 'Gnngp']):
        axes[row, 0].set_ylabel(get_label[typ])
        for col, lmbda in enumerate(lmbdas):
            data = graph_data[lmbda][typ].detach().numpy()
            im = axes[row, col].imshow(data, cmap='viridis', interpolation='none', vmin=vmin, vmax=vmax)
            im_K_list.append(im)
            turn_off_axes_ticks(axes[row, col])

            if row == 0:
                axes[row, col].set_title(rf'$\lambda={lmbda}$')

            # Append yyT as last column
            if col == len(lmbdas) - 1:
                yyT_data = graph_data[lmbda]['yyT'].detach().numpy()
                im_yyT = axes[row, -2].imshow(yyT_data, cmap='viridis', interpolation='none', vmin=vmin, vmax=vmax)
                im_yyT_list.append(im_yyT)
                turn_off_axes_ticks(axes[row, -2])
                if row == 0:
                    axes[row, -2].set_title(r'$\mathbf{Y}\mathbf{Y}^T$')

    # Add colorbars
    fig.colorbar(im_K_list[-1], cax=axes[0, -1], shrink=0.8)
    fig.colorbar(im_yyT_list[-1], cax=axes[1, -1], shrink=0.8)

    plt.tight_layout()
    import os
    if not os.path.exists('lineardkm/artefacts'):
        os.mkdir('lineardkm/artefacts')
    plt.savefig(f'lineardkm/artefacts/alignment_ell_{ell}.pdf')

def expected_homophily_ratio():
    hs = []
    for _ in range(10000):
        P = 50
        edge_ixs = TG.utils.erdos_renyi_graph(P, 0.5)
        adj_sp = t.sparse.FloatTensor(edge_ixs, t.ones(edge_ixs.size(1)), t.Size([P, P]))
        A = adj_sp.to_dense()
        A = A.tril()
        h = (A[:25,:25].sum() + A[25:,25:].sum()) / A.sum()
        hs.append(h.item())
    hs = t.tensor(hs)
    print(hs.mean().item(), "\\pm", hs.std().item())


if __name__ == '__main__':
    expected_homophily_ratio()
    # test(ell=2)