import pickle, copy, os

class Data:
  def __init__(self, name=None):
    self._locked = False
    self.name = name
    self.graph_to_label = {} 
    self.graphs = [] 
    self.train_graphs = set()
    self.val_graphs = set()
    self.test_graphs = set()
    self.A = []
    self.X_node = {}
    self.X_edge = {}
    self.succ_node_to_nodes = {}
    self.pred_node_to_nodes = {}
    self.nodes_to_edge = {}
    self.left_graphs = set()
    self.target_labeled_graphs = set()
    self.label_to_graphs = {}
    self.epsilon = 0.0
    self.expected = 0.0
    self.timeLimit = 0

  def freeze(self):
    self._locked = True
  
  def __setattr__(self, key, value):
    if getattr(self, '_locked', False):
      raise AttributeError("data is immutable")
    super().__setattr__(key, value)


def check_error(name):
  base_path = os.path.join("datasets", name)
  files_to_check = [
    f"{name}_train.txt", f"{name}_val.txt", f"{name}_test.txt",
    f"{name}_graph_indicator.txt", f"{name}_A.txt", f"{name}_graph_labels.txt"
  ]
  for file_name in files_to_check:
    file_path = os.path.join(base_path, file_name)
    if not os.path.exists(file_path):
      raise ValueError(f"{file_name} does not exist in {base_path}")

def process_graph_indicator(name):
  nodes = set()
  graph_to_nodes = {}
  file_path = os.path.join("datasets", name, f"{name}_graph_indicator.txt")
  with open(file_path) as file:
    for i, line in enumerate(file):
      graph_idx = int(line.strip()) - 1
      if graph_idx not in graph_to_nodes:
        graph_to_nodes[graph_idx] = []
      graph_to_nodes[graph_idx].append(i)
      nodes.add(i)
  return graph_to_nodes, nodes

def _read_file_to_set(file_path):
    with open(file_path) as f:
        return {int(line.strip()) for line in f}

def _read_labels(file_path):
    if not os.path.exists(file_path):
        #print(f"Not exist : {file_path}")
        return None
    labels = {}
    with open(file_path) as f:
        for i, line in enumerate(f):
            labels[i] = int(line.strip())
    return labels

def _read_attributes(file_path):
    if not os.path.exists(file_path):
        #print(f"Not exist : {file_path}")
        return None
    attributes = {}
    with open(file_path) as f:
        for i, line in enumerate(f):
            text_attributes = (line.strip().split(','))
            float_attributes = [float(attr) for attr in text_attributes]
            attributes[i] = float_attributes 
    return attributes



