import torch
import time
from math import ceil
import torch.nn
import torch
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN
from torch_geometric.utils import softmax, add_self_loops, remove_self_loops, segregate_self_loops, remove_isolated_nodes, contains_isolated_nodes, add_remaining_self_loops
from torch_geometric.data import DataListLoader, DataLoader
from random import shuffle 
import numpy as np
from gnn_model import mis_MPNN, ErdosLoss_mis
import argparse
import os
from pathlib import Path
import yaml
from dataset.test_20 import REGULAR_test
import networkx as nx
from torch_geometric.utils import to_networkx

def decode_mis_sorted_faster(nx_graph, probs):
    probs = probs.squeeze()
    sorted_index = torch.argsort(probs)
    num_nodes = probs.shape[0]
    for i in range(num_nodes):
        node_to_change = sorted_index[num_nodes-i-1]
        #import pdb; pdb.set_trace()
        neighbors = [n for n in nx_graph.neighbors(node_to_change.item())]
        penalty = 0
        for j in neighbors:
            penalty = penalty + probs[j]
        if 0 < penalty*1.001 - 1:
            probs[node_to_change]=0
        else:
            probs[node_to_change]=1
    return probs.sum()

def decode_mis_sorted(data, probs):
    edge_index = data.edge_index
    no_loop_index,_ = remove_self_loops(edge_index)  
    row, col = no_loop_index
    probs = probs.squeeze()
    sorted_index = torch.argsort(probs)
    num_nodes = probs.shape[0]
    #import pdb; pdb.set_trace()
    for i in range(data.x.shape[0]):
        tmp_0 = probs.clone()
        tmp_1 = probs.clone()
        tmp_0[sorted_index[num_nodes-i-1]] = 0
        tmp_1[sorted_index[num_nodes-i-1]] = 1
        tmp_0_row = tmp_0[row]
        tmp_0_col = tmp_0[col]
        tmp_1_row = tmp_1[row]
        tmp_1_col = tmp_1[col]
        weight_0 = tmp_0.sum()
        weight_1 = tmp_1.sum()
        penalty_0 =  ((tmp_0_row) * (tmp_0_col)).sum()
        penalty_1 =  ((tmp_1_row) * (tmp_1_col)).sum()
        #import pdb; pdb.set_trace()
        if penalty_0 - weight_0 < penalty_1 - weight_1:
            probs = tmp_0.clone()
        else:
            probs = tmp_1.clone()
    probs_row = probs[row]
    probs_col = probs[col]
    penalty = ((probs_row) * (probs_col)).sum()
    #print(probs)
    if penalty > 0:
        print("penalty infer larger than 0: "+str(penalty))
    return probs.sum()


def main():
    parser = argparse.ArgumentParser(description='this is the arg parser for max clique problem')
    parser.add_argument('--save_path', dest = 'save_path',default = 'train_files/max_clique/new_train/')
    parser.add_argument('--gpu', dest = 'gpu',default = '0')
    args = parser.parse_args()

    
    # prepare the dataset
    cfg = Path('./configs/config.yaml')
    cfg_dict = yaml.safe_load(cfg.open('r'))
    testdata = REGULAR_test(cfg_dict['test_5_105'])
    dataset = testdata
    batch_size = 1
    test_loader = DataLoader(testdata, batch_size, shuffle=False)
    device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu')
    #device = torch.device('cpu')
    #set up random seeds 
    torch.manual_seed(123)
    np.random.seed(123)   
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False

    #hyper-parameters
    numlayers = 6
    receptive_field = numlayers + 1
    penalty_coeff = 0.5
    hidden_1 = 64
    hidden_2 = 1
    
    net =  mis_MPNN(dataset, numlayers, hidden_1, hidden_2 ,1)

    state_dict = torch.load('put the model here', map_location = torch.device('cpu'))
    #net.eval()
    net.reset_parameters()
    net.load_state_dict(state_dict)
    net.to(device)
    model_output = np.zeros(len(testdata))
    greedy_output = np.zeros(len(testdata))
    gt_output = []
    model_index = -1
    time_list = []
    for data in test_loader:
        model_index = model_index + 1
        time_per_data = 0
        for k in range(8):
            nx_graph = to_networkx(data)
            data = data.to(device)
            greedy_output[model_index]=data.greedy_num.item()
            start_t = time.time()
            print('start!')
            #import pdb; pdb.set_trace()
            probs = net(data.x, data.edge_index, data.batch, None, penalty_coeff)
            print('finish infer with NN!'+str(time.time()-start_t))
            probs = probs.to(torch.device('cpu'))
            num_vertex = decode_mis_sorted_faster(nx_graph,probs)
            print('finish rounding:'+str(time.time() - start_t))
            #import pdb; pdb.set_trace()
            if num_vertex > model_output[model_index]:
                model_output[model_index] = num_vertex
            end_t = time.time()
            time_per_data = time_per_data + end_t - start_t
            
        time_list.append(time_per_data)
        vertex_gt = data.max_set
        gt_output.append(vertex_gt.item())
        print('model_index:'+str(model_index)+" gt:"+str(vertex_gt.item())+' model:'+str(model_output[model_index])+" all_node:"+str(data.x.shape[0])+" time:"+str(time_per_data))
    #import pdb; pdb.set_trace()
    ratio_105 = [(model_output[i])/gt_output[i] for i in range(20)]
    ratio_105_greedy = [(greedy_output[i])/gt_output[i] for i in range(20)]
    
    #ratios = ratios.numpy()
    print(f"Mean ratio 105 model: {(np.array(ratio_105)).mean()} +/-  {(np.array(ratio_105)).std()}")
    print(f"Mean ratio 105 greedy: {(np.array(ratio_105_greedy)).mean()} +/-  {(np.array(ratio_105_greedy)).std()}")

    print('avg_time:')
    time_5 = [time_list[i] for i in range(1,20)]
    print(np.mean(time_5))
    

if __name__ == '__main__':
    main()