import torch
import networkx as nx
import numpy as np
from scipy.sparse.csgraph import laplacian
from scipy.linalg import eig
import cvxpy as cp
from numpy.linalg import det, eigvals


def check_inf_valid(func):
    def wrapper(self, *args, **kwargs):
        value = func(self, *args, **kwargs)
        if value == float("-inf"):
            return -1
        if value == float('inf'):
            return -1
        return value
    return wrapper

class Graph:
    def __init__(self, adj_mat):
        self.adj_mat = np.array(adj_mat)
        self.graph = nx.from_numpy_array(self.adj_mat)

    @check_inf_valid
    def size(self):
        return self.graph.size()
    
    @check_inf_valid
    def order(self):
        return self.graph.order()

    @check_inf_valid
    def fiedler_value(self):
        L = laplacian(self.adj_mat)

        # Step 2: Compute eigenvalues and eigenvectors of the Laplacian matrix
        eigenvalues, eigenvectors = eig(L)

        # Step 3: Sort the eigenvalues and eigenvectors in ascending order
        sorted_indices = np.argsort(eigenvalues)
        eigenvalues = eigenvalues[sorted_indices]
        eigenvectors = eigenvectors[:, sorted_indices]

        # Step 4: The Fiedler value is the second smallest eigenvalue
        fiedler_value = eigenvalues[1].real  # Taking the real part in case of small imaginary component

        # Step 5: The Fiedler vector is the eigenvector corresponding to the second smallest eigenvalue
        fiedler_vector = eigenvectors[:, 1].real
        return fiedler_value

    @check_inf_valid
    def diameter(self):
        if nx.is_connected(self.graph):
            return nx.diameter(self.graph)
        else:
            return -1

    @check_inf_valid
    def estrada_index(self):
        value = nx.estrada_index(self.graph)
        if value > 1000:
            return 1000
        elif value < -1000:
            return -1000
        else:
            return value

    @check_inf_valid
    def fractional_chromatic_number(self):
        adj_matrix = self.adj_mat
        # Get the number of vertices in the graph
        n = len(adj_matrix)

        # Variables representing the assignment of fractions to colors
        x = cp.Variable(n)  # One variable per vertex

        # The objective is to minimize the sum of x
        objective = cp.Minimize(cp.sum(x))

        # Constraints: adjacent vertices must get disjoint sets of colors
        constraints = []
        for i in range(n):
            for j in range(i + 1, n):
                if adj_matrix[i][j] == 1:  # If there's an edge between i and j
                    constraints.append(x[i] + x[j] >= 1)

        # Add constraint that every vertex must have its set sum to 1
        for i in range(n):
            constraints.append(x[i] >= 0)

        # Form and solve the problem
        prob = cp.Problem(objective, constraints)
        result = prob.solve()

        return result

    @check_inf_valid
    def hyper_wiener_index(self):
        # Get shortest path lengths between all pairs of nodes
        shortest_paths = dict(nx.floyd_warshall(self.graph))
        
        # Initialize the hyper-Wiener index
        WW = 0
        
        # Iterate through all pairs of vertices
        for u in shortest_paths:
            for v in shortest_paths:
                if u != v:  # Exclude self-pairs
                    d_uv = shortest_paths[u][v]
                    WW += d_uv + d_uv**2
        
        # Return half of the sum, as per the formula
        return WW / 2

    @check_inf_valid
    def lovasz_number(self):
        adjacency_matrix = self.adj_mat
        n = adjacency_matrix.shape[0]

        # Define the semidefinite matrix variable
        X = cp.Variable((n, n), symmetric=True)

        # Define constraints
        constraints = [X >> 0]  # X must be positive semidefinite
        constraints += [cp.diag(X) == 1]  # Diagonal elements must be 1

        # Add constraint that X_ij = 0 if there's an edge between i and j
        for i in range(n):
            for j in range(n):
                if adjacency_matrix[i, j] == 1:
                    constraints += [X[i, j] == 0]

        # Define the objective: maximize trace(X)
        objective = cp.Maximize(cp.trace(X))

        # Define and solve the problem
        problem = cp.Problem(objective, constraints)
        problem.solve()

        # Return the optimal value
        if problem.value == float("-inf"):
            return -1
        return problem.value
    
    @check_inf_valid
    def parry_sullivan_invariant(self):
        A = self.adj_mat
        n = A.shape[0]  # Number of nodes
        I = np.eye(n)  # Identity matrix of the same size as A
        char_poly = np.linalg.det(np.eye(n) - A)

        # Step 2: Compute eigenvalues of A
        eigenvalues = np.linalg.eigvals(A)

        # Step 3: Get the spectral radius (the largest absolute value of the eigenvalues)
        spectral_radius = np.max(np.abs(eigenvalues))

        # Step 4: Return the Parry-Sullivan invariant information
        # return {
        #     "Characteristic Polynomial": char_poly,
        #     "Eigenvalues": eigenvalues,
        #     "Spectral Radius": spectral_radius
        # }
        if char_poly > 1000:
            return 1000
        elif char_poly < -1000:
            return -1000
        else:
            return char_poly

    @check_inf_valid
    def radius(self):
        if nx.is_connected(self.graph):
            return nx.radius(self.graph)
        else:
            return -1
    
    @check_inf_valid
    def randic_index(self):
        adjacency_matrix = self.adj_mat
        # Number of vertices
        n = len(adjacency_matrix)
        
        # Calculate degrees of the vertices
        degrees = np.sum(adjacency_matrix, axis=1)
        
        # Initialize Randić index
        randic = 0.0
        
        # Iterate through the adjacency matrix
        for i in range(n):
            for j in range(i+1, n):  # Only consider the upper triangle of the matrix
                if adjacency_matrix[i][j] == 1:  # If there's an edge between i and j
                    randic += 1 / np.sqrt(degrees[i] * degrees[j])
        
        return randic
    
    @check_inf_valid
    def rank(self):
        return np.linalg.matrix_rank(self.adj_mat)

    @check_inf_valid
    def splittance(self):
        if nx.is_connected(self.graph):
            min_cut = nx.minimum_edge_cut(self.graph)
            return len(min_cut)
        else:
            return -1

    @check_inf_valid
    def strength(self):
        strength = np.sum(self.adj_mat) / 2  # divide by 2 to avoid double counting edges
        return strength

    @check_inf_valid
    def wiener_index(self):
        return nx.wiener_index(self.graph)

    def result(self):
        return np.array([
            self.size(),
            self.order(),
            self.fiedler_value(),
            self.diameter(),
            self.estrada_index(),
            self.fractional_chromatic_number(),
            self.hyper_wiener_index(),
            self.lovasz_number(),
            self.parry_sullivan_invariant(),
            self.radius(),
            self.randic_index(),
            self.rank(),
            self.splittance(),
            self.strength(),
            self.wiener_index()
        ])

    def test(self):
        print(self.graph)
        print("size", self.size())
        print("order", self.order())
        print("fiedler_value", self.fiedler_value())
        print("diameter", self.diameter())
        print("estrada_index", self.estrada_index())
        print("fractional_chromatic_number", self.fractional_chromatic_number())
        print("hyper_wiener_index", self.hyper_wiener_index())
        print("lovasz_number", self.lovasz_number())
        print("parry_sullivan_invariant", self.parry_sullivan_invariant())
        print("radius", self.radius())
        print("randic_index", self.randic_index())
        print("rank", self.rank())
        print("splittance", self.splittance())
        print("strength", self.strength())
        print("wiener_index", self.wiener_index())
