"""MIS (Maximal Independent Set) dataset."""

import glob
import os
import sys
if int(sys.version.split('.')[1])<9:
  import pickle5 as pickle
else:
  import pickle
import numpy as np
import torch
import re
from torch_geometric.data import Data as GraphData
from torch_geometric.data import Batch as GraphBatch

class MISDataset(torch.utils.data.Dataset):
  def __init__(self, data_file, data_label_dir=None, num_testset=-1, shuffle=True):
    self.data_file = data_file
    self.file_lines = glob.glob(data_file)
    self.shuffle=shuffle
    if num_testset >0 :
      self.file_lines = self.file_lines[:num_testset]
    # self.file_lines = self.file_lines[:100] 

    self.data_label_dir = data_label_dir
    print(f'Loaded "{data_file}" with {len(self.file_lines)} examples')

  def __len__(self):
    return len(self.file_lines)

  def get_example(self, idx):
    with open(self.file_lines[idx], "rb") as f:
      graph = pickle.load(f)
    num_nodes = graph.number_of_nodes()

    match = re.search(r'_m(\d+)', self.file_lines[idx])
    if match:
      reward = - int(match.group(1))
    else:
      try:
        reward = - sum([node_data['label'] for node_data in graph.nodes.values()])
      except:
        reward = 0
    if self.data_label_dir is None:
      node_labels = [_[1] for _ in graph.nodes(data='label')]
      if node_labels is not None and node_labels[0] is not None:
        node_labels = np.array(node_labels, dtype=np.int64)
      else:
        node_labels = np.zeros(num_nodes, dtype=np.int64)
    else:
      base_label_file = os.path.basename(self.file_lines[idx]).replace('.gpickle', '_unweighted.result')
      node_label_file = os.path.join(self.data_label_dir, base_label_file)
      with open(node_label_file, 'r') as f:
        node_labels = [int(_) for _ in f.read().splitlines()]
      node_labels = np.array(node_labels, dtype=np.int64)
      assert node_labels.shape[0] == num_nodes

    edges = np.array(graph.edges, dtype=np.int64)
    edges = np.concatenate([edges, edges[:, ::-1]], axis=0)
    # add self loop
    self_loop = np.arange(num_nodes).reshape(-1, 1).repeat(2, axis=1)
    edges = np.concatenate([edges, self_loop], axis=0)
    edges = edges.T

    return num_nodes, node_labels, edges, reward

  def __getitem__(self, idx):
    num_nodes, node_labels, edge_index, reward = self.get_example(idx)
    graph_data = GraphData(x=torch.from_numpy(node_labels),
                           edge_index=torch.from_numpy(edge_index))
    graph_data.edge_length = graph_data.edge_index.shape[-1]
    point_indicator = np.array([num_nodes], dtype=np.int64)
    return (
        torch.LongTensor(np.array([idx], dtype=np.int64)),
        graph_data,
        torch.from_numpy(point_indicator).long(),
        reward
    )


class MIS_ERGraphEnvironment:
    def __init__(self, lower_bound=700, upper_bound=800, p=0.15):

        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.p = 1-np.sqrt(1-p)

    def get_batch(self, batch_size):
        """

        Returns:
            tuple: (None, batch_data, num_nodes_tensor, None)
        """
        graph_list = []
        num_nodes_list = []

        for _ in range(batch_size):
            n = torch.randint(low=self.lower_bound, high=self.upper_bound + 1, size=(1,)).item()
            num_nodes_list.append(n)

            x = torch.zeros(n,  dtype=torch.float)


            rand_matrix = torch.rand(n, n)
            self_loops = torch.eye(n, dtype=torch.bool)
            mask = ((rand_matrix < self.p)+(rand_matrix.T < self.p) + self_loops)>0
          
            edge_index = mask.nonzero(as_tuple=False).t().contiguous()

            data = GraphData(x=x, edge_index=edge_index, edge_length=edge_index.shape[-1])
            graph_list.append(data)

        batch_data = GraphBatch.from_data_list(graph_list)

        num_nodes_tensor = torch.tensor(num_nodes_list, dtype=torch.long)

        return None, batch_data, num_nodes_tensor, None
    

if __name__ == '__main__':
    env = MIS_ERGraphEnvironment()
    a, batch_data, num_nodes_tensor, d = env.get_batch(batch_size=4)
    
    print("a:", a)
    print("batch_data.x.shape:", batch_data.x.shape)          
    print("batch_data.edge_index.shape:", batch_data.edge_index.shape) 
    print("num_nodes_tensor:", num_nodes_tensor)               
    print("d:", d)