import os
import os.path as osp
import pickle
import shutil
from typing import Callable, List, Optional
import mmcv
import torch
from tqdm import tqdm
import numpy as np
import scipy.io as scio
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)




class Tile(InMemoryDataset):
    def __init__(self, url=None, root='data', processed_suffix='', split='train',
                 transform=None, pre_transform=None, pre_filter=None):
        self.url = url
        self.root = root
        self.transform = transform
        self.pre_filter = pre_filter
        self.pre_transform = pre_transform
        self.split = split
        self.raw = os.path.join(root, 'tile_xla')
        self.processed_suffix = processed_suffix
        super(Tile, self).__init__(root=root, transform=transform, pre_transform=pre_transform,
                                                 pre_filter=pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        name = 'raw/' + self.split
        return os.path.join(self.raw, name)

    @property
    def processed_dir(self):
        return os.path.join(self.raw, 'processed'+self.processed_suffix)

    @property
    def raw_file_names(self):
        names = os.listdir(self.raw_dir)
        return names

    @property
    def processed_file_names(self):
        return ['data_'+self.split+'.pt']

    def np2pyg(self, raw_data):
        edge_index = raw_data['edge_index']
        num_nodes = np.max(edge_index) + 1
        x = torch.tensor(raw_data['node_feat'])
        opcode = torch.tensor(raw_data['node_opcode']).int()
        config_feat = torch.tensor(raw_data['config_feat'])
        config_feat_marker = config_feat.shape[0]  # a single number[c]
        y = raw_data['config_runtime'] / raw_data['config_runtime_normalizers']
        return Data(edge_index=torch.tensor(edge_index).transpose(0, 1), x=x, opcode=opcode, num_nodes=num_nodes, config_feat=config_feat,
                    y=torch.tensor(y).float(), config_feat_marker=config_feat_marker)



    def process(self):
        # process npy data into pyg.Data
        print('Processing data from ' + self.raw_dir + '...')
        data_list = []
        count = 0
        for file in self.raw_paths:
            raw_data = dict(np.load(file, allow_pickle=True))
            data_list.append(self.np2pyg(raw_data))
            count += 1
            if count % 500 == 0:
                print('Loading raw data: #%d' % count)
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]
        if self.pre_transform is not None:
            print('pre-transforming for data at ' + self.processed_paths[0])
            temp = []
            for i, data in enumerate(data_list):
                if i != 0 and i % 5000 == 0:
                    print('Pre-processing %d/%d' % (i, len(data_list)))
                    break
                temp.append(self.pre_transform(data))
            data_list = temp
            # data_list = [self.pre_transform(data) for data in data_list]
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
