import torch
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt

from tqdm import tqdm
from scipy.stats import multivariate_normal as mvn
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA


def sbm_prob_mat(B, cluster_sizes):
    L = len(cluster_sizes)
    clusts = [0] * (L+1)
    for i in range(L):
        clusts[i+1] = cluster_sizes[i] + clusts[i]
        
    n = clusts[-1]    
    P = torch.zeros((n,n))
    for i in range(L):
        for j in range(L):
            P[clusts[i]:clusts[i+1], clusts[j]:clusts[j+1]] = B[i,j]
            
    return P


def ier_sampler(P):
    device = P.device
    n = P.shape[0]
    return lambda: (torch.rand((n,n), device=device) < P).float()


def er_phi():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    k_max = 3
    n_list = [300, 3_000, 30_000]
    d_list = [2, 4, 16]
    n0 = n_list[0]
    nfn_list = []


    phi_coll = [torch.zeros((3, n_list[i])).to(device) for i in range(3)]
    phi_mean_coll = [torch.zeros((3, n_list[i])).to(device) for i in range(3)]

    n_instances = 250
    for i,n_nodei in enumerate(n_list):
        X = torch.ones(n_nodei).to(device)
        for j,nu_j in enumerate(d_list):
            Bj = torch.ones((1,1)) * nu_j/n_nodei
            Pj = sbm_prob_mat(Bj, torch.tensor([n_nodei]))
            Aj_gen = ier_sampler(Pj.to(device))
            for m in tqdm(range(n_instances)):
                Aj = Aj_gen().fill_diagonal_(0) / nu_j

                phi_coll[i][j] = X
                for _ in range(k_max):
                    phi_coll[i][j] = Aj @ phi_coll[i][j]
                phi_mean_coll[i][j] = phi_mean_coll[i][j]*(m/(m+1)) + phi_coll[i][j]/(m+1)
                
            phi_coll[i][j] = phi_coll[i][j] * np.sqrt(nu_j)
            phi_mean_coll[i][j] = phi_mean_coll[i][j] * np.sqrt(nu_j)
        phi_coll[i] = phi_coll[i].cpu()
        phi_mean_coll[i] = phi_mean_coll[i].cpu()

    fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,5))

    binsij = [[np.linspace(-2,4.75,13), np.linspace(-2,4.75,28), np.linspace(-2,4.75,42)],
              [np.linspace(-2.25,4.5,13), np.linspace(-2.25,4.5,28), np.linspace(-2.25,4.5,42)], #42 # 60
              [np.linspace(-3,3.25,13), np.linspace(-3,3.25,28), np.linspace(-3,3.25,42)]]

    for i in range(3):
        for j in range(3):
            cent_feats =  (phi_coll[i][j] - phi_mean_coll[i][j]).cpu().numpy()
            axs[j][i].hist(cent_feats, density=True,bins=binsij[j][i], alpha=0.7, label=rf'$\nu_n = {d_list[j]}$', color=f'C{j}')
            axs[j][i].legend()
            if j == 1:
                axs[j][i].set_ylim((0,0.41))
            if j == 2:
                axs[j][i].set_xticks([-3, -1.5, 0, 1.5, 3],[-3, -1.5,0, 1.5, 3])
                axs[j][i].set_xlabel(r'$\xi^{(k)}$', fontsize=12)
            if j == 0:
                axs[j][i].set_title(rf'$n={n_list[i]}$', fontsize=14)
                vals = [[0,0.15,0.30, 0.45], [0,0.25,0.5,0.75], [0,0.25,0.5,0.75]]
                axs[j][i].set_yticks(vals[i], [str(e) for e in vals[i]])
            if i == 0:
                axs[j][i].set_ylabel('Density', fontsize=12)
                
            
    plt.tight_layout()
    plt.savefig('nu_CLT.png', dpi=300, bbox_inches='tight')


