import networkx as ntx
import numpy as np
from data_handling import get_zinc_data
import numpy as np
from matplotlib import pyplot as plt
import torch
import math
from torch_geometric.utils import to_networkx
import torch.optim as optim
from models import *
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool, global_add_pool
from scipy.linalg import pinv
from scipy.sparse import csr_matrix


def get_resistance_indices(alpha, data):
    G = to_networkx(data, to_undirected=True)
    v = G.number_of_nodes()
    nLap = ntx.laplacian_matrix(G)
    nLap = csr_matrix.todense(nLap) + 1. / G.number_of_nodes()
    nLap = pinv(nLap)

    diag = np.diag(nLap)
    diag = diag.reshape(diag.size, 1)
    Q_ii = np.repeat(diag, diag.shape[0], axis=-1)
    Q_jj = np.repeat(diag.T, diag.shape[0], axis=0)
    resist = Q_ii + Q_jj - 2. * nLap

    resist_flat = np.ndarray.flatten(resist)
    sorts = np.argsort(resist_flat)
    chosen_nodes = sorts[v + int((v * v - 1 - v) * alpha)]
    i = int(chosen_nodes / v)
    j = chosen_nodes % v

    commute_time = resist[i,j]*2*G.number_of_edges()

    return i, j, commute_time

def add_rand_input(i,j,data,mix_type):
    if mix_type == 0 or mix_type == 1:
        randA, randB = torch.rand(1), torch.rand(1)
    elif mix_type == 2:
        randA, randB = torch.rand(1)*1.5, torch.rand(1)*1.5
    new_x = torch.zeros(data.x.size(0),1).float()
    new_x[i] = randA
    new_x[j] = randB
    data.x = new_x

    return randA, randB, data

def mixing(randA,randB,data,mix_type):
    if mix_type == 0:
        data.y = torch.tanh(randA+randB)
    elif mix_type == 1 or mix_type == 2:
        data.y = torch.exp(randA+randB)
    return data

def get_distance_between_two_nodes(i,j,data):
    G = to_networkx(data, to_undirected=True)
    spl = ntx.shortest_path_length(G,i,j)
    return spl

def get_eff_res_data(alpha=0.25,mix_type=0):
    max_distance = 0
    distances = []
    commute_times = []
    train_dataset = get_zinc_data('train')
    test_dataset = get_zinc_data('test')
    val_dataset = get_zinc_data('val')

    new = []

    with torch.no_grad():
        for set_index, set in enumerate([train_dataset, test_dataset, val_dataset]):
            new_dataset = []
            for data in set:
                i, j, commute_time = get_resistance_indices(alpha,data)
                commute_times.append(commute_time)
                distance = get_distance_between_two_nodes(i,j,data)
                if(set_index==0):
                    distances.append(distance)
                if distance > max_distance:
                    max_distance = distance
                randA, randB, data = add_rand_input(i,j,data,mix_type)
                mixing(randA,randB,data,mix_type)
                new_dataset.append(data)
            new.append(new_dataset)

    print("Finished preparing dataset")
    print("Max distance: ",max_distance)

    return new[0], new[1], new[2]