import torch
import numpy as np
import math

class Ring():
    def __init__(self, n_nodes):
        self.w = torch.zeros((n_nodes, n_nodes))

        for i in range(n_nodes):
            self.w[i,i] = 1/3
            self.w[i, (i+1)%n_nodes] = 1/3
            self.w[i, (i-1)%n_nodes] = 1/3


class kRing():
    def __init__(self, n_nodes, k):
        self.w = torch.zeros((n_nodes, n_nodes))

        for i in range(n_nodes):
            self.w[i,i] = 1/ (2*k + 1)
            for j in range(k+1):
                self.w[i, (i+j)%n_nodes] = 1 / (2*k + 1)
                self.w[i, (i-j)%n_nodes] = 1 / (2*k + 1)
            
class Torus():
    def __init__(self, p, q):
        self.n_nodes = p * q
        self.p = p
        self.q = q

        if p <=2 or q<=2:
            print("ERROR")
            
        self.w = torch.zeros((self.n_nodes, self.n_nodes))

        node_list_list = self.split_nodes()
        
        for sub_node_list in node_list_list:
            for i in range(len(sub_node_list)):
                i_idx = sub_node_list[i]
                j_idx = sub_node_list[(i+1) % len(sub_node_list)]
                self.w[i_idx, j_idx] = 1/5
                self.w[j_idx, i_idx] = 1/5
                self.w[i_idx, i_idx] = 1/5
                self.w[j_idx, j_idx] = 1/5
                

        node_list_list2 = self.split_nodes2()
        
        for sub_node_list in node_list_list2:
            for i in range(len(sub_node_list)):
                i_idx = sub_node_list[i]
                j_idx = sub_node_list[(i+1) % len(sub_node_list)]
                self.w[i_idx, j_idx] = 1/5
                self.w[j_idx, i_idx] = 1/5
                self.w[i_idx, i_idx] = 1/5
                self.w[j_idx, j_idx] = 1/5
                
    def split_nodes(self):
        node_list = list(range(self.n_nodes))
        node_list_list = [node_list[i*self.q:(i+1)*self.q] for i in range(self.p)]
        return node_list_list

    
    def split_nodes2(self):
        node_list_list = self.split_nodes()
        node_list_list2 = [[] for _ in range(self.q)]

        for i in range(self.q):
            for j in range(self.p):
                node_list_list2[i].append(node_list_list[j][i])


        return node_list_list2

    
class Line():
    def __init__(self, n_nodes):
        self.w = torch.zeros((n_nodes, n_nodes))

        for i in range(n_nodes):
            if i == 0:
                self.w[0,0] = 2/3
                self.w[0,1] = 1/3
            elif i == n_nodes - 1:
                self.w[n_nodes-1, n_nodes-1] = 2/3
            else:
                self.w[i,i] = 1/3
                self.w[i, (i+1)] = 1/3
                self.w[i, (i-1)] = 1/3
    
    
def max_eigenvalue_and_vector(matrix):
    eigenvalues, eigenvectors = np.linalg.eig(matrix)
    max_index = np.argmax(np.abs(eigenvalues))
    max_eigenvalue = eigenvalues[max_index]
    max_eigenvector = eigenvectors[:, max_index]
    return eigenvalues


def generate_graph(matrix, b):
    eigenvalues, eigenvectors = np.linalg.eig(matrix)
    max_index = np.argmax(np.abs(eigenvalues))
    max_eigenvalue = eigenvalues[max_index]
    max_eigenvector = eigenvectors[:, max_index]
    
    #print(eigenvalues)
    new_eigenvalues = []
    counter = True
    for v in eigenvalues:
        if v**2 == sorted([v**2 for v in eigenvalues])[-1]:
            new_eigenvalues.append(v)
        elif v**2 == sorted([v**2 for v in eigenvalues])[-2] and counter:
            new_eigenvalues.append(v)
            counter = False
        else:
            new_eigenvalues.append(b)
    eigenvalues = np.array(new_eigenvalues)
    
    #print(eigenvalues)
    
    D = np.diag(eigenvalues)
    # Reconstruct the original matrix: A = V * D * V^{-1}
    V_inv = np.linalg.inv(eigenvectors)
    A_reconstructed = eigenvectors @ D @ V_inv
    return A_reconstructed



def calc_average_spectral_gap(matrix):
    singular_values = np.linalg.eig(matrix)[0] 
    #singular_values = np.linalg.svd(matrix, compute_uv=False)
    
    """
    if len(singular_values) <= 4:
        print(sorted(singular_values))
    else:
        print(sorted(singular_values)[:5], " ... ", sorted(singular_values)[-3:])
    """ 
    return sum(v**2 / (1 - v**2) for v in sorted(singular_values)[:-1]) / len(singular_values)

def calc_spectral_gap(matrix):
    singular_values = np.linalg.eig(matrix)[0] 
    #singular_values = np.linalg.svd(matrix, compute_uv=False)
    
    """
    if len(singular_values) <= 4:
        print(sorted(singular_values))
    else:
        print(sorted(singular_values)[:5], " ... ", sorted(singular_values)[-3:])
    """ 
    return max([v**2 / (1 - v**2) for v in sorted(singular_values)[:-1]])

def generate_graph3(matrix, average_spectral_gap):
    eigenvalues, eigenvectors = np.linalg.eig(matrix)
    n_nodes = len(eigenvalues)

    spectral_gap = calc_spectral_gap(matrix)
    
    eig = (average_spectral_gap * n_nodes - spectral_gap) / ((1 + average_spectral_gap) * n_nodes - 2 - spectral_gap)
    
    if eig < 0 or eig >= 1 or average_spectral_gap > spectral_gap:
        print("ERROR")
    eig = math.sqrt(eig)
    #print(eigenvalues)
    new_eigenvalues = []
    counter = True
    for v in eigenvalues:
        if v**2 == sorted([v**2 for v in eigenvalues])[-1]:
            new_eigenvalues.append(v)
        elif v**2 == sorted([v**2 for v in eigenvalues])[-2] and counter:
            new_eigenvalues.append(v)
            counter = False
        else:
            new_eigenvalues.append(eig)
    eigenvalues = np.array(new_eigenvalues)
    
    #print(eigenvalues)
    
    D = np.diag(eigenvalues)
    # Reconstruct the original matrix: A = V * D * V^{-1}
    V_inv = np.linalg.inv(eigenvectors)
    A_reconstructed = eigenvectors @ D @ V_inv
    return A_reconstructed