def sbm_plot():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    k_max = 3
    n_nodes = 32_000

    a = 10
    b = 0.1
    props = torch.Tensor([0.9,0.1])

    nu_fn = lambda u: np.sqrt(u)
    nu = nu_fn(n_nodes)
    B0 = torch.Tensor([[a,b], [b,a]])
    B = B0 * nu/n_nodes

    mu = torch.Tensor([1,-1])/100

    cluster_sizes = torch.tensor([int(n_nodes*prop) for prop in props])
    y = torch.cat([torch.full((b,),i) for i, b in enumerate(cluster_sizes)])

    torch.manual_seed(43031)

    L = B.shape[0]
    Z = torch.zeros((n_nodes, L))
    Z[torch.arange(n_nodes), y] = 1

    P = sbm_prob_mat(B, cluster_sizes)
    A_gen = ier_sampler(P.to(device))

    J = B0 @ torch.diag(torch.Tensor(props))
    w = torch.Tensor([[torch.linalg.norm(torch.sqrt(J[l]) @ torch.matrix_power(J,k-1) @ mu) for l in range(L)] for k in range(1,k_max+1)])

    n_instances = 100
    phi = torch.zeros(n_nodes).to(device)
    phi_hist = np.zeros((n_instances, n_nodes), dtype=np.float32)
    for i in tqdm(range(n_instances)):
        A = A_gen().fill_diagonal_(0) / nu 
        X = (Z @ mu+ torch.randn(n_nodes)*0.01).to(device)

        phi = X
        for k in range(k_max):
            phi = A @ phi
        phi_hist[i] = phi.cpu().numpy()

    centered_features = np.sqrt(nu) * (phi_hist - phi_hist.mean(axis=0, keepdims=True)).flatten()

    plt.figure(figsize=(9,4)) 
    n_bins = 100 
    plt.hist(centered_features, bins=np.linspace(-10,10,n_bins), density=True, alpha=0.7, label='Empirical')


    plt.title('Distribution of Poly-GNN Features for 2-Class SBM', fontsize=16)
    plt.xlabel(r'$\xi^{(k)}$', fontsize=14)
    plt.ylabel('Density', fontsize=14)

    sig = w[2].max()
    x_axis = np.linspace(-10,10, n_bins*4)
    mixture_gaussian = st.norm.pdf(x_axis, 0, w[2][0]) * props[0].numpy() + st.norm.pdf(x_axis, 0, w[2][1]) * props[1].numpy()
    plt.plot(x_axis, mixture_gaussian, label='Theoretical')

    plt.grid(True)
    plt.legend(fontsize=14)
    plt.gca().tick_params(axis='both', which='major', labelsize=12)
    plt.gca().tick_params(axis='both', which='minor', labelsize=10)
    plt.savefig('SBM_high_res.png', dpi=300, bbox_inches='tight')


