from ogb.graphproppred import PygGraphPropPredDataset
import pickle
import numpy as np
import os
import pandas as pd
from collections import Counter
from tqdm import tqdm
import dataflow_parser
import torch
from torch_geometric.data import (
  Data,

)




def get_vocab_mapping(seq_list, num_vocab):
  """Adapted from OGB.

    Input:

        seq_list: a list of sequences
        num_vocab: vocabulary size
    Output:
        vocab2idx:
            A dictionary that maps vocabulary into integer index.
            Additioanlly, we also index '__UNK__' and '__EOS__'
            '__UNK__' : out-of-vocabulary term
            '__EOS__' : end-of-sentence
        idx2vocab:
            A list that maps idx to actual vocabulary.
  """

  vocab_cnt = {}
  vocab_list = []
  for seq in seq_list:
    for w in seq:
      if w in vocab_cnt:
        vocab_cnt[w] += 1
      else:
        vocab_cnt[w] = 1
        vocab_list.append(w)

  cnt_list = np.array([vocab_cnt[w] for w in vocab_list])
  topvocab = np.argsort(-cnt_list, kind='stable')[:num_vocab]

  print('Coverage of top {} vocabulary:'.format(num_vocab))
  print(float(np.sum(cnt_list[topvocab])) / np.sum(cnt_list))

  vocab2idx = {
      vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)
  }
  idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab]

  # print(topvocab)
  # print([vocab_list[v] for v in topvocab[:10]])
  # print([vocab_list[v] for v in topvocab[-10:]])

  vocab2idx['__UNK__'] = num_vocab
  idx2vocab.append('__UNK__')

  vocab2idx['__EOS__'] = num_vocab + 1
  idx2vocab.append('__EOS__')

  # test the correspondence between vocab2idx and idx2vocab
  for idx, vocab in enumerate(idx2vocab):
    assert (idx == vocab2idx[vocab])

  # test that the idx of '__EOS__' is len(idx2vocab) - 1.
  # This fact will be used in decode_arr_to_seq, when finding __EOS__
  assert (vocab2idx['__EOS__'] == len(idx2vocab) - 1)

  return vocab2idx, idx2vocab


ds = PygGraphPropPredDataset(name="ogbg-code2", root='../data/')
#vocab2idx, idx2vocab = get_vocab_mapping(
#      [ds[i].y for i in ds.get_idx_split()['train']], 5000)

with open("./vocab_map", "rb") as fp:  # Unpickling
    vocab2idx, idx2vocab = pickle.load(fp)

metadata = {
      'vocab2idx': vocab2idx,
      'idx2vocab': idx2vocab,
      # +2 to account for end of sequence and unknown token
      'num_vocab': str(5000 + 2).encode('utf-8'),
      'max_seq_len': str(5).encode('utf-8')
  }

mapping_dir = os.path.join('../data', 'ogbg_code2', 'mapping')

attr2idx = dict()
for line in pd.read_csv(os.path.join(mapping_dir,
                                     'attridx2attr.csv.gz')).values:
  attr2idx[line[1]] = int(line[0])
type2idx = dict()
for line in pd.read_csv(os.path.join(mapping_dir,
                                     'typeidx2type.csv.gz')).values:
  type2idx[line[1]] = int(line[0])

code_dict_file_path = os.path.join(mapping_dir, 'graphidx2code.json.gz')
code_dict = pd.read_json(code_dict_file_path, orient='split')
#code_dict = pd.read_json(code_dict_file_path, orient='split', lines=True, chunksize=100)
#code_dict = next(code_dict)


node_type_counter = Counter()
node_value_counter = Counter()
edge_name_counter = Counter()

split_ids = ds.get_idx_split()

