from tqdm import tqdm
import os
import random
import pickle
import numpy as np
from easydict import EasyDict
import os.path as osp
import pandas as pd
import re
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import remove_self_loops, to_undirected
from torch_geometric.transforms import Compose
import torch_geometric.transforms as T
from utils.get_mag_lap import AddMagLaplacianEigenvectorPE, AddLaplacianEigenvectorPE

from utils.misc import create_nested_folder


MEAN = {'dsp': 11.1125, 'cp': 7.7051}
STD = {'dsp': 11.3296, 'cp': 2.1629}

def add_node_attr(data, value, attr_name):
    if attr_name is None:
        if 'x' in data:
            x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)
        else:
            data.x = value
    else:
        data[attr_name] = value

    return data


def get_target_transform(target_name, mean, std):
    def my_transform(data):
        data.y = (data[target_name] - mean) / std
        #data.y = data[target_name]
        return data
    return my_transform

def symmetrize_transform(data):
    data.edge_index = torch.cat([data.edge_index, data.edge_index[[1, 0]]], dim=-1)
    data.edge_attr = torch.cat([data.edge_attr, data.edge_attr], dim=0)
    return data


class HLSDataProcessor(InMemoryDataset):
    def __init__(self, config, mode):
        self.config = config
        processed_folder = config['task']['processed_folder']
        self.save_folder = osp.join(processed_folder, str(config['task']['name'])+'_'+str(config['task']['type']))
        create_nested_folder(self.save_folder)
        self.divide_seed = config['task']['divide_seed']
        self.mode = mode
        self.raw_data_root = config['task']['raw_data_path']
        # add PE
        pe_type = config['model']['args'].get('pe_type')
        if pe_type is not None:
            pe_dim = config['model']['args']['pe_dim_input']
            if pe_type == 'lap':
                pre_transform = AddLaplacianEigenvectorPE(k=pe_dim, attr_name='lap_pe')
            elif pe_type == 'maglap':
                q = config['model']['args']['q']
                q_dim = config['model']['args']['q_dim']
                pre_transform = AddMagLaplacianEigenvectorPE(k=pe_dim, q=q,
                                                         multiple_q=q_dim, attr_name='maglap_pe')
            #pre_transform = Compose([T.AddRandomWalkPE(walk_length=config['model']['args']['pe_dim_input'], attr_name='rw_pe')])
        mean = MEAN[config['task']['target']]
        std = STD[config['task']['target']]
        transform = get_target_transform(config['task']['target'], mean, std)
        if not config['train']['directed']:
            transform = T.Compose([transform, symmetrize_transform])
        super().__init__(root = self.save_folder, pre_transform = pre_transform,
                         transform=transform)
        self.data, self.slices = torch.load(self.processed_paths[mode])
    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return {
            'train': 'train_'+str(self.divide_seed)+'.pt',
            'valid': 'val_'+str(self.divide_seed)+'.pt',
            'test': 'test_'+str(self.divide_seed)+'.pt',
            'test_real': 'test_real_'+str(self.divide_seed)+'.pt',
            'test_othertype': 'test_other_type_'+str(self.divide_seed)+'.pt'
        }

    @property
    def processed_dir(self) -> str:
        processed_dir = osp.join(self.save_folder, 'processed')
        pe_type = self.config['model']['args'].get('pe_type')
        processed_dir += '_' + pe_type + str(self.config['model']['args']['pe_dim_input']) if pe_type is not None else ''
        if pe_type == 'maglap':
            processed_dir += '_' + str(self.config['model']['args']['q_dim']) + 'q' + str(self.config['model']['args']['q'])
        return processed_dir
    @property
    def processed_paths(self):
        return {mode: os.path.join(self.processed_dir, fname) for mode, fname in self.processed_file_names.items()}
    def process(self):
        file_names = self.processed_file_names
        # check if has already created
        exist_flag = 0
        for key in file_names:        
            exist_flag = exist_flag + os.path.isfile(self.processed_paths[key])
        if exist_flag == len(file_names):
            print('all datasets already exists, directly load.')
            return
        else:
            indices = list(range(18570))
            random.shuffle(indices)
            indice_dict = {}
            train_indices = indices[:16570]
            valid_indices = indices[16570:17570]
            test_indices = indices[17570:]
            test_real_indices = list(range(18570, 18626))
            other_indices = list(range(19119))
            test_othertype_indices = random.sample(other_indices, 1000)
            indice_dict['train'] = train_indices
            indice_dict['valid'] = valid_indices
            indice_dict['test'] = test_indices
            indice_dict['test_real'] = test_real_indices
            indice_dict['test_othertype'] = test_othertype_indices
            cdfg_raw_data_path = self.config['task']['raw_data_path']+self.config['task']['type']+'_cp_all/'
            dfg_raw_data_path = self.config['task']['raw_data_path']+'dfg'+'_cp/'
            cdfg_graph_list = self.read_csv_graph_raw(cdfg_raw_data_path)
            dfg_graph_list = self.read_csv_graph_raw(dfg_raw_data_path)
            for key in file_names:
                print('pre-transforming '+key+' dataset...')
                data_list = []
                if key != 'test_othertype':
                    for i, id in enumerate(indice_dict[key]):
                        if i % 1000 == 0 and i != 0:
                            print('pre-transforming ' + key + ' dataset: %d/%d' % (i, len(indice_dict[key])))
                            break
                        data = Data(x = torch.tensor(cdfg_graph_list[id]['node_feat']).long(), edge_index = torch.tensor(cdfg_graph_list[id]['edge_index']),
                                    edge_attr = torch.tensor(cdfg_graph_list[id]['edge_feat']).long(), dsp = torch.tensor(cdfg_graph_list[id]['dsp']).to(dtype=torch.float32),
                                    cp = torch.tensor(cdfg_graph_list[id]['cp']).to(dtype=torch.float32), lut = torch.tensor(cdfg_graph_list[id]['lut']).to(dtype=torch.float32),
                                    ff = torch.tensor(cdfg_graph_list[id]['ff']).to(dtype=torch.float32), slice = torch.tensor(cdfg_graph_list[id]['slice']).to(dtype=torch.float32))
                        if self.pre_transform is not None:
                            #bi_edge_index, bi_edge_weight = to_undirected(data.edge_index, data.edge_attr)
                            #tmp_bidirect_data = Data(x = data.x, edge_index = bi_edge_index, edge_attr = bi_edge_weight)
                            #tmp_bidirect_data = self.pre_transform(tmp_bidirect_data)
                            #data = add_node_attr(data, tmp_bidirect_data.rw_pe, attr_name='rw_pe')
                            data = self.pre_transform(data)
                        data_list.append(data)
                else:  
                    for i, id in enumerate(indice_dict[key]):
                        if i % 1000 == 0:
                            print('pre-transforming ' + key + ' dataset: %d/%d' % (i, len(indice_dict[key])))
                        data = Data(x = torch.tensor(dfg_graph_list[id]['node_feat']).long(), edge_index = torch.tensor(dfg_graph_list[id]['edge_index']),
                                    edge_attr = torch.tensor(dfg_graph_list[id]['edge_feat']).long(), dsp = torch.tensor(dfg_graph_list[id]['dsp']).to(dtype=torch.float32),
                                    cp = torch.tensor(dfg_graph_list[id]['cp']).to(dtype=torch.float32), lut = torch.tensor(dfg_graph_list[id]['lut']).to(dtype=torch.float32),
                                    ff = torch.tensor(dfg_graph_list[id]['ff']).to(dtype=torch.float32), slice = torch.tensor(dfg_graph_list[id]['slice']).to(dtype=torch.float32))
                        if self.pre_transform is not None:
                            data = self.pre_transform(data)
                            #bi_edge_index, bi_edge_weight = to_undirected(data.edge_index, data.edge_attr)
                            #tmp_bidirect_data = Data(x = data.x, edge_index = bi_edge_index, edge_attr = bi_edge_weight)
                            #tmp_bidirect_data = self.pre_transform(tmp_bidirect_data)
                            #data = add_node_attr(data, tmp_bidirect_data.rw_pe, attr_name='rw_pe')
                        data_list.append(data)
                data, slices = self.collate(data_list)
                torch.save((data, slices), self.processed_paths[key])
    def read_csv_graph_raw(self, raw_dir):
        label_dir = raw_dir + 'mapping'
        raw_dir = raw_dir + 'raw'
        labels = pd.read_csv(osp.join(label_dir, 'mapping.csv'))
        if isinstance(labels['DSP'][0], str):
            labels['DSP'] = labels['DSP'].apply(lambda x: float(re.findall(r'\d+\.?\d*', x)[0]) if x else None)
            labels['LUT'] = labels['LUT'].apply(lambda x: float(re.findall(r'\d+\.?\d*', x)[0]) if x else None).round(3)
            labels['CP'] = labels['CP'].apply(lambda x: float(re.findall(r'\d+\.?\d*', x)[0]) if x else None).round(3)
            labels['FF'] = labels['FF'].apply(lambda x: float(re.findall(r'\d+\.?\d*', x)[0]) if x else None).round(3)
            labels['SLICE'] = labels['SLICE'].apply(lambda x: float(re.findall(r'\d+\.?\d*', x)[0]) if x else None)
        try:
            edge = pd.read_csv(osp.join(raw_dir, 'edge.csv'), header = None).values.T.astype(np.int64) # (2, num_edge) numpy array
            num_node_list = pd.read_csv(osp.join(raw_dir, 'num-node-list.csv'), header = None).astype(np.int64)[0].tolist() # (num_graph, ) python list
            num_edge_list = pd.read_csv(osp.join(raw_dir, 'num-edge-list.csv'), header = None).astype(np.int64)[0].tolist() # (num_edge, ) python list
        except FileNotFoundError:
            raise RuntimeError('No such file')
        try:
            node_feat = pd.read_csv(osp.join(raw_dir, 'node-feat.csv'), header = None).values
            if 'int' in str(node_feat.dtype):
                node_feat = node_feat.astype(np.int64)
            else:
                node_feat = node_feat.astype(np.float32)
        except FileNotFoundError:
            node_feat = None
        #[0 0 0 0 0 0 0]
        #[3 256 7 56 2 2 257]
        print('node feature min'+str(node_feat.min(axis = 0)))
        print('node feat max:'+str(node_feat.max(axis = 0)))
        #print(np.unique(node_feat[:, 6]))
        #print(np.unique(node_feat[:, 1]))
        try:
            edge_feat = pd.read_csv(osp.join(raw_dir, 'edge-feat.csv'), header = None).values
            if 'int' in str(edge_feat.dtype):
                edge_feat = edge_feat.astype(np.int64)
            else:
                edge_feat = edge_feat.astype(np.float32)
        except FileNotFoundError:
            edge_feat = None

        print('edge feat min'+str(edge_feat.min(axis = 0)))
        print('edge feat max:'+str(edge_feat.max(axis = 0)))
        graph_list = []
        num_node_accum = 0
        num_edge_accum = 0
        print('Processing graphs...')
        for graph_id, (num_node, num_edge) in tqdm(enumerate(zip(num_node_list, num_edge_list))):
            graph = dict()
            graph['edge_index'] = edge[:, num_edge_accum:num_edge_accum+num_edge]
            if edge_feat is not None:
                graph['edge_feat'] = edge_feat[num_edge_accum:num_edge_accum+num_edge]
            else:
                graph['edge_feat'] = None
            num_edge_accum += num_edge
            ### handling node
            if node_feat is not None:
                graph['node_feat'] = node_feat[num_node_accum:num_node_accum+num_node]
            else:
                graph['node_feat'] = None
            # turn the node_feature into binary_encoding
            # original  7 dimension
            # min [0 0 0 0 0 0 0]
            # max [3 256 7 56 2 2 257]
            # now 2 + 8 + 3 + 6 + 2 + 2 + 8 = 31
            '''graph['node_feat'][graph['node_feat'] > 255] = 255
            binary_1 = np.array([np.array(list(np.binary_repr(feat[0], width=2)), dtype=int) for feat in graph['node_feat']])
            binary_2 = np.array([np.array(list(np.binary_repr(feat[1], width=8)), dtype=int) for feat in graph['node_feat']])
            binary_3 = np.array([np.array(list(np.binary_repr(feat[2], width=3)), dtype=int) for feat in graph['node_feat']])
            binary_4 = np.array([np.array(list(np.binary_repr(feat[3], width=6)), dtype=int) for feat in graph['node_feat']])
            binary_5 = np.array([np.array(list(np.binary_repr(feat[4], width=2)), dtype=int) for feat in graph['node_feat']])
            binary_6 = np.array([np.array(list(np.binary_repr(feat[5], width=2)), dtype=int) for feat in graph['node_feat']])
            binary_7 = np.array([np.array(list(np.binary_repr(feat[6], width=8)), dtype=int) for feat in graph['node_feat']])
            graph['node_feat'] = np.concatenate([binary_1, binary_2, binary_3, binary_4, binary_5, binary_6, binary_7], axis=1)'''
            graph['dsp'] = labels['DSP'][graph_id]
            graph['lut'] = labels['LUT'][graph_id]
            graph['cp'] = labels['CP'][graph_id]
            graph['ff'] = labels['FF'][graph_id]
            graph['slice'] = labels['SLICE'][graph_id]
            graph['num_nodes'] = num_node
            num_node_accum += num_node
            graph_list.append(graph)
        return graph_list
    

    