'''
CircuitSAT dataset in PyG.

'''

from typing import List
import os.path as osp
import logging


import torch
from torch_geometric.data import InMemoryDataset

from .data_utils import read_npz_file
from .ordered_data import OrderedData, return_order_info

_logger = logging.getLogger(__name__)


class CircuitSATDataset(InMemoryDataset):
    r"""
    The PyG dataset for DGDARGRNN.

    Args:
        root (string): Root directory where the dataset should be saved.
        use_aig (boolean): use AIG format.
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """

    def __init__(self, root, use_aig=False, aig_opt=False, transform=None, pre_transform=None, pre_filter=None, **kwargs):
        self.root = root
        # use_aig = True
        # aig_opt = False
        
        self.use_aig = use_aig

        if use_aig and aig_opt:
            file_name = 'optaig_dataset.npz'
            self.folder_name = 'optaig'
        elif use_aig and not aig_opt:
            file_name = 'aig_dataset.npz'
            self.folder_name = 'aig' 
        else:
            file_name = 'ckt_dataset.npz'
            self.folder_name = 'ckt'
        self.raw_file_name = osp.join(self.root, file_name)

        assert (transform == None) and (pre_transform == None) and (pre_filter == None), "Cannot accept the transform, pre_transfrom and pre_filter args now."

        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return self.root

    @property
    def processed_dir(self):
        return osp.join(self.root, self.folder_name)

    @property
    def raw_file_names(self) -> List[str]:
        return [self.raw_file_name]

    @property
    def processed_file_names(self) -> str:
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        data_list = []

        dataset = read_npz_file(self.raw_file_name)['dataset'].item()
        data_list = []

        for instance_name in dataset.keys():

            data_item = dataset[instance_name]
            if data_item['y'] == 0 or 'UNSAT_' in instance_name: #NOTE: bug? why initia `y` is all 1?
                continue

            x = data_item['x']
            edge_index=data_item['edge_index']

            forward_level, forward_index, backward_level, backward_index = return_order_info(edge_index, x.size(0))
            graph = OrderedData(x=x, edge_index=edge_index, forward_level=forward_level, forward_index=forward_index, 
                                backward_level=backward_level, backward_index=backward_index)

            for item_key in data_item.keys():
                if item_key not in graph.keys:
                    graph[item_key] = data_item[item_key]

            data_list.append(graph)
            _logger.info('Parse {}'.format(instance_name))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}({len(self)})'