for split, ids in split_ids.items():
  file_path = os.path.join('../data/ogbg-code2-norev-df', split)
  os.makedirs(file_path, exist_ok=True)
  os.makedirs(os.path.join(file_path, 'raw'), exist_ok=True)

  buffer = []
  start_id = ids[0]
  for id_ in tqdm(list(ids)): # for testing, let us just save 100 datapoints
    try:
      id_ = id_.item()
      df_graph = dataflow_parser.py2ogbgraph(
          code_dict.iloc[id_].code, attr2idx, type2idx)[0]

      for node_type in df_graph['node_feat_raw'][:, 0]:
        node_type = node_type.decode('utf-8')
        node_type_counter[node_type] += 1
      for node_value in df_graph['node_feat_raw'][:, 1]:
        node_value = node_value.decode('utf-8')
        node_value_counter[node_value] += 1
      for edge_name in df_graph['edge_name'].squeeze():
        edge_name = edge_name.decode('utf-8')
        edge_name_counter[edge_name] += 1


      # transform df_graph to pyg graph
      pyg_graph = Data(x=torch.tensor(df_graph['node_feat']),
                       node_feat_raw=df_graph['node_feat_raw'],
                       node_dfs_order=torch.tensor(df_graph['node_depth']),
                       node_depth = torch.tensor(df_graph['node_depth']),
                       node_is_attributed = torch.tensor(df_graph['node_is_attributed']),
                       edge_index=torch.tensor(df_graph['edge_index']),
                       edge_attr=torch.tensor(df_graph['edge_type']),
                       edge_order=torch.tensor(df_graph['edge_order']),
                       edge_name=df_graph['edge_name'],
                       num_nodes=torch.tensor(df_graph['num_nodes']),
                       num_edges=torch.tensor(df_graph['num_edges']),
                       y=ds[id_].y,
                       )
      buffer.append(pyg_graph)
    except:  # pylint: disable=bare-except
      print(f'Error for graph {id_}')
      print(code_dict.iloc[id_].code)

    #if len(buffer) >= _SHARD_SIZE.value or id_ == ids[-1]:
    #if id_ == ids[-1]:
  file_name = os.path.join(
      file_path, 'raw', f'{start_id}_{id_}_raw.npz')
  #np.savez_compressed(file_name, data=np.array(buffer, dtype='object'))
  with open(file_name, 'wb') as f:
    pickle.dump(buffer, f)
  #logging.info('Wrote %d to %s', len(buffer), file_name)
  #buffer = []
  #start_id = id_

topk_node_values = node_value_counter.most_common(11_972)
node_values_to_idx = {k: v for v, (k, _) in enumerate(topk_node_values)}


def node_values_to_idx_with_default(value):
  if value in node_values_to_idx:
    return node_values_to_idx[value]
  return len(node_values_to_idx)


node_type_to_idx = {
  k: v for v, (k, _) in enumerate(node_type_counter.most_common())}
edge_name_to_idx = {
  k: v for v, (k, _) in enumerate(edge_name_counter.most_common())}

metadata['node_values_to_idx'] = node_values_to_idx
metadata['node_type_to_idx'] = node_type_to_idx
metadata['edge_name_to_idx'] = edge_name_to_idx

file = os.path.join("../data/ogbg-code2-norev-df", 'meta.npz')
np.savez_compressed(file, data=np.array(metadata, dtype='object'))
#logging.info('Wrote %s', file)

for split in split_ids.keys():
  file_path = os.path.join("../data/ogbg-code2-norev-df", split)

  files = os.listdir(os.path.join(file_path, 'raw'))
  for f in tqdm(files):
    if 'raw' not in f:
      continue

    buffer = []
    with open(os.path.join(file_path, 'raw', f), 'rb') as out:
      data_list = pickle.load(out)

    #for graph, *remainder in np.load(
            #os.path.join(file_path, 'raw', f), allow_pickle=True)["data"]:
    for graph in data_list:
      node_types = [
        node_type_to_idx[node_type.decode('utf-8')]
        for node_type in graph.node_feat_raw[:, 0]
      ]
      node_values = [
        node_values_to_idx_with_default(node_value.decode('utf-8'))
        for node_value in graph.node_feat_raw[:, 1]
      ]
      #graph['node_feat_orig'] = graph['node_feat']
      #graph['node_feat'] = np.array((node_types, node_values),
                                    #dtype=np.int64).transpose()
      graph.x = torch.tensor((node_types, node_values), dtype=torch.int64).transpose(0, 1)
      del graph.node_feat_raw

      edge_names = [
        edge_name_to_idx[edge_name.decode('utf-8')]
        for edge_name in graph.edge_name.squeeze()
      ]
      #graph['edge_name'] = np.array(
        #edge_names, dtype=np.int64)[:, None]
      graph.edge_name = torch.tensor(edge_names, dtype=torch.int64)[:, None]

      # merge all edge attributes into one
      graph.edge_attr = torch.cat([graph.edge_attr, graph.edge_order, graph.edge_name], dim=-1)
      del graph.edge_order, graph.edge_name

      #graphs_tuple = to_graphs_tuple(graph, ogb_edge_types=False)
      #buffer.append((graphs_tuple, *remainder))
      buffer.append(graph)

    file_name = os.path.join(file_path, f.replace('_raw', ''))
    #np.savez_compressed(file_name, data=np.array(buffer, dtype='object'))
    with open(file_name, 'wb') as f:
      pickle.dump(buffer, f)
    #logging.info('Wrote %d to %s', len(buffer), file_name)
