'''
DeepGate dataset in PyG.

'''

import imp
from typing import List
import os.path as osp
import logging

import torch
from torch_geometric.data import InMemoryDataset

from .data_utils import read_npz_file
from .parser import circuit_parse_pyg

_logger = logging.getLogger(__name__)


class GATEDateset(InMemoryDataset):
    r"""
    A variety of circuit graph datasets, *e.g.*, open-sourced benchmarks,
    random circuits.

    Args:
        root (string): Root directory where the dataset should be saved.
        use_aig (boolean): use optimized AIG format.        
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """

    def __init__(self, root, dataset_type='optaig', transform=None, pre_transform=None, pre_filter=None, **kwargs):
        self.root = root
        
        self.dataset_type = dataset_type
        self.raw_file_name = osp.join(self.root, '{}_prob_dataset.npz'.format(dataset_type))

        assert (transform == None) and (pre_transform == None) and (pre_filter == None), "Cannot accept the transform, pre_transfrom and pre_filter args now."

        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return self.root

    @property
    def processed_dir(self):
        return osp.join(self.root, "{}_prob".format(self.dataset_type))
    

    @property
    def raw_file_names(self) -> List[str]:
        # return [self.args.circuit_file, self.args.label_file]
        return [self.raw_file_name]

    @property
    def processed_file_names(self) -> str:
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        data_list = []
        circuits = read_npz_file(self.raw_file_name)['circuits'].item()


        for cir_idx, cir_name in enumerate(circuits):
            x = circuits[cir_name]["x"]
            edge_index = circuits[cir_name]["edge_index"]
            y = circuits[cir_name]["y"]
            graph = circuit_parse_pyg(x, edge_index, y, un_directed=False, num_gate_types=4, mask=True)
            graph.name = cir_name
            data_list.append(graph)
            _logger.info('Parse {}'.format(cir_name))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}({len(self)})'