def grad_paths():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    k_max = 2
    n_nodes = 8192

    a = 0.4
    b = 1.
    c=2
    props = torch.Tensor([0.25,0.45,0.3])

    nu_fn = lambda u: np.sqrt(u)
    nu = nu_fn(n_nodes)
    B0 = torch.Tensor([[a,b,b], [b,a,b], [b,b,a]])

    mu = torch.Tensor([[2,2],[-1,-3], [-1,0]])
    J = B0 @ torch.diag(torch.Tensor(props))


    mu_k = torch.matrix_power(J,k_max-1) @ mu
    mu_mn = J @ mu_k 
    Sigma = torch.stack([mu_k.T @ (torch.diag(J[l]/nu) @  mu_k) for l in range(3)])

    dist1, dist2, dist3 = [mvn(mu_mn[l], Sigma[l]) for l in range(3)]
    u1, u2 = np.mgrid[-0.15-0.125:0.15-0.125:.001, -0.3-0.5:0.3-0.5:.001]
    pos = np.dstack((u1, u2))

    lr = 1e1


    torch.manual_seed(43031)
    W = torch.nn.Linear(2, 3, bias=True, device=device)

    B = B0 * nu/n_nodes
    cluster_sizes = torch.tensor([int(n_nodes*prop) for prop in props])
    cluster_sizes[-1] += (n_nodes - cluster_sizes.sum().item())
    y = torch.cat([torch.full((b,),i) for i, b in enumerate(cluster_sizes)])

    L = B.shape[0]
    Z = torch.zeros((n_nodes, L))
    Z[torch.arange(n_nodes), y] = 1
    X = (Z @ mu).to(device)  + torch.randn(n_nodes,2).to(device)*0.5
    y = y.to(device)

    P = sbm_prob_mat(B, cluster_sizes)
    A_gen = ier_sampler(P.to(device))

    phi = torch.matrix_power(A_gen()/nu,k_max) @ X

    BCE = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(W.parameters(), lr=lr)

    W_init = W.weight.detach().cpu()
    grad_list = []

    epochs = 10
    for t in range(epochs):
        optimizer.zero_grad()
        out = W(phi)
        loss = BCE(out, y)
        loss.backward()
        grad_list.append(W.weight.grad.detach().cpu().numpy())
        optimizer.step()

    torch.manual_seed(43031)
    W2 = torch.nn.Linear(2, 3, bias=True, device=device)

    B = B0 * nu/n_nodes
    cluster_sizes = torch.tensor([int(n_nodes*prop) for prop in props])
    cluster_sizes[-1] += (n_nodes - cluster_sizes.sum().item())
    y = torch.cat([torch.full((b,),i) for i, b in enumerate(cluster_sizes)])

    L = B.shape[0]
    Z = torch.zeros((n_nodes, L))
    Z[torch.arange(n_nodes), y] = 1
    X = (Z @ mu).to(device)  + torch.randn(n_nodes,2).to(device)*0.5
    y = y.to(device)

    P = sbm_prob_mat(B, cluster_sizes)
    A_gen = ier_sampler(P.to(device))

    dists = [dist1, dist2, dist3]
    phi = torch.Tensor([dists[yi].rvs() for yi in y]).to(device)

    BCE = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(W2.parameters(), lr=lr)

    W2_init = W2.weight.detach().cpu()
    grad2_list = []

    epochs = 10
    for t in range(epochs):
        optimizer.zero_grad()
        out = W2(phi)
        loss = BCE(out, y)
        loss.backward()
        grad2_list.append(W2.weight.grad.detach().cpu().numpy())
        optimizer.step()

    fig, axs = plt.subplots(1,4, figsize=(14, 3))

    axs[0].contourf(u1, u2, props[0] * dist1.pdf(pos) + props[1] * dist2.pdf(pos)+props[2] * dist3.pdf(pos), 15, cmap='cividis')
    axs[0].set_title('Limiting Mixture Distribution')

    a0 = 0.8
    scale = (1,1,0.7)
    for i in range(3):
        grad_list = np.array(grad_list)
        grad2_list = np.array(grad2_list)
        W_t = W_init
        W2_t = W2_init 
        for t in range(epochs):
            if t+1 < epochs:
                axs[i+1].arrow(W_t[i][0], W_t[i][1], lr*grad_list[t][i][0], lr*grad_list[t][i][1], label='$\mathbb{P}$' if t == 2 else None,
                             color='green', alpha=a0**t, head_width=0.05*scale[i], length_includes_head=True)
                axs[i+1].arrow(W2_t[i][0], W2_t[i][1], lr*grad2_list[t][i][0], lr*grad2_list[t][i][1],label='$\mathbb{G}$' if t == 2 else None,
                             color='orange', alpha=a0**t, head_width=0.05*scale[i], length_includes_head=True)
            W_t = W_t + lr*grad_list[t]
            W2_t = W2_t + lr*grad2_list[t]
        axs[i+1].legend(loc=4)
        axs[i+1].set_title(rf'Gradient Path for $w_{i+1}$')
    plt.savefig('gradient_paths.png', dpi=300, bbox_inches='tight')


