import networkx as nx
import random
import numpy as np
import math
import warnings
import heapq
from dataloder import *
from scipy.optimize import linprog

import heapq


def PM(X0, P, iters):
    X_old = np.zeros(len(P))
    X1 = X0    
    for _ in range(iters): # at least 5 iterations
        X_old = X1
        X1 = X_old @ P # matrix mulitplication        
    return X1


def normalize(x):
    fac = abs(x).max()
    x_n = x / x.max()
    return fac, x_n

def normalize_rows(matrix):
    row_sums = matrix.sum(axis=1, keepdims=True)  # Compute row sums
    return matrix / row_sums  # Element-wise division

def personalize_adj(matrix, v, damping):
    row_sums = matrix.sum(axis=1, keepdims=True)  # Compute row sums
    mat = matrix / row_sums  # Element-wise division
    return damping * mat + (1 - damping) * v

class MaxHeap:
    def __init__(self):
        self.heap = []
    
    def push(self, item):
        """
        Adds a pair (x, y) to the heap, sorted by the second value y in descending order.
        """
        heapq.heappush(self.heap, (-item[1], item))  # Store negative to simulate max-heap
    
    def pop(self):
        """
        Removes and returns the pair with the largest second value.
        """
        return heapq.heappop(self.heap)[1]
    
    def peek(self):
        """
        Returns the pair with the largest second value without removing it.
        """
        return self.heap[0][1] if self.heap else None
    
    def is_empty(self):
        """
        Checks if the heap is empty.
        """
        return len(self.heap) == 0

# Input: adj matrix of G, Red and Blue nodes that form partition of V (i.e., their colors), b edges to add.
# Output: new networkx graph with b new edges.
def algoOpt(A, colors, bmax, directed=False, powerMethod=False, pagerank=False, person=0, damping=0.75):
    matrixA = A.A
    #bmax = bvector[-1] #Assuming b is sorted and increasing
    if not directed:
        G = nx.from_numpy_matrix(matrixA, create_using=nx.Graph())
    else:
        G = nx.from_numpy_matrix(matrixA, create_using=nx.DiGraph())
    numOfNode = len(G.nodes)
    R = [i for i in range(numOfNode) if colors[i] == 'red']
    B = [i for i in range(numOfNode) if colors[i] == 'blue']

    # The candidates are red nodes with less than |B| blue neighbors (otherwise we can't add an edge incident to this red node)
    r_bngbs = {r: [b for b in B if b not in G[r]] for r in R}
    # Removing red vertices that are shortcutted to all blue vertices.
    # keeping only those where we can add edges
    r_bngbs = {r: r_bngbs[r] for r in r_bngbs if len(r_bngbs[r]) > 0}


    r_degs = {r: G.degree(r) for r in r_bngbs}
    r_ids = list(r_bngbs.keys())
    G_R = nx.induced_subgraph(G, r_ids)
    #Remove zero degree nodes
    zero_deg = [node for node, degree in dict(G_R.degree()).items() if degree < 1]
    G.remove_nodes_from(zero_deg)
    if not directed:
        G_R = nx.induced_subgraph(G, r_ids)
        largest = max(nx.connected_components(G_R), key=len)
        G_scc = G_R.subgraph(largest).copy()
        edgesToAdd, stat = optimize(G_scc, r_bngbs, r_degs, bmax)
    else:
        G_R = nx.induced_subgraph(G_R, r_ids)
        idx = 0
        R_nodes_map = {}
        #print(list(G_R.nodes))
        for nod in list(G_R.nodes):
            R_nodes_map[nod] = idx
            idx = idx + 1

        comps = list(nx.strongly_connected_components(G_R))
        compss = sorted(comps, key=len, reverse=True)
        if not pagerank:
            # Create the induced subgraph
            G_scc = G_R.subgraph(compss[0]).copy()
            r_degss = {r: G_scc.out_degree(r) for r in G_scc.nodes}
            edgesToAdd, stat = optimizeDirected(G_scc, r_bngbs, r_degss, bmax, powerMethod, pr=pagerank)
        else:
            if person==0:
                #G_scc = G_R.subgraph(compss[0]).copy()
                G_scc = G_R.copy()
                r_degss = {r: G_scc.out_degree(r) for r in G_scc.nodes}
                NN = len(G_scc.nodes)
                avec = np.ones(NN)/NN
                edgesToAdd, stat = optimizeDirected(G_scc, r_bngbs, r_degss, bmax, powerMethod, pr=pagerank, personalization=avec, damping=damping)
            elif person==1: # Jump according to component size
                NN = len(G_R.nodes)
                avec = np.zeros(NN)
                norm = sum([len(curr_comp) for curr_comp in comps]) # should be the number of nodes
                #print("n_R", NN, G_R.number_of_nodes(), norm)
                r_degss = {r: G_R.out_degree(r) for r in G_R.nodes}
                for comp in comps:
                    for element in comp:
                        #print(element, R_nodes_map[element])
                        avec[R_nodes_map[element]] = 1./(len(comp)*len(comps))
                #print("DEBUG:", NN, norm, r_ids, comps[0], avec, sum(avec))
                edgesToAdd, stat = optimizeDirected(G_R, r_bngbs, r_degss, bmax, powerMethod, pr=pagerank, personalization=avec, damping=damping)
            elif person==2: # Jump according to component size and node degree
                NN = G_R.number_of_nodes()
                avec = np.zeros(NN)
                norm = sum([len(curr_comp) for curr_comp in comps]) # should be the number of nodes
                r_degss = {r: G_R.out_degree(r) for r in G_R.nodes}
                norm_degs = sum([G_R.out_degree(r)+G_R.in_degree(r) for r in G_R.nodes]) # should be twice the number of edges in G_R
                #print("2m_R", 2*G_R.number_of_edges(), norm_degs)
                for comp in comps:
                    for element in comp:
                        avec[R_nodes_map[element]] = 0.5*(G_R.out_degree(element)+G_R.in_degree(element))/(norm_degs) + 0.5*1./(len(comp)*len(comps))
                #print("DEBUG:", NN, norm, r_ids, comps[0], avec, sum(avec))
                edgesToAdd, stat = optimizeDirected(G_R, r_bngbs, r_degss, bmax, powerMethod, pr=pagerank, personalization=avec, damping=damping)
        
        
    # The average degree of the red nodes
    #d_R = np.average([val for (node, val) in G.degree()])
    return edgesToAdd, stat