def load_Data(name):
    data = Data(name)
    check_error(name)
    
    base_path = os.path.join("datasets", name)

    data.train_graphs = _read_file_to_set(os.path.join(base_path, f"{name}_train.txt"))
    data.val_graphs = _read_file_to_set(os.path.join(base_path, f"{name}_val.txt"))
    data.test_graphs = _read_file_to_set(os.path.join(base_path, f"{name}_test.txt"))

    graph_to_nodes, nodes = process_graph_indicator(name)
    node_to_graph = {node: graph_idx for graph_idx, nodes in graph_to_nodes.items() for node in nodes}

    graph_to_edges = {i: [] for i in range(len(graph_to_nodes))}
    A = []
    with open(os.path.join(base_path, f"{name}_A.txt")) as file:
        for i, line in enumerate(file):
            edge = line.strip().split(', ')
            fr_node = int(edge[0]) - 1
            to_node = int(edge[1]) - 1
            A.append((fr_node, to_node))
            if fr_node in node_to_graph:
                graph_idx = node_to_graph[fr_node]
                graph_to_edges[graph_idx].append(i)
    data.A = A

    with open(os.path.join(base_path, f"{name}_graph_labels.txt")) as file:
        for i, line in enumerate(file):
            label = int(line.strip())
            if label < 0:
                raise("We assume labels are non-negative integers")
            data.graph_to_label[i] = label
            data.label_to_graphs.setdefault(label, set()).add(i)
    
    node_to_attributes = _read_attributes(os.path.join(base_path, f"{name}_node_attributes.txt"))
    if node_to_attributes:
        data.X_node = [node_to_attributes[i] for i in sorted(node_to_attributes.keys())]
    else:
        data.X_node = []
        for i in range(len(nodes)):
           data.X_node.append([])
    
    node_to_label = _read_labels(os.path.join(base_path, f"{name}_node_labels.txt"))
    if node_to_label:
        for _, node in enumerate(node_to_label):
            data.X_node[node].append(node_to_label[node])

    edge_to_attributes = _read_attributes(os.path.join(base_path, f"{name}_edge_attributes.txt"))
    if edge_to_attributes:
        data.X_edge = [edge_to_attributes[i] for i in sorted(edge_to_attributes.keys())]
    else:
        data.X_edge = []
        for i in range(len(data.A)):
           data.X_edge.append([])

    edge_to_label = _read_labels(os.path.join(base_path, f"{name}_edge_labels.txt"))
    if edge_to_label:
        for _, edge in enumerate(edge_to_label):
            data.X_edge[edge].append(edge_to_label[edge])

    data.graphs = [[graph_to_nodes[i], graph_to_edges.get(i, [])] for i in range(len(data.graph_to_label))]

    for idx, (fr_node, to_node) in enumerate(data.A):
        data.nodes_to_edge[(fr_node, to_node)] = idx
        data.succ_node_to_nodes.setdefault(fr_node, set()).add((idx, to_node))
        data.pred_node_to_nodes.setdefault(to_node, set()).add((idx, fr_node))

    data.timeLimit = 10
    return data


def load_BBBP():
  data = Data()
  data.name = 'BBBP'

  with open("datasets/BBBP/tr.pickle", 'rb') as f:
    train_graphs = pickle.load(f)

  with open("datasets/BBBP/va.pickle", 'rb') as f:
    val_graphs = pickle.load(f)

  with open("datasets/BBBP/te.pickle", 'rb') as f:
    test_graphs = pickle.load(f)

  with open("datasets/BBBP/graph_to_label_bbbp.pickle", 'rb') as f:
    graph_to_label = pickle.load(f)

  with open("datasets/BBBP/X_node_bbbp.pickle", 'rb') as f:
    X_node = pickle.load(f)

  with open("datasets/BBBP/X_edge_bbbp.pickle", 'rb') as f:
    X_edge = pickle.load(f)

  with open("datasets/BBBP/new_A_bbbp.pickle", 'rb') as f:
    A = pickle.load(f)

  with open("datasets/BBBP/graphs_bbbp.pickle", 'rb') as f:
    graphs = pickle.load(f)



  new_graphs = []
  for i, val in enumerate(graphs):
    new_graph = []
    nodes = set()
    for _, edge_idx in enumerate(val[1]):
      (fr, to) = A[edge_idx]
      nodes.add(fr)
      nodes.add(to)
    nodes= list(nodes)
    new_graph.append(nodes)
    new_graph.append(val[1])
    new_graphs.append(new_graph)
  graphs = new_graphs



  label_to_graphs = {} 
  label_to_graphs[0] = set()
  label_to_graphs[1] = set()
  new_graph_to_label = {}

  for i, val in enumerate(graph_to_label):
    if graph_to_label[i] == 1:
      label_to_graphs[1].add(i)
      new_graph_to_label[i] = 1
    elif graph_to_label[i] == -1:
      label_to_graphs[0].add(i)
      new_graph_to_label[i] = 0 
    else:
      print("Cannot be happened")
      raise ValueError


 
  succ_node_to_nodes = {}
  pred_node_to_nodes = {}
  nodes_to_edge = {}

  for idx, val in enumerate(A):
    fr_node = val[0]
    to_node = val[1]
    nodes_to_edge[(fr_node, to_node)] = idx
    if not fr_node in succ_node_to_nodes:
      succ_node_to_nodes[fr_node] = set()
    if not to_node in pred_node_to_nodes:
      pred_node_to_nodes[to_node] = set()
    succ_node_to_nodes[fr_node].add((idx, to_node))
    pred_node_to_nodes[to_node].add((idx, fr_node))

 
  data.train_graphs = train_graphs
  data.val_graphs = val_graphs
  data.test_graphs = test_graphs
  data.graphs = graphs
  data.X_edge = X_edge
  data.X_node = X_node
  data.A = A
  #data.graph_to_label = graph_to_label
  data.graph_to_label = new_graph_to_label
  data.label_to_graphs = label_to_graphs
  data.succ_node_to_nodes = succ_node_to_nodes
  data.pred_node_to_nodes = pred_node_to_nodes
  data.nodes_to_edge = nodes_to_edge
  data.timeLimit = 10
  return data
  #BBBP