def contour_plot():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    k_max = 2
    n_nodes = 32000

    a = 0.5
    b = 1
    props = torch.Tensor([0.4,0.6])

    nu_fn = lambda u: np.sqrt(u)
    nu = 30
    B0 = torch.Tensor([[a,b], [b,a]])

    mu = torch.Tensor([[2,2],[-1,-2]])
    J = B0 @ torch.diag(torch.Tensor(props))


    mu_k = torch.matrix_power(J,k_max-1) @ mu
    mu_mn = J @ mu_k 
    Sigma = torch.stack([mu_k.T @ (torch.diag(J[l]/nu) @  mu_k) for l in range(2)])

    dist1, dist2 = [mvn(mu_mn[l], Sigma[l]) for l in range(2)]
    u1, u2 = np.mgrid[-0.3+0.175:0.3+0.175:.001, -0.35-0.175:0.35-0.175:.001]
    pos = np.dstack((u1, u2))

    torch.manual_seed(43031)

    B = B0 * nu/n_nodes
    cluster_sizes = torch.tensor([int(n_nodes*prop) for prop in props])
    cluster_sizes[-1] += (n_nodes - cluster_sizes.sum().item())
    y = torch.cat([torch.full((b,),i) for i, b in enumerate(cluster_sizes)])

    L = B.shape[0]
    Z = torch.zeros((n_nodes, L))
    Z[torch.arange(n_nodes), y] = 1
    X = (Z @ mu).to(device)
    y = y.to(device)

    P = sbm_prob_mat(B, cluster_sizes)
    A_gen = ier_sampler(P.to(device))

    phi = torch.matrix_power(A_gen()/nu,k_max) @ X
    W = torch.nn.Linear(2, 2, bias=True, device=device)
    CE = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(W.parameters(), lr=5e-1)

    epochs = 1000
    for t in range(epochs):
        optimizer.zero_grad()
        out = W(phi)
        loss = CE(out, y)
        if (t+1) % 50 == 0:
            print(t+1,':', loss.item())
        loss.backward()
        optimizer.step()

    scores = W.cpu()(torch.Tensor(pos.reshape(-1,2))).reshape(pos.shape).detach().numpy()
    qda_obj = QDA()
    qda_obj.fit(phi.cpu().detach().numpy(), y.cpu().detach().numpy())
    prob_outs = qda_obj.decision_function(pos.reshape(-1,2)).reshape(pos.shape[:2])

    fig, ax= plt.subplots(1,3, figsize=(14, 3.5))

    vmin = min((-prob_outs).min(), (scores[...,0]-scores[...,1]).min())
    vmax = max((-prob_outs).max(), (scores[...,0]-scores[...,1]).max())

    ax[0].contourf(u1, u2, props[0] * dist1.pdf(pos) + props[1] * dist2.pdf(pos), 12, cmap='cividis')
    ax[0].scatter(mu_mn[:,0], mu_mn[:,1], 130, 'g', marker='*', edgecolor='black', linewidth=1, label=r'Class Means')
    ax[0].set_xlabel(r'$\phi$ Component 1', fontsize=11)
    ax[0].set_ylabel(r'$\phi$ Component 2', fontsize=11)
    ax[0].set_title('Kernel Density Estimate')
    ax[0].legend(loc=4)

    phi_np = phi.cpu().numpy()

    im = ax[1].contourf(u1,u2, -prob_outs, 12, cmap='plasma', vmin=vmin, vmax=vmax)
    xlim = ax[1].get_xlim()
    ylim = ax[1].get_ylim()
    ax[1].scatter(phi_np[:,0], phi_np[:,1], 5, 'black', alpha=0.05, marker='.', label=r'Samples')
    ax[1].scatter(mu_mn[:,0], mu_mn[:,1], 130, 'g', marker='*', edgecolor='black', linewidth=1, label=r'Class Means')
    ax[1].set_xlim(xlim)
    ax[1].set_ylim(ylim)
    ax[1].set_xlabel(r'$\phi$ Component 1', fontsize=11)
    ax[1].set_title('QDA Log-Likelihood Ratio')
    ax[1].legend(loc=4)

    ax[2].contourf(u1, u2, scores[...,0]-scores[...,1], 12,cmap='plasma', vmin=vmin, vmax=vmax)
    xlim = ax[2].get_xlim()
    ylim = ax[2].get_ylim()
    ax[2].scatter(phi_np[:,0], phi_np[:,1], 5, 'black', alpha=0.05, marker='.', label=r'Samples')
    ax[2].scatter(mu_mn[:,0], mu_mn[:,1], 130, 'g', marker='*', edgecolor='black', linewidth=1, label=r'Class Means')
    ax[2].set_xlim(xlim)
    ax[2].set_ylim(ylim)
    ax[2].set_xlabel(r'$\phi$ Component 1', fontsize=11)
    ax[2].set_title('CE Log-Likelihood Ratio')
    ax[2].legend(loc=4)

    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.815, 0.111, 0.02, 0.77])
    fig.colorbar(im, cax=cbar_ax)
    plt.savefig('empirical_likelihood_plots.png', dpi=300, bbox_inches='tight')


