import torch
import time
from math import ceil
import torch.nn
import torch
from utils import get_diracs
import pickle
from torch_geometric.data import DataLoader
from random import shuffle
from torch_geometric.datasets import TUDataset
import numpy as np
from  utils import solve_gurobi_maxclique
from gnn_model import clique_MPNN, ErdosLoss_clique
from utils import decode_clique_final_speed
import argparse
import os
from pathlib import Path
import yaml
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
from dataset.rb500_test import RB500_test


def finetune(net, data_prime, optimizer, epoch, device, receptive_field, penalty_coeff, totalretdict, batch_size, criterion):
    optimizer.zero_grad()
    data_prime = data_prime.to(device)
    probs = net(data_prime.x, data_prime.edge_index, data_prime.batch, None, penalty_coeff)
    retdict = criterion(probs, data_prime.edge_index, data_prime.batch, penalty_coeff, device)
    for key,val in retdict.items():
        if "sequence" in val[1]:
            if key in totalretdict:
                totalretdict[key][0] += val[0].item()
            else:
                totalretdict[key] = [val[0].item(),val[1]]
    if epoch > 2:
            retdict["loss"][0].backward()
            #reporter.report()
            torch.nn.utils.clip_grad_norm_(net.parameters(),1)
            optimizer.step()
            del(retdict)
    if epoch > -1:        
        for key,val in totalretdict.items():
            if "sequence" in val[1]:
                val[0] = val[0]/1
        del data_prime
    return probs
    #print(totalretdict.items())
    


def main():
    parser = argparse.ArgumentParser(description='this is the arg parser for max clique problem')
    parser.add_argument('--save_path', dest = 'save_path',default = 'train_files/max_clique/new_train/')
    parser.add_argument('--use_gpu', dest = 'use_gpu',default = '0')
    parser.add_argument('--model_path', dest = 'model_path',default = './train_files/maml/maml_outer1e3_inner5e5_step5_mb128/best_model.pth')
    parser.add_argument('--seed_num', dest = 'seed_num',default = 8)
    args = parser.parse_args()

    
    # prepare the dataset
    cfg = Path('./dataset/configs/config.yaml')
    cfg_dict = yaml.safe_load(cfg.open('r'))
    testdata = RB500_test(cfg_dict['test'])
    dataset = testdata
    batch_size = 1
    test_loader = DataLoader(testdata, batch_size, shuffle=False)
    device = torch.device('cuda:'+str(args.use_gpu) if torch.cuda.is_available() else 'cpu')
    #device = torch.device('cpu')
    #set up random seeds 
    torch.manual_seed(66)
    np.random.seed(2)   
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #hyper-parameters
    numlayers = 4
    receptive_field = numlayers + 1
    penalty_coeff = 4.
    hidden_1 = 64
    hidden_2 = 1
    
    net =  clique_MPNN(dataset, numlayers, hidden_1, hidden_2 ,1)
    model_path = args.model_path
    state_dict = torch.load(args.model_path,map_location = torch.device('cpu'))

    net.train()
    model_output = np.zeros(len(testdata))
    gt_output = []
    model_index = -1
    time_list = []
    num_seed = int(args.seed_num)
    for data in test_loader:
        model_index = model_index + 1
        time_per_data = 0
        for k in range(num_seed):
            # get k different data input
            data_prime = get_diracs(args.use_gpu, data.to(device), 1, sparse = True, effective_volume_range=0.15, receptive_field = receptive_field)
            data_prime = data_prime.to(device)
            net.reset_parameters()
            net.load_state_dict(state_dict)
            net.to(device)
            criterion = ErdosLoss_clique()
            start_t = time.time()
            probs = net(data_prime.x, data_prime.edge_index, data_prime.batch, None, penalty_coeff)
            retdict = criterion(probs, data_prime.edge_index, data_prime.batch, penalty_coeff, device)
            sets, set_edges, set_cardinality = decode_clique_final_speed(data_prime,(retdict["output"][0]), weight_factor =0.,draw=False)
            if set_cardinality.item() > model_output[model_index]:
                model_output[model_index] = set_cardinality
            end_t = time.time()
            time_per_data = time_per_data + end_t - start_t
            
        time_list.append(time_per_data)
        
        print('model_index:'+str(model_index)+' model:'+str(model_output[model_index])+"time:"+str(time_per_data))
    #import pdb; pdb.set_trace()
    path_to_list = './dataset/testset/'
    if not os.path.exists(path_to_list+'testlist.pkl'):
        test_data_clique = []
        for data in testdata:
            my_graph = to_networkx(Data(x=data.x, edge_index = data.edge_index)).to_undirected()
            cliqno, _ = solve_gurobi_maxclique(my_graph, 500)
            data.clique_number = cliqno
            test_data_clique += [data]
        stored_dataset = open(path_to_list + 'testlist.pkl', 'wb')   
        dataset = pickle.dump(test_data_clique, stored_dataset)
    else:
        stored_dataset = open(path_to_list + 'testlist.pkl', 'rb')   
        test_data_clique = pickle.load(stored_dataset)

    ratios = [model_output[i]/test_data_clique[i].clique_number for i in range(len(model_output))]
    print(f"Mean ratio: {(np.array(ratios)).mean()} +/-  {(np.array(ratios)).std()}")

    print('avg_time:')
    print(np.mean(time_list))
    

if __name__ == '__main__':
    main()