'''
CNF dataset in PyG.

'''

from typing import List
import os.path as osp
import logging

import torch
from torch_geometric.data import InMemoryDataset

from .data_utils import read_npz_file
from .ordered_data import OrderedData

_logger = logging.getLogger(__name__)


class CNFDataset(InMemoryDataset):
    r"""
    The PyG dataset for CNFDataset.

    Args:
        root (string): Root directory where the dataset should be saved.
        args (object): The arguments specified by the main program.
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        self.root = root
        self.raw_file_name = osp.join(self.root, 'cnf_dataset.npz') 

        assert (transform == None) and (pre_transform == None) and (pre_filter == None), "Cannot accept the transform, pre_transfrom and pre_filter args now."

        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return self.root

    @property
    def processed_dir(self):
        return osp.join(self.root, "cnf")

    @property
    def raw_file_names(self) -> List[str]:
        # since the data is generated on the fly, we don't need the raw files here.
        return [self.raw_file_name]

    @property
    def processed_file_names(self) -> str:
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        '''
        The cnf dataset generation proecess followed by https://github.com/ryanzhangfan/NeuroSAT/blob/master/src/data_maker.py
        Here we ignore the constraint of `max_nodes_per_batch`.
        '''
        data_list = []

        dataset = read_npz_file(self.raw_file_name)['dataset'].item()
        data_list = []

        for instance_name in dataset.keys():
            data_item = dataset[instance_name]
        
            forward_index = []
            backward_index = []
            for x_idx in range(len(data_item['x'])):
                forward_index.append(x_idx)
                backward_index.append(len(data_item['x']) - x_idx - 1)
            forward_level = []
            backward_level = []
            for x_info in data_item['x']:
                if x_info[0] == 1 or x_info[1] == 1:
                    forward_level.append(0)
                    backward_level.append(2)
                elif x_info[2] == 1:
                    forward_level.append(1)
                    backward_level.append(1)
                else:
                    forward_level.append(2)
                    backward_level.append(0)
            graph = OrderedData(x=data_item['x'], edge_index=data_item['edge_index'], 
                    forward_level=forward_level, backward_level=backward_level, 
                    forward_index=forward_index, backward_index=backward_index)

            for item_key in data_item.keys():
                if item_key not in graph.keys:
                    graph[item_key] = data_item[item_key]

            data_list.append(graph)
            # _logger.info('Parse {}'.format(instance_name))
            # _logger('parse {}'.format(instance_name))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}({len(self)})'