def optimize(Graph, R_additions, R_degs, b):
    new_edges = []
    R_add_sizes = {r: len(R_additions[r]) for r in R_additions}
    A_R = nx.adjacency_matrix(Graph).sum(axis=0).A
    #Computing the stationary distribution on undirected graph G_R, i.e., d(i)/2m_R
    pi_RC = [deg for node, deg in sorted(dict(Graph.degree()).items())]
    deg_R = pi_RC
    #print(pi_R)
    pi_R_nodes = [node for node, deg in sorted(dict(Graph.degree()).items())]
    #print(pi_R_nodes)
    idx = 0
    R_nodes_map = {}
    for nod in pi_R_nodes:
        R_nodes_map[nod] = idx
        idx = idx + 1
    #print("Degrees", pi_R)
    m_edges = sum(pi_RC)
    pi_R = np.array([el/m_edges for el in pi_RC])
    #if len(pi_R) != A_R.shape[1]: #pi_R is size 'n' and A_R '1xn'
    #    print("ERROR, the two shapes of pi_R, and A_R are different!", len(pi_R), A_R.shape[1])
    x_vector = np.zeros(len(pi_R))
    #M_diag = np.diagonal(np.outer(pi_R.T, A_R)) #dimensions: (nx1) x (1xn) = nxn
    M_diag = [pi_R[nod]*deg_R[nod] for nod in range(len(pi_R))]
    PQ = MaxHeap()
    #Computing gains
    for node in list(Graph.nodes):
        if R_add_sizes[node] > 0:
            M_val = M_diag[R_nodes_map[node]]
            n_deg = R_degs[node] 
            mygain = (M_val/(1.*n_deg)) - (M_val/(n_deg+1.))
            PQ.push((node, mygain))
    #Identify edges to add
    for i in range(b):
        if PQ.is_empty():
            break
        else:
            el = PQ.pop()
            node = el[0]
            n_deg = R_degs[node] 
            node_id = R_nodes_map[el[0]]
            M_val = M_diag[node_id]
            new_edges.append((node, R_additions[node][0]))
            x_vector[node_id] = x_vector[node_id] + 1
            del R_additions[node][0]
            if len(R_additions[node]) == 0:  # node is shortcutted to all available blue nodes.
                del R_additions[node]
                continue
            else: #update gain for node
                mygain = (M_val/(1.*n_deg+x_vector[node_id])) - (M_val/(n_deg+1.+x_vector[node_id]))
                PQ.push((node, mygain))
    return new_edges, pi_R

