import os
import dgl
import numpy as np
import torch
from scipy import ndimage
import cv2
import tqdm
from torch.utils.data import DataLoader, Dataset


def resize(input):
    dimension = input.shape
    result = ndimage.zoom(input, (256 / dimension[0], 256 / dimension[1]), order=3)
    return result

def std(input):
    if input.max() == 0:
        return input
    else:
        result = (input-input.min()) / (input.max()-input.min())
        return result

def resize_cv2(input):
    output = cv2.resize(input, (256, 256), interpolation = cv2.INTER_AREA)
    return output

class CircuitNetDataset(Dataset):
    def __init__(self, data_root, namelist, args):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.data_root = data_root
        self.name_list = []
        list_file = open(os.path.join(data_root, namelist), mode='r')
        for line in list_file:
            if line.endswith('\n'):
                line = line[:-1]
            self.name_list.append(line)
        list_file.close()

        # congestion
        self.congestion_root = 'routability_features'
        self.congestion_feature_list = [os.path.join(self.data_root, self.congestion_root, 'macro_region'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY_pin')]
        self.congestion_label_list = [os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow'),
                                      os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow')]
        # DRC task
        feature_list = ['routability_features/macro_region',
                        'routability_features/cell_density',
                        'routability_features/RUDY/RUDY_long',
                        'routability_features/RUDY/RUDY_short',
                        'routability_features/RUDY/RUDY_pin_long',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_horizontal_overflow',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_vertical_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow']
        label_list = ['routability_features/DRC/DRC_all']
        self.drc_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.drc_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        # IR task
        feature_list = ['IR_drop_features_decompressed/power_i', 'IR_drop_features_decompressed/power_s',
                        'IR_drop_features_decompressed/power_sca', 'IR_drop_features_decompressed/power_all',
                        'IR_drop_features_decompressed/power_t']

        label_list = ['IR_drop_features_decompressed/IR_drop']
        self.ir_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.ir_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        # thermal_task
        self.thermal_label_list = [os.path.join(self.data_root, "hotmap")]


        self.feature = args.feature
        assert self.feature in ["congestion", "DRC", "IR_drop", "thermal", "all"]
        self.label = args.label
        assert self.label in ["congestion", "DRC", "IR_drop", "thermal", "all"]
        #if self.label == "thermal":
        #    assert args.pretrain == False
        #self.get_features_and_labels_from_task()
        if self.feature == "all" or self.feature == "thermal":
            self.get_all_features()
        else:
            self.get_features_one_task()
        if self.label == "all":
            self.get_all_labels()
        else:
            self.get_labels_one_task()
        self.feature_dim = self.features.shape[1]
        self.label_dim = self.labels.shape[1]

    def get_all_features(self):
        features = []
        for name in self.name_list:
            out_feature_list = []
            for feature_name in self.congestion_feature_list:
                feature = np.load(os.path.join(feature_name, name))
                feature = torch.tensor(std(resize(feature)))
                out_feature_list.append(feature)
            for feature_name in self.drc_feature_list:
                feature = np.load(os.path.join(feature_name, name))
                feature = torch.tensor(std(resize(feature)))
                out_feature_list.append(feature)
            for feature_name in self.ir_feature_list:
                feature = np.load(os.path.join(feature_name, name))
                if feature_name.endswith('power_t'):
                    for i in range(20):
                        slice = feature[i, :, :]
                        out_feature_list.append(torch.tensor(std(resize_cv2(slice))))
                else:
                    feature = torch.tensor(std(resize_cv2(feature.squeeze())))
                    out_feature_list.append(feature)
            features.append(torch.stack(out_feature_list))
        self.features = torch.stack(features)



    def get_features_one_task(self):
        features = []
        if self.feature == "congestion":
            for name in self.name_list:
                out_feature_list = []
                for feature_name in self.congestion_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)
                features.append(torch.stack(out_feature_list))
        elif self.feature == 'DRC':
            for name in self.name_list:
                out_feature_list = []
                for feature_name in self.drc_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)
                features.append(torch.stack(out_feature_list))
        elif self.feature == 'IR_drop':
            for name in self.name_list:
                out_feature_list = []
                for feature_name in self.ir_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    if feature_name.endswith('power_t'):
                        for i in range(20):
                            slice = feature[i, :, :]
                            out_feature_list.append(torch.tensor(std(resize_cv2(slice))))
                    else:
                        feature = torch.tensor(std(resize_cv2(feature.squeeze())))
                        out_feature_list.append(feature)
                features.append(torch.stack(out_feature_list))
        self.features = torch.stack(features)

    def get_all_labels(self):
        labels = []
        for name in self.name_list:
            out_label_list = []
            #congestion - 将水平和垂直通道相加成1个通道
            congestion_labels = []
            for label_name in self.congestion_label_list:
                label = np.load(os.path.join(label_name, name))
                label = torch.tensor(resize(label))
                congestion_labels.append(label)
            # 将水平和垂直拥塞相加
            combined_congestion = torch.stack(congestion_labels).sum(dim=0)
            out_label_list.append(combined_congestion)
            #drc
            for label_name in self.drc_label_list:
                label = np.load(os.path.join(label_name, name))
                label = np.clip(label, 0, 200)
                label = torch.tensor(resize_cv2(label) / 200)
                out_label_list.append(label)
            #ir
            for label_name in self.ir_label_list:
                label = np.load(os.path.join(label_name, name))
                label = np.squeeze(label)
                label = np.clip(label, 1e-6, 50)
                label = torch.tensor((np.log10(resize_cv2(label)) + 6) / (np.log10(50) + 6))
                out_label_list.append(label)
            labels.append(torch.stack(out_label_list))
        self.labels = torch.stack(labels)


    def get_labels_one_task(self):
        labels = []
        if self.label == "congestion":
            for name in self.name_list:
                out_label_list = []
                # 将水平和垂直拥塞通道相加成1个通道
                congestion_labels = []
                for label_name in self.congestion_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = torch.tensor(resize(label))
                    congestion_labels.append(label)
                # 将水平和垂直拥塞相加
                combined_congestion = torch.stack(congestion_labels).sum(dim=0)
                out_label_list.append(combined_congestion)
                labels.append(torch.stack(out_label_list))
        elif self.label == 'DRC':
            for name in self.name_list:
                out_label_list = []
                for label_name in self.drc_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.clip(label, 0, 200)
                    label = torch.tensor(resize_cv2(label) / 200)
                    out_label_list.append(label)
                labels.append(torch.stack(out_label_list))
        elif self.label == 'IR_drop':
            for name in self.name_list:
                out_label_list = []
                for label_name in self.ir_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.squeeze(label)
                    label = np.clip(label, 1e-6, 50)
                    label = torch.tensor((np.log10(resize_cv2(label)) + 6) / (np.log10(50) + 6))
                    out_label_list.append(label)
                labels.append(torch.stack(out_label_list))
        elif self.label == "thermal":
            for name in self.name_list:
                out_label_list = []
                for label_name in self.thermal_label_list:
                    try:
                        label = np.load(os.path.join(label_name, name+".npy"))
                        label = torch.tensor(resize(label))
                        out_label_list.append(label)
                    except FileNotFoundError:
                        print(f"Warning: Thermal file not found: {os.path.join(label_name, name+'.npy')}, skipping...")
                        continue
                if out_label_list:  # 只有当找到文件时才添加
                    labels.append(torch.stack(out_label_list))

        self.labels = torch.stack(labels)


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        feature_map = self.features[idx]
        labels = self.labels[idx]
        return feature_map, labels