def load_BACE():
  data = Data()
  data.name = 'BACE'

  with open("datasets/BACE/tr.pickle", 'rb') as f:
    train_graphs = pickle.load(f)

  with open("datasets/BACE/va.pickle", 'rb') as f:
    val_graphs = pickle.load(f)

  with open("datasets/BACE/te.pickle", 'rb') as f:
    test_graphs = pickle.load(f)

  with open("datasets/BACE/graph_to_label_bace.pickle", 'rb') as f:
    graph_to_label = pickle.load(f)

  with open("datasets/BACE/X_node_bace.pickle", 'rb') as f:
    X_node = pickle.load(f)

  with open("datasets/BACE/X_edge_bace.pickle", 'rb') as f:
    X_edge = pickle.load(f)

  with open("datasets/BACE/new_A_bace.pickle", 'rb') as f:
    A = pickle.load(f)

  with open("datasets/BACE/graphs_bace.pickle", 'rb') as f:
    graphs = pickle.load(f)



  new_graphs = []
  for i, val in enumerate(graphs):
    new_graph = []
    nodes = set()
    for _, edge_idx in enumerate(val[1]):
      (fr, to) = A[edge_idx]
      nodes.add(fr)
      nodes.add(to)
    nodes= list(nodes)
    new_graph.append(nodes)
    new_graph.append(val[1])
    new_graphs.append(new_graph)
  graphs = new_graphs



  label_to_graphs = {} 
  label_to_graphs[0] = set()
  label_to_graphs[1] = set()
  new_graph_to_label = {}

  for i, val in enumerate(graph_to_label):
    if graph_to_label[i] == 1:
      label_to_graphs[1].add(i)
      new_graph_to_label[i] = 1
    elif graph_to_label[i] == 0:
      label_to_graphs[0].add(i)
      new_graph_to_label[i] = 0 
    else:
      print("Cannot be happened")
      raise ValueError


 
  succ_node_to_nodes = {}
  pred_node_to_nodes = {}
  nodes_to_edge = {}

  for idx, val in enumerate(A):
    fr_node = val[0]
    to_node = val[1]
    nodes_to_edge[(fr_node, to_node)] = idx
    if not fr_node in succ_node_to_nodes:
      succ_node_to_nodes[fr_node] = set()
    if not to_node in pred_node_to_nodes:
      pred_node_to_nodes[to_node] = set()
    succ_node_to_nodes[fr_node].add((idx, to_node))
    pred_node_to_nodes[to_node].add((idx, fr_node))

 
  data.train_graphs = train_graphs
  data.val_graphs = val_graphs
  data.test_graphs = test_graphs
  data.graphs = graphs
  data.X_edge = X_edge
  data.X_node = X_node
  data.A = A
  data.graph_to_label = graph_to_label
  data.label_to_graphs = label_to_graphs
  data.succ_node_to_nodes = succ_node_to_nodes
  data.pred_node_to_nodes = pred_node_to_nodes
  data.nodes_to_edge = nodes_to_edge
  data.timeLimit = 10
  return data



#if __name__ == "__main__":
#    data = load_Data('PTC_MR')
#    print(len(data.train_graphs))
#    print(len(data.val_graphs))
#    print(len(data.test_graphs))