def optimizeDirected(Graph, R_additions, R_degs, b, powerMethod, pr=False, damping=0.75, personalization=None):
    new_edges = []
    #Computing the stationary distribution on directed graph 
    A_R = nx.adjacency_matrix(Graph).A
    
    if not pr:
        normalizedA = normalize_rows(A_R)
    else:
        if personalization is None:
            NN = A_R.shape[0]
            avec = np.ones(NN)/NN
            normalizedA = personalize_adj(A_R, avec, damping)
            print("Computed PR distrib!")
        else:
            normalizedA = personalize_adj(A_R, personalization, damping)
            print("Computed PR distrib!")
        
    pi_R_nodes = [node for node, deg in sorted(dict(Graph.degree()).items())]
    if not powerMethod:
        #We have to transpose so that Markov transitions correspond to right multiplying by a column vector.  np.linalg.eig finds right eigenvectors.
        evals, evecs = np.linalg.eig(normalizedA.T)
        evec1 = evecs[:,np.isclose(evals, 1)]

        #Since np.isclose will return an array, we've indexed with an array
        #so we still have our 2nd axis.  Get rid of it, since it's only size 1.
        evec1 = evec1[:,0]

        pi_R = evec1 / evec1.sum()
        pi_R = pi_R.real
    else:
        #Start from uniform distribution and apply power method for 50 steps!
        X0 = [1/len(pi_R_nodes) for _ in range(len(pi_R_nodes))]
        pi_R = PM(X0, normalizedA, 30)

    #eigs finds complex eigenvalues and eigenvectors, so you'll want the real part.
    R_add_sizes = {r: len(R_additions[r]) for r in R_additions}
    idx = 0
    R_nodes_map = {}
    for nod in pi_R_nodes:
        R_nodes_map[nod] = idx
        idx = idx + 1
    m_edges = sum(pi_R)
    if len(pi_R) != A_R.shape[1]: #pi_R is size 'n' and A_R '1xn'
        print("ERROR, the two shapes of pi_R, and A_R are different!", len(pi_R), A_R.shape[1])
    x_vector = np.zeros(len(pi_R))
    deg_R = [deg for node, deg in sorted(dict(Graph.out_degree()).items())]
    M_diag = [pi_R[nod]*deg_R[nod] for nod in range(len(pi_R))]
    #M_diag = np.diagonal(np.outer(pi_R.T, A_R)) #dimensions: (nx1) x (1xn) = nxn
    PQ = MaxHeap()
    #Computing gains
    for node in list(Graph.nodes):
        if R_add_sizes[node] > 0:
            M_val = M_diag[R_nodes_map[node]]
            n_deg = R_degs[node] 
            mygain = (M_val/(1.*n_deg)) - (M_val/(n_deg+1.))
            PQ.push((node, mygain))
    #Identify edges to add
    for i in range(b):
        if PQ.is_empty():
            break
        else:
            el = PQ.pop()
            node = el[0]
            n_deg = R_degs[node] 
            node_id = R_nodes_map[el[0]]
            M_val = M_diag[node_id]
            new_edges.append((node, R_additions[node][0]))
            x_vector[node_id] = x_vector[node_id] + 1
            del R_additions[node][0]
            if len(R_additions[node]) == 0:  # node is shortcutted to all available blue nodes.
                del R_additions[node]
                continue
            else: #update gain for node
                mygain = (M_val/(1.*n_deg+x_vector[node_id])) - (M_val/(n_deg+1.+x_vector[node_id]))
                PQ.push((node, mygain))
    return new_edges, pi_R
    
