import os
import dgl
import numpy as np
import torch
from scipy import ndimage
import cv2
import tqdm
from torch.utils.data import DataLoader, Dataset

def resize(input):
    dimension = input.shape
    result = ndimage.zoom(input, (256 / dimension[0], 256 / dimension[1]), order=3)
    return result

def std(input):
    if input.max() == 0:
        return input
    else:
        result = (input-input.min()) / (input.max()-input.min())
        return result

def resize_cv2(input):
    output = cv2.resize(input, (256, 256), interpolation = cv2.INTER_AREA)
    return output
class ReadFeaturesOutput(Dataset):
    def __init__(self, data_root, graph_list,gcell_list, namelist):
        self.use_tqdm = 1
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.data_root = data_root
        self.graphs = graph_list
        self.topo_geom = "both"
        self.gcellsize = gcell_list
        self.pos_flatten = []

        #self.save_file_path = os.path.join(self.data_root, self.save_root, self.name.split('.def')[0] + '.pickle')
        # tasks
        self.tasks = ['congestion', 'DRC', 'IR_drop']

        # congestion task
        self.congestion_root = 'routability_features'
        self.congestion_feature_list = [os.path.join(self.data_root, self.congestion_root, 'macro_region'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY_pin')]
        self.congestion_label_list = [os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow'),
                                      os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow')]
        self.name_list = namelist
        


        # DRC task
        feature_list = ['routability_features/macro_region',
                        'routability_features/cell_density',
                        'routability_features/RUDY/RUDY_long',
                        'routability_features/RUDY/RUDY_short',
                        'routability_features/RUDY/RUDY_pin_long',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_horizontal_overflow',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_vertical_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow']
        label_list = ['routability_features/DRC/DRC_all']
        self.drc_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.drc_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        # IR task
        feature_list = ['IR_drop_features_decompressed/power_i', 'IR_drop_features_decompressed/power_s',
                        'IR_drop_features_decompressed/power_sca', 'IR_drop_features_decompressed/power_all','IR_drop_features_decompressed/power_t']
        
        label_list = ['IR_drop_features_decompressed/IR_drop']
        self.ir_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.ir_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        self.generate_features()
        self.combine_features()


    def combine_features(self):
        self.graphsize = self.gcellsize
        self.node_feature_num = self.graphs[0][0].nodes['node'].data['hv'].shape[0]
        self.map_feature_num = self.all_features.shape[1]
        self.nodemap_feature_num = self.node_feature_num+self.map_feature_num
        self.label_num = self.all_labels.shape[1]
        self.feature_size = self.all_features.shape[-1]
        loop = range(len(self.graphs))
        loop = tqdm.tqdm(loop, total=len(loop))
        print("combine_features")
        for i in loop:
            list_hetero_graph = self.graphs[i]
            graphsize = self.graphsize[i][0]
            feature = self.all_features[i]
            label = self.all_labels[i]
            pos_flatten = []
            for hetero_graph in list_hetero_graph:
                cellxy = hetero_graph.nodes['node'].data['hv'][:, -2:]
                cellxy = (cellxy/graphsize*self.feature_size).floor().to(torch.int)
                pos_flatten.append(cellxy)
                feature_from_features = []
                label_from_labels = []
                for x,y in cellxy:
                    feature_from_features.append(feature[:,x,y])
                    label_from_labels.append(label[:,x,y])
                feature_from_features = torch.stack(feature_from_features)
                label_from_labels = torch.stack(label_from_labels)
                hetero_graph.nodes['node'].data['hv'] = torch.cat([hetero_graph.nodes['node'].data['hv'],feature_from_features],dim=1)
                hetero_graph.nodes['node'].data['label'] = label_from_labels
            self.pos_flatten.append(pos_flatten)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.all_features = self.all_features.float().to(device)
        self.all_labels = self.all_labels.float().to(device)
        for gs in self.graphs:
            for gi in range(len(gs)):
                gs[gi].nodes['node'].data['hv'] = gs[gi].nodes['node'].data['hv'].float()
                gs[gi].nodes['net'].data['hv'] = gs[gi].nodes['net'].data['hv'].float()
                gs[gi].edges['pinned'].data['he'] = gs[gi].edges['pinned'].data['he'].float()
                gs[gi].edges['near'].data['he'] =  gs[gi].edges['near'].data['he'].float()
                gs[gi].nodes['node'].data['label'] = gs[gi].nodes['node'].data['label'].float()

                #gs[gi] = gs[gi].to(self.device)

    def fit_topo_geom(self):
        if self.topo_geom == 'topo':
            ltg = [(g, dgl.remove_edges(hg, hg.edges('eid', etype='near'), etype='near')) for g, hg in ltg]
        elif self.topo_geom == 'geom':
            ltg = [(g, dgl.remove_edges(hg, hg.edges('eid', etype='pinned'), etype='pinned')) for g, hg in ltg]
        ltg = [(g, dgl.add_self_loop(hg, etype='near')) for g, hg in ltg]
        return ltg

    def pack_data(self, task):
        features = []
        labels = []
        name_list = tqdm.tqdm(self.name_list, total=len(self.name_list))
        if task == 'congestion':
            self.graphsize = []
            for name in name_list:
                #name = os.path.basename(name)
                out_feature_list = []
                out_label_list = []
                for feature_name in self.congestion_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    self.graphsize.append(feature.shape[-2:])
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)

                for label_name in self.congestion_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = torch.tensor((resize(label)))
                    out_label_list.append(label)
                features.append(torch.stack(out_feature_list))
                labels.append(torch.stack(out_label_list))

            fs = torch.stack(features)
            ls = torch.stack(labels)
            if len(ls.shape) == 4:
                ls = torch.sum(ls, dim=1, keepdim=True)
            else:
                ls = torch.unsqueeze(ls, 1)
            self.graphsize = self.graphsize[:fs.shape[0]]
        elif task == 'DRC':
            for name in name_list:
                #name = os.path.basename(name)
                out_feature_list = []
                out_label_list = []
                for feature_name in self.drc_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)

                for label_name in self.drc_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.clip(label, 0, 200)
                    label = torch.tensor(resize_cv2(label) / 200)
                    out_label_list.append(label)
                features.append(torch.stack(out_feature_list))
                labels.append(torch.stack(out_label_list))
            fs = torch.stack(features)
            ls = torch.stack(labels)
        elif task == 'IR_drop':
            for name in name_list:
                #name = os.path.basename(name)
                out_feature_list = []
                out_label_list = []
                for feature_name in self.ir_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    if feature_name.endswith('power_t'):
                        for i in range(20):
                            slice = feature[i, :, :]
                            out_feature_list.append(torch.tensor(std(resize_cv2(slice))))
                    else:
                        feature = torch.tensor(std(resize_cv2(feature.squeeze())))
                        out_feature_list.append(feature)
                for label_name in self.ir_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.squeeze(label)
                    label = np.clip(label, 1e-6, 50)
                    label = torch.tensor((np.log10(resize_cv2(label)) + 6) / (np.log10(50) + 6))
                    out_label_list.append(label)
                features.append(torch.stack(out_feature_list))
                labels.append(torch.stack(out_label_list))
            fs = torch.stack(features)
            ls = torch.stack(labels)
        else:
            fs, ls = [], []
        return fs, ls


    def generate_features(self):
        self.feature_length = []
        self.label_length = []
        self.all_features = []
        self.all_labels = []
        if 'congestion' in self.tasks:
            congestion_features_map, congestion_labels_map = self.pack_data('congestion')
            self.feature_length.append(congestion_features_map.shape[1])
            self.label_length.append(congestion_labels_map.shape[1])
            self.all_features.append(congestion_features_map)
            self.all_labels.append(congestion_labels_map)
        if 'DRC' in self.tasks:
            drc_features_map, drc_labels_map = self.pack_data('DRC')
            self.feature_length.append(drc_features_map.shape[1])
            self.label_length.append(drc_labels_map.shape[1])
            self.all_features.append(drc_features_map)
            self.all_labels.append(drc_labels_map)
        if 'IR_drop' in self.tasks:
            ir_features_map, ir_labels_map = self.pack_data('IR_drop')
            self.feature_length.append(ir_features_map.shape[1])
            self.label_length.append(ir_labels_map.shape[1])
            self.all_features.append(ir_features_map)
            self.all_labels.append(ir_labels_map)
        self.all_features = torch.cat(self.all_features,1)
        self.all_labels = torch.cat(self.all_labels,1)




    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        #graphs = self.graphs[idx]
        feature_map = self.all_features[idx]
        labels = self.all_labels[idx].flatten(1).transpose(0, 1)
        pos = self.pos_flatten[idx]
        return feature_map, labels, pos


