# Copied and edited from https://github.com/jctops/understanding-oversquashing/tree/main
# Original author: Jake Topping and Francesco Di Giovanni and Benjamin Paul Chamberlain and Xiaowen Dong and Michael M. Bronstein
# Description: This class implements SDRF as described in [Understanding over-squashing and bottlenecks on graphs via curvature, 2021]
import math
import os
import pathlib
from typing import Any

from numba import cuda
import numpy as np
import torch
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import (
    to_networkx,
    from_networkx,
    to_dense_adj,
    remove_self_loops,
    to_undirected,
)

# Turn off numba warnings
import warnings
from numba.core.errors import NumbaPerformanceWarning
warnings.filterwarnings("ignore", category=NumbaPerformanceWarning)


from src.utils.path_io import get_path_up_to

ROOT_PATH = get_path_up_to(__file__, "src")

def softmax(a, tau=1):
    exp_a = np.exp(a * tau)
    return exp_a / exp_a.sum()


@cuda.jit(
    "void(float32[:,:], float32[:,:], float32[:], float32[:], int32, float32[:,:])"
)
def _balanced_forman_curvature(A, A2, d_in, d_out, N, C):
    i, j = cuda.grid(2)

    if (i < N) and (j < N):
        if A[i, j] == 0:
            C[i, j] = 0
            return

        if d_in[i] > d_out[j]:
            d_max = d_in[i]
            d_min = d_out[j]
        else:
            d_max = d_out[j]
            d_min = d_in[i]

        if d_max * d_min == 0:
            C[i, j] = 0
            return

        sharp_ij = 0
        lambda_ij = 0
        for k in range(N):
            TMP = A[k, j] * (A2[i, k] - A[i, k]) * A[i, j]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

            TMP = A[i, k] * (A2[k, j] - A[k, j]) * A[i, j]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

        C[i, j] = (
                (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2[i, j] * A[i, j]
        )
        if lambda_ij > 0:
            C[i, j] += sharp_ij / (d_max * lambda_ij)


def balanced_forman_curvature(A, C=None):
    N = A.shape[0]
    A2 = torch.matmul(A, A)
    d_in = A.sum(axis=0)
    d_out = A.sum(axis=1)
    if C is None:
        C = torch.zeros(N, N).cuda()

    threadsperblock = (16, 16)
    blockspergrid_x = math.ceil(N / threadsperblock[0])
    blockspergrid_y = math.ceil(N / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)

    _balanced_forman_curvature[blockspergrid, threadsperblock](A, A2, d_in, d_out, N, C)
    return C


@cuda.jit(
    "void(float32[:,:], float32[:,:], float32, float32, int32, float32[:,:], int32, int32, int32[:], int32[:], int32, int32)"
)
def _balanced_forman_post_delta(
        A, A2, d_in_x, d_out_y, N, D, x, y, i_neighbors, j_neighbors, dim_i, dim_j
):
    I, J = cuda.grid(2)

    if (I < dim_i) and (J < dim_j):
        i = i_neighbors[I]
        j = j_neighbors[J]

        if (i == j) or (A[i, j] != 0):
            D[I, J] = -1000
            return

        # Difference in degree terms
        if j == x:
            d_in_x += 1
        elif i == y:
            d_out_y += 1

        if d_in_x * d_out_y == 0:
            D[I, J] = 0
            return

        if d_in_x > d_out_y:
            d_max = d_in_x
            d_min = d_out_y
        else:
            d_max = d_out_y
            d_min = d_in_x

        # Difference in triangles term
        A2_x_y = A2[x, y]
        if (x == i) and (A[j, y] != 0):
            A2_x_y += A[j, y]
        elif (y == j) and (A[x, i] != 0):
            A2_x_y += A[x, i]

        # Difference in four-cycles term
        sharp_ij = 0
        lambda_ij = 0
        for z in range(N):
            A_z_y = A[z, y] + 0
            A_x_z = A[x, z] + 0
            A2_z_y = A2[z, y] + 0
            A2_x_z = A2[x, z] + 0

            if (z == i) and (y == j):
                A_z_y += 1
            if (x == i) and (z == j):
                A_x_z += 1
            if (z == i) and (A[j, y] != 0):
                A2_z_y += A[j, y]
            if (x == i) and (A[j, z] != 0):
                A2_x_z += A[j, z]
            if (y == j) and (A[z, i] != 0):
                A2_z_y += A[z, i]
            if (z == j) and (A[x, i] != 0):
                A2_x_z += A[x, i]

            TMP = A_z_y * (A2_x_z - A_x_z) * A[x, y]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

            TMP = A_x_z * (A2_z_y - A_z_y) * A[x, y]
            if TMP > 0:
                sharp_ij += 1
                if TMP > lambda_ij:
                    lambda_ij = TMP

        D[I, J] = (
                (2 / d_max) + (2 / d_min) - 2 + (2 / d_max + 1 / d_min) * A2_x_y * A[x, y]
        )
        if lambda_ij > 0:
            D[I, J] += sharp_ij / (d_max * lambda_ij)


def balanced_forman_post_delta(A, x, y, i_neighbors, j_neighbors, D=None):
    N = A.shape[0]
    A2 = torch.matmul(A, A)
    d_in = A[:, x].sum()
    d_out = A[y].sum()
    if D is None:
        D = torch.zeros(len(i_neighbors), len(j_neighbors)).cuda()

    threadsperblock = (16, 16)
    blockspergrid_x = math.ceil(D.shape[0] / threadsperblock[0])
    blockspergrid_y = math.ceil(D.shape[1] / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)

    _balanced_forman_post_delta[blockspergrid, threadsperblock](
        A,
        A2,
        d_in,
        d_out,
        N,
        D,
        x,
        y,
        np.array(i_neighbors),
        np.array(j_neighbors),
        D.shape[0],
        D.shape[1],
    )
    return D


class SDRF(BaseTransform):

    def __init__(
            self,
            graph_name: str,
            n_loops=10,
            remove_edges=True,
            removal_bound=0.5,
            tau=1,
            is_undirected=False,
    ):
        self.n_loops = n_loops
        self.remove_edges = remove_edges
        self.removal_bound = removal_bound
        self.tau = tau
        self.is_undirected = is_undirected
        self.graph_index = 0
        self.dirname = os.path.join(ROOT_PATH, 'data','graphs', graph_name, 'sdrf')


    def forward(self, data) -> Any:

        # Check if there is a preprocessed graph
        if not os.path.exists(self.dirname):
            pathlib.Path(self.dirname).mkdir(parents=True, exist_ok=True)
        edge_index_filename = os.path.join(self.dirname,
                                           f'iters_{self.n_loops}_tau_{self.tau}_remove_{self.remove_edges}_bound_{self.removal_bound}_edge_index_{self.graph_index}.pt')
        edge_type_filename = os.path.join(self.dirname,
                                          f'iters_{self.n_loops}_tau_{self.tau}_remove_{self.remove_edges}_bound_{self.removal_bound}_edge_type_{self.graph_index}.pt')

        if (os.path.exists(edge_index_filename) and os.path.exists(edge_type_filename)):
            print(f'[SDRF]: Load graph {self.graph_index}...')
            # if (debug): print(
            #     f'[INFO] Rewired graph for {loops} iterations, {batch_add} edge additions and {batch_remove} edge removal exists...')
            edge_index = torch.load(edge_index_filename, weights_only=False)
            edge_type = torch.load(edge_type_filename, weights_only=False)

            data.edge_index = edge_index
            data.edge_type = edge_type
            self.graph_index += 1
            return data
        #
        # else:
        #     print("################### Graph not found #############################")

        print(f'[SDRF]: Processing graph {self.graph_index}...')
        edge_index = data.edge_index
        if self.is_undirected:
            edge_index = to_undirected(edge_index)
        A = to_dense_adj(remove_self_loops(edge_index)[0])[0]
        N = A.shape[0]
        G = to_networkx(data)
        if self.is_undirected:
            G = G.to_undirected()
        A = A.cuda()
        C = torch.zeros(N, N).cuda()

        for x in range(self.n_loops):
            can_add = True
            balanced_forman_curvature(A, C=C)
            ix_min = C.argmin().item()
            x = ix_min // N
            y = ix_min % N

            if self.is_undirected:
                x_neighbors = list(G.neighbors(x)) + [x]
                y_neighbors = list(G.neighbors(y)) + [y]
            else:
                x_neighbors = list(G.successors(x)) + [x]
                y_neighbors = list(G.predecessors(y)) + [y]
            candidates = []
            for i in x_neighbors:
                for j in y_neighbors:
                    if (i != j) and (not G.has_edge(i, j)):
                        candidates.append((i, j))

            if len(candidates):
                D = balanced_forman_post_delta(A, x, y, x_neighbors, y_neighbors)
                improvements = []
                for (i, j) in candidates:
                    improvements.append(
                        (D - C[x, y])[x_neighbors.index(i), y_neighbors.index(j)].item()
                    )

                k, l = candidates[
                    np.random.choice(
                        range(len(candidates)), p=softmax(np.array(improvements), tau=self.tau)
                    )
                ]
                G.add_edge(k, l)
                if self.is_undirected:
                    A[k, l] = A[l, k] = 1
                else:
                    A[k, l] = 1
            else:
                can_add = False
                if not self.remove_edges:
                    break

            if self.remove_edges:
                ix_max = C.argmax().item()
                x = ix_max // N
                y = ix_max % N
                if C[x, y] > self.removal_bound:
                    G.remove_edge(x, y)
                    if self.is_undirected:
                        A[x, y] = A[y, x] = 0
                    else:
                        A[x, y] = 0
                else:
                    if can_add is False:
                        break

        edge_index = from_networkx(G).edge_index
        edge_type = torch.zeros(size=(len(G.edges),)).type(torch.LongTensor)
        # edge_type = torch.tensor(edge_type)

        # if (debug): print(f'[INFO] Saving edge_index to {edge_index_filename}')
        with open(edge_index_filename, 'wb') as f:
            torch.save(edge_index, f)

        # if (debug): print(f'[INFO] Saving edge_type to {edge_type_filename}')
        with open(edge_type_filename, 'wb') as f:
            torch.save(edge_type, f)

        # Apply new edges to data object
        data.edge_index = edge_index
        data.edge_type = edge_type

        self.graph_index += 1

        return data
