# generate application 1 dataset
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
import pickle
from pathlib import Path
import yaml
import re
import itertools
from torch_geometric.data import DataLoader
from utils import get_diracs
from tqdm import tqdm
from gurobipy import * 
import gurobipy as gp
from gurobipy import GRB
from torch_geometric.datasets import TUDataset
import pulp
import networkx as nx
from torch_geometric.utils import erdos_renyi_graph, to_networkx, from_networkx
import random
from random import choice
from torch_geometric.utils import degree, remove_self_loops

def greedy(data):
    edges = remove_self_loops(data.edge_index)
    edges = edges[0]
    original_nodes_num = data.num_nodes
    sort_counter = 0
    deleted_nodes = []
    chosen_mask = torch.zeros(original_nodes_num)

    while edges.shape[1] >0:
        # sort the degree
        deg = degree(edges[0], original_nodes_num)
        sorted_index = torch.argsort(deg)
        # choose the node
        neighbor_list = []
        neighbor_list.append(sorted_index[sort_counter])
        deleted_nodes.append(sorted_index[sort_counter].item())
        chosen_mask[sorted_index[sort_counter]] = 1
        # find its neighbors
        wheres_node = (edges==sorted_index[sort_counter]).nonzero(as_tuple=False)
        for i in range(wheres_node.shape[0]):
            neighbor_list.append(edges[1-wheres_node[i][0]][wheres_node[i][1]].item())
            deleted_nodes.append(edges[1-wheres_node[i][0]][wheres_node[i][1]].item())
        neighbor_list = list(dict.fromkeys(neighbor_list))
        deleted_nodes = list(dict.fromkeys(deleted_nodes))
        # delete the nodes and its neighbors
        for q in range(len(neighbor_list)):
            ids = (edges==neighbor_list[q]).nonzero(as_tuple=False)
            columns_to_remove = []
            all_lists = np.arange(edges.shape[1])
            for p in range(ids.shape[0]):
                columns_to_remove.append(ids[p][1])
            all_lists = np.delete(all_lists,columns_to_remove)
            edges = edges[:,all_lists]
        # check the situation where ther would be isolate nodes
        deg = degree(edges[0], original_nodes_num)
        deg0_list = (deg==0).nonzero(as_tuple = False).reshape(-1).numpy().tolist()
        if len(deg0_list)!=len(deleted_nodes):
            missed_nodes = list(set(deg0_list)^set(deleted_nodes))
            chosen_mask[missed_nodes] = 1
            for i in range(len(missed_nodes)):
                deleted_nodes.append(missed_nodes[i])
        # update the sort counter
        sort_counter = len(deleted_nodes)
    # check if it is independent
    ###############
    edges, _ = remove_self_loops(data.edge_index)
    row, col = edges
    probs_row = chosen_mask[row]
    probs_col = chosen_mask[col]
    penalty = ((probs_row) * (probs_col)).sum()
    if penalty>0:
        print('mother fucker wrong')
    return chosen_mask

def generate_graph(n, d=None, p=None, graph_type='reg', random_seed=0):
    """
    Helper function to generate a NetworkX random graph of specified type,
    given specified parameters (e.g. d-regular, d=3). Must provide one of
    d or p, d with graph_type='reg', and p with graph_type in ['prob', 'erdos'].
    Input:
        n: Problem size
        d: [Optional] Degree of each node in graph
        p: [Optional] Probability of edge between two nodes
        graph_type: Specifies graph type to generate
        random_seed: Seed value for random generator
    Output:
        nx_graph: NetworkX OrderedGraph of specified type and parameters
    """
    if graph_type == 'reg':
        print(f'Generating d-regular graph with n={n}, d={d}, seed={random_seed}')
        nx_temp = nx.random_regular_graph(d=d, n=n, seed=random_seed)
    elif graph_type == 'prob':
        print(f'Generating p-probabilistic graph with n={n}, p={p}, seed={random_seed}')
        nx_temp = nx.fast_gnp_random_graph(n, p, seed=random_seed)
    elif graph_type == 'erdos':
        print(f'Generating erdos-renyi graph with n={n}, p={p}, seed={random_seed}')
        nx_temp = nx.erdos_renyi_graph(n, p, seed=random_seed)
    else:
        raise NotImplementedError(f'!! Graph type {graph_type} not handled !!')

    # Networkx does not enforce node order by default
    nx_temp = nx.relabel.convert_node_labels_to_integers(nx_temp)
    # Need to pull nx graph into OrderedGraph so training will work properly
    nx_graph = nx.OrderedGraph()
    nx_graph.add_nodes_from(sorted(nx_temp.nodes()))
    nx_graph.add_edges_from(nx_temp.edges)
    return nx_graph


class REGULAR_train(InMemoryDataset):
    def __init__(self, config:dict):
        self.config = config
        self.data_path = Path(config['data_dir'])
        super(REGULAR_train, self).__init__(root=self.data_path)
        self.data, self.slices = torch.load(self.processed_paths[0])
    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['data.pt']
    def download(self):
        # Download to `self.raw_dir`.
        pass
    def get_idx_split(self, split_type = 'Random'):
        data_idx = np.arange(2389)
        train_idx = data_idx
        return {'train':torch.tensor(train_idx,dtype = torch.long)}
    def process(self):
        # for each task, it's a separate dataset
        data_list = []
        for task_index in tqdm(range(3000)):
            n_list = [1000]
            d_list = [3,5,7,10,20]
            ratio_list = [0.45537, 0.38443, 0.33567, 0.28521, 0.19732]
            n_ = choice(n_list)
            d_ = choice(d_list)
            ratio_ = ratio_list[d_list.index(d_)]
            max_set = ratio_ * n_
            nx_graph = generate_graph(n = n_, d = d_, p = None, graph_type = 'reg', random_seed = None)
            data = from_networkx(nx_graph)
            x = torch.zeros(n_).reshape(-1, 1)
            tmp_data_list = []
            tmp_data = Data(x = x, edge_index = data.edge_index)
            tmp_data_list.append(tmp_data)
            tmp_data_loader = DataLoader(tmp_data_list, batch_size = 1)
            for data in tmp_data_loader:
                node_feature = greedy(data)
                node_feature = torch.tensor(node_feature).reshape(-1)
                new_data = get_diracs(data, 1, sparse = True, effective_volume_range=0.15, receptive_field = 5)
                final_data = Data(x = node_feature, edge_index = new_data.edge_index, train_batch = new_data.batch, max_set = max_set)
                data_list.append(final_data)
            
            
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

        #import pdb; pdb.set_trace()
# 85.24 / 130.21
if __name__ == '__main__':
    import os
    configs = Path('./configs')
    for cfg in configs.iterdir():
        if str(cfg).startswith("configs/config"):
            cfg_dict = yaml.safe_load(cfg.open('r'))
            dataset = REGULAR_train(cfg_dict['train_3_20'])
