import torch
import time
from math import ceil
import torch.nn
import torch
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN
from torch_geometric.utils import softmax, add_self_loops, remove_self_loops, segregate_self_loops, remove_isolated_nodes, contains_isolated_nodes, add_remaining_self_loops
from utils import get_diracs
import scipy
import scipy.io
from matplotlib.lines import Line2D
import GPUtil
import pickle
from torch_geometric.data import DataListLoader, DataLoader
from random import shuffle
from torch_geometric.datasets import TUDataset
import visdom 
from visdom import Visdom 
import numpy as np
import matplotlib.pyplot as plt
import gurobipy as gp
from gurobipy import GRB
from gnn_model import mis_MPNN, ErdosLoss_mis
from torch_geometric.nn.norm.graph_size_norm import GraphSizeNorm
import argparse
import os
from pathlib import Path
import yaml
from dataset.test_20 import REGULAR_test
import networkx as nx
from torch_geometric.utils import to_networkx

def decode_mis_sorted_faster(nx_graph, probs):
    probs = probs.squeeze()
    sorted_index = torch.argsort(probs)
    num_nodes = probs.shape[0]
    for i in range(num_nodes):
        node_to_change = sorted_index[num_nodes-i-1]
        #import pdb; pdb.set_trace()
        neighbors = [n for n in nx_graph.neighbors(node_to_change.item())]
        penalty = 0
        for j in neighbors:
            penalty = penalty + probs[j]
        if 0 < penalty*1.01 - 1:
            probs[node_to_change]=0
        else:
            probs[node_to_change]=1
    return probs.sum()


def decode_mis_sorted(data, probs):
    edge_index = data.edge_index
    no_loop_index,_ = remove_self_loops(edge_index)  
    row, col = no_loop_index
    probs = probs.squeeze()
    sorted_index = torch.argsort(probs)
    num_nodes = probs.shape[0]
    #import pdb; pdb.set_trace()
    for i in range(data.x.shape[0]):
        tmp_0 = probs.clone()
        tmp_1 = probs.clone()
        tmp_0[sorted_index[num_nodes-i-1]] = 0
        tmp_1[sorted_index[num_nodes-i-1]] = 1
        tmp_0_row = tmp_0[row]
        tmp_0_col = tmp_0[col]
        tmp_1_row = tmp_1[row]
        tmp_1_col = tmp_1[col]
        weight_0 = tmp_0.sum()
        weight_1 = tmp_1.sum()
        penalty_0 =  ((tmp_0_row) * (tmp_0_col)).sum()
        penalty_1 =  ((tmp_1_row) * (tmp_1_col)).sum()
        #import pdb; pdb.set_trace()
        if penalty_0 - weight_0 < penalty_1 - weight_1:
            probs = tmp_0.clone()
        else:
            probs = tmp_1.clone()
    probs_row = probs[row]
    probs_col = probs[col]
    penalty = ((probs_row) * (probs_col)).sum()
    #print(probs)
    if penalty > 0:
        print("penalty infer larger than 0: "+str(penalty))
    return probs.sum()


def main():
    parser = argparse.ArgumentParser(description='this is the arg parser for max clique problem')
    parser.add_argument('--save_path', dest = 'save_path',default = 'train_files/max_clique/new_train/')
    parser.add_argument('--gpu', dest = 'gpu',default = '0')
    args = parser.parse_args()

    
    # prepare the dataset
    cfg = Path('./dataset/configs/config.yaml')
    cfg_dict = yaml.safe_load(cfg.open('r'))
    testdata = REGULAR_test(cfg_dict['test_20'])
    dataset = testdata
    batch_size = 1
    test_loader = DataLoader(testdata, batch_size, shuffle=False)
    device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu')
    #device = torch.device('cpu')
    #set up random seeds 
    torch.manual_seed(66)
    np.random.seed(2)   
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False

    #hyper-parameters
    numlayers = 6
    receptive_field = numlayers + 1
    penalty_coeff = 4
    hidden_1 = 64
    hidden_2 = 1
    
    net =  mis_MPNN(dataset, numlayers, hidden_1, hidden_2 ,1)

    #state_dict = torch.load('put your model here', map_location = torch.device('cpu'))
    #net.eval()
    net.reset_parameters()
    net.load_state_dict(state_dict)
    net.to(device)
    model_output = np.zeros(len(testdata))
    greedy_output = np.zeros(len(testdata))
    gt_output = []
    model_index = -1
    time_list = []
    for data in test_loader:
        model_index = model_index + 1
        time_per_data = 0
        for k in range(1):
            nx_graph = to_networkx(data)
            data = data.to(device)
            greedy_output[model_index]=data.greedy_num.item()
            start_t = time.time()
            #import pdb; pdb.set_trace()
            probs = net(data.x, data.edge_index, data.batch, None, penalty_coeff)
            num_vertex = decode_mis_sorted_faster(nx_graph,probs)
            #import pdb; pdb.set_trace()
            if num_vertex > model_output[model_index]:
                model_output[model_index] = num_vertex
            end_t = time.time()
            probs = probs.to(torch.device('cpu'))
            time_per_data = time_per_data + end_t - start_t
            
        time_list.append(time_per_data)
        vertex_gt = data.max_set
        gt_output.append(vertex_gt.item())
        print('model_index:'+str(model_index)+" gt:"+str(vertex_gt.item())+' model:'+str(model_output[model_index])+" all_node:"+str(data.x.shape[0])+" time:"+str(time_per_data))
    #import pdb; pdb.set_trace()
    
    ratio_103 = [(model_output[i])/gt_output[i] for i in range(20)]
    ratio_104 = [(model_output[i])/gt_output[i] for i in range(20,40)]
    ratio_103_greedy = [(greedy_output[i])/gt_output[i] for i in range(20)]
    ratio_104_greedy = [(greedy_output[i])/gt_output[i] for i in range(20,40)]
    
    #ratios = ratios.numpy()
    print(f"Mean ratio 103 model: {(np.array(ratio_103)).mean()} +/-  {(np.array(ratio_103)).std()}")
    print(f"Mean ratio 104 model: {(np.array(ratio_104)).mean()} +/-  {(np.array(ratio_104)).std()}")
    print(f"Mean ratio 103 greedy: {(np.array(ratio_103_greedy)).mean()} +/-  {(np.array(ratio_103_greedy)).std()}")
    print(f"Mean ratio 104 greedy: {(np.array(ratio_104_greedy)).mean()} +/-  {(np.array(ratio_104_greedy)).std()}")
    

    print('avg_time:')
    time_3 = [time_list[i] for i in range(1,20)]
    time_4 = [time_list[i] for i in range(21,40)]
    print(np.mean(time_3))
    print(np.mean(time_4))

if __name__ == '__main__':
    main()