import os
import dgl
import numpy as np
import torch
from scipy import ndimage
import cv2
import tqdm
from torch.utils.data import DataLoader, Dataset


def resize(input):
    dimension = input.shape
    result = ndimage.zoom(input, (256 / dimension[0], 256 / dimension[1]), order=3)
    return result

def std(input):
    if input.max() == 0:
        return input
    else:
        result = (input-input.min()) / (input.max()-input.min())
        return result

def resize_cv2(input):
    output = cv2.resize(input, (256, 256), interpolation = cv2.INTER_AREA)
    return output

def get_vit_dataset(data_root, list_path, istrain):
    read_list = []
    if istrain:
        list_path = list_path[0]
    else:
        list_path = list_path[1]
    list_file = open(os.path.join(data_root, list_path), 'r')
    for line in list_file:
        if line.endswith('\n'):
            line = line[:-1]
        read_list.append(line)
    return Vit_dataset(data_root,read_list)

class Vit_dataset(Dataset):
    def __init__(self, data_root, namelist):
        self.use_tqdm = 1
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.data_root = data_root
        self.pos_flatten = []
        # self.save_file_path = os.path.join(self.data_root, self.save_root, self.name.split('.def')[0] + '.pickle')
        # tasks
        self.tasks = ['congestion', 'DRC', 'IR_drop']

        # congestion task
        self.congestion_root = 'routability_features'
        self.congestion_feature_list = [os.path.join(self.data_root, self.congestion_root, 'macro_region'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY_pin')]
        self.congestion_label_list = [os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow'),
                                      os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow')]
        self.name_list = namelist

        # DRC task
        feature_list = ['routability_features/macro_region',
                        'routability_features/cell_density',
                        'routability_features/RUDY/RUDY_long',
                        'routability_features/RUDY/RUDY_short',
                        'routability_features/RUDY/RUDY_pin_long',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_horizontal_overflow',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_vertical_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow']
        label_list = ['routability_features/DRC/DRC_all']
        self.drc_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.drc_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        # IR task
        feature_list = ['IR_drop_features_decompressed/power_i', 'IR_drop_features_decompressed/power_s',
                        'IR_drop_features_decompressed/power_sca', 'IR_drop_features_decompressed/power_all',
                        'IR_drop_features_decompressed/power_t']

        label_list = ['IR_drop_features_decompressed/IR_drop']
        self.ir_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.ir_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        self.generate_features()


    def pack_data(self, task):
        features = []
        labels = []
        name_list = tqdm.tqdm(self.name_list, total=len(self.name_list))
        if task == 'congestion':
            self.graphsize = []
            for name in name_list:
                # name = os.path.basename(name)
                out_feature_list = []
                out_label_list = []
                for feature_name in self.congestion_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    self.graphsize.append(feature.shape[-2:])
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)

                for label_name in self.congestion_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = torch.tensor(resize(resize(label)))
                    out_label_list.append(label)
                features.append(torch.stack(out_feature_list))
                labels.append(torch.stack(out_label_list))

            fs = torch.stack(features)
            ls = torch.stack(labels)
            if len(ls.shape) == 4:
                ls = torch.sum(ls, dim=1, keepdim=True)
            else:
                ls = torch.unsqueeze(ls, 1)
            self.graphsize = self.graphsize[:fs.shape[0]]
        elif task == 'DRC':
            for name in name_list:
                # name = os.path.basename(name)
                out_feature_list = []
                out_label_list = []
                for feature_name in self.drc_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)

                for label_name in self.drc_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.clip(label, 0, 200)
                    label = torch.tensor(resize_cv2(label) / 200)
                    out_label_list.append(label)
                features.append(torch.stack(out_feature_list))
                labels.append(torch.stack(out_label_list))
            fs = torch.stack(features)
            ls = torch.stack(labels)
        elif task == 'IR_drop':
            for name in name_list:
                # name = os.path.basename(name)
                out_feature_list = []
                out_label_list = []
                for feature_name in self.ir_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    if feature_name.endswith('power_t'):
                        for i in range(20):
                            slice = feature[i, :, :]
                            out_feature_list.append(torch.tensor(std(resize_cv2(slice))))
                    else:
                        feature = torch.tensor(std(resize_cv2(feature.squeeze())))
                        out_feature_list.append(feature)
                for label_name in self.ir_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.squeeze(label)
                    label = np.clip(label, 1e-6, 50)
                    label = torch.tensor((np.log10(resize_cv2(label)) + 6) / (np.log10(50) + 6))
                    out_label_list.append(label)
                features.append(torch.stack(out_feature_list))
                labels.append(torch.stack(out_label_list))
            fs = torch.stack(features)
            ls = torch.stack(labels)
        else:
            fs, ls = [], []
        return fs, ls

    def generate_features(self):
        self.feature_length = []
        self.label_length = []
        self.all_features = []
        self.all_labels = []
        if 'congestion' in self.tasks:
            congestion_features_map, congestion_labels_map = self.pack_data('congestion')
            self.feature_length.append(congestion_features_map.shape[1])
            self.label_length.append(congestion_labels_map.shape[1])
            self.all_features.append(congestion_features_map)
            self.all_labels.append(congestion_labels_map)
        if 'DRC' in self.tasks:
            drc_features_map, drc_labels_map = self.pack_data('DRC')
            self.feature_length.append(drc_features_map.shape[1])
            self.label_length.append(drc_labels_map.shape[1])
            self.all_features.append(drc_features_map)
            self.all_labels.append(drc_labels_map)
        if 'IR_drop' in self.tasks:
            ir_features_map, ir_labels_map = self.pack_data('IR_drop')
            self.feature_length.append(ir_features_map.shape[1])
            self.label_length.append(ir_labels_map.shape[1])
            self.all_features.append(ir_features_map)
            self.all_labels.append(ir_labels_map)
        self.all_features = torch.cat(self.all_features, 1).float()#.to(self.device)
        self.all_labels = torch.cat(self.all_labels, 1).float()#.to(self.device)

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, idx):
        feature_map = self.all_features[idx]
        labels = self.all_labels[idx].flatten(1).transpose(0, 1)
        return feature_map, labels