def oversmooth_plot():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_nodes = 32000

    a = 0.5
    b = 1
    props = torch.Tensor([0.4,0.6])

    nu_fn = lambda u: np.sqrt(u)
    nu = nu_fn(n_nodes)/5
    B0 = torch.Tensor([[a,b], [b,a]])

    mu = torch.Tensor([[2,2],[-1,-2]])
    J = B0 @ torch.diag(torch.Tensor(props))

    torch.manual_seed(43031)

    B = B0 * nu/n_nodes
    cluster_sizes = torch.tensor([int(n_nodes*prop) for prop in props])
    y = torch.cat([torch.full((b,),i) for i, b in enumerate(cluster_sizes)])

    L = B.shape[0]
    Z = torch.zeros((n_nodes, L))
    Z[torch.arange(n_nodes), y] = 1
    X = (Z @ mu).to(device) + torch.randn(n_nodes,2).to(device)*0.01
    y = y.to(device)

    P = sbm_prob_mat(B, cluster_sizes)
    A_gen = ier_sampler(P.to(device))

    k_vals = [2,4,6]

    bnds = [(-0.3+0.175, 0.3+0.175, -0.375-0.2, 0.375-0.2),
            (-0.06+0.085, 0.06+0.085, -0.07-0.09, 0.07-0.09),
            (-0.031+0.045, 0.031+0.045, -0.031-0.049, 0.031-0.049)]

    fig, ax= plt.subplots(1,3, figsize=(14, 3.5))
    for i,k in enumerate(k_vals):
        phi = torch.matrix_power(A_gen()/nu,k) @ X
        kernel = st.gaussian_kde(phi.cpu().numpy().T)
        
        mu_k = torch.matrix_power(J,k-1) @ mu
        mu_mn = J @ mu_k 

        xl,xr,yl,yr = bnds[i]
        u1, u2 = np.mgrid[xl:xr:.005/10**i, yl:yr:.005/10**i]
        pos = np.vstack([u1.ravel(), u2.ravel()])
        f = np.reshape(kernel(pos).T, u1.shape)
        ax[i].contourf(u1, u2, f, 12, cmap='cividis')
        xlim = ax[i].get_xlim()
        ylim = ax[i].get_ylim()
        ax[i].set_xlabel(r'$\phi$ Component 1', fontsize=11)
        if i == 0:
            ax[0].set_ylabel(r'$\phi$ Component 2', fontsize=11)
            x0,y0 = 0,0
            x1,y1 = 0,0
        elif i == 1:
            x0,x1 = xlim[0],xlim[0]
            y0,y1 = mu_mn[0][1]/mu_mn[0][0] * x0, mu_mn[1][1]/mu_mn[1][0] * x1
        elif i == 2:
            y0,y1 = ylim[1],ylim[1]
            x0,x1 =  mu_mn[0][0]/mu_mn[0][1] * y0, mu_mn[1][0]/mu_mn[1][1] * y1

        an = ax[i].annotate("", xytext=(x0,y0), xy=(mu_mn[0][0], mu_mn[0][1]),arrowprops=dict(arrowstyle="->",color='green'))
        ax[i].annotate("", xytext=(x1,y1), xy=(mu_mn[1][0], mu_mn[1][1]),arrowprops=dict(arrowstyle="->",color='green'))
        ax[i].set_title(rf'$k ={k}$')
        ax[i].legend([an.arrow_patch], ['Mean Vectors'],loc=3)
    fig.suptitle(r'Mixture Density for Different Number of Aggregations $k$', y=1.05, fontsize=15)
    plt.savefig('empirical_oversmooth_plots.png', dpi=300, bbox_inches='tight')


if __name__ == '__main__':
    print('ER plot')
    er_phi()
    print('SBM plot')
    sbm_plot()
    print('Gradient plot')
    grad_paths()
    print('Contour plot')
    contour_plot()
    print('Oversmooth plot')
    oversmooth_plot()
