import os
import dgl
import numpy as np
import torch
from scipy import ndimage
import cv2
import tqdm
from torch.utils.data import DataLoader, Dataset
import shutil




def resize(input):
    dimension = input.shape
    result = ndimage.zoom(input, (256 / dimension[0], 256 / dimension[1]), order=3)
    return result


def std(input):
    if input.max() == 0:
        return input
    else:
        result = (input - input.min()) / (input.max() - input.min())
        return result


def resize_cv2(input):
    output = cv2.resize(input, (256, 256), interpolation=cv2.INTER_AREA)
    return output

def copy_file_to_path(source_file_path, target_file_path):
    # 检查源文件是否存在
    if not os.path.exists(source_file_path):
        raise FileNotFoundError(f"源文件不存在: {source_file_path}")

    # 检查源路径是否为文件
    if not os.path.isfile(source_file_path):
        raise ValueError(f"源路径不是一个文件: {source_file_path}")

    # 检查目标文件是否已经存在
    if os.path.exists(target_file_path):
        print(f"目标文件已存在，跳过复制: {target_file_path}")
        return target_file_path

    # 从目标文件路径中提取目录路径
    target_directory = os.path.dirname(target_file_path)

    # 创建目标目录（如果不存在）
    if target_directory and not os.path.exists(target_directory):
        os.makedirs(target_directory, exist_ok=True)
        print(f"已创建目录: {target_directory}")

    # 复制文件到目标路径
    shutil.copy2(source_file_path, target_file_path)

    print(f"文件已成功复制到: {target_file_path}")
    return target_file_path

class ReadFeaturesOutputSave(Dataset):
    def __init__(self, data_root, data_root2, namelist):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.data_root = data_root
        self.data_root2 = data_root2
        self.pos_flatten = []

        # self.save_file_path = os.path.join(self.data_root, self.save_root, self.name.split('.def')[0] + '.pickle')
        # tasks
        self.tasks = ['congestion', 'DRC', 'IR_drop']

        # congestion task
        self.congestion_root = 'routability_features'
        self.congestion_feature_list = [os.path.join(self.data_root, self.congestion_root, 'macro_region'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY'),
                                        os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY_pin')]
        self.congestion_label_list = [os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow'),
                                      os.path.join(self.data_root, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow')]
        self.congestion_feature_list2 = [os.path.join(self.data_root2, self.congestion_root, 'macro_region'),
                                        os.path.join(self.data_root2, self.congestion_root, 'RUDY/RUDY'),
                                        os.path.join(self.data_root2, self.congestion_root, 'RUDY/RUDY_pin')]
        self.congestion_label_list2 = [os.path.join(self.data_root2, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow'),
                                      os.path.join(self.data_root2, self.congestion_root,
                                                   'congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow')]
        self.name_list = namelist

        # DRC task
        feature_list = ['routability_features/macro_region',
                        'routability_features/cell_density',
                        'routability_features/RUDY/RUDY_long',
                        'routability_features/RUDY/RUDY_short',
                        'routability_features/RUDY/RUDY_pin_long',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_horizontal_overflow',
                        'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_vertical_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow',
                        'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow']
        label_list = ['routability_features/DRC/DRC_all']
        self.drc_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.drc_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        self.drc_feature_list2 = [os.path.join(self.data_root2, fea) for fea in feature_list]
        self.drc_label_list2 = [os.path.join(self.data_root2, fea) for fea in label_list]
        # IR task
        feature_list = ['IR_drop_features_decompressed/power_i', 'IR_drop_features_decompressed/power_s',
                        'IR_drop_features_decompressed/power_sca', 'IR_drop_features_decompressed/power_all',
                        'IR_drop_features_decompressed/power_t']

        label_list = ['IR_drop_features_decompressed/IR_drop']
        self.ir_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
        self.ir_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
        self.ir_feature_list2 = [os.path.join(self.data_root2, fea) for fea in feature_list]
        self.ir_label_list2 = [os.path.join(self.data_root2, fea) for fea in label_list]
        self.generate_features()

    def copy_data(self, task):
        name_list = self.name_list
        if task == 'congestion':
            for name in name_list:
                for feature_name, feature_name2 in zip(*(self.congestion_feature_list, self.congestion_feature_list2)):
                    copy_file_to_path(os.path.join(feature_name, name), os.path.join(feature_name2, name))
                for label_name, label_name2 in zip(*(self.congestion_label_list, self.congestion_label_list2)):
                    copy_file_to_path(os.path.join(label_name, name), os.path.join(label_name2, name))

        elif task == 'DRC':
            for name in name_list:
                for feature_name, feature_name2 in zip(*(self.drc_feature_list, self.drc_feature_list2)):
                    copy_file_to_path(os.path.join(feature_name, name), os.path.join(feature_name2, name))
                for label_name, label_name2 in zip(*(self.drc_label_list, self.drc_label_list2)):
                    copy_file_to_path(os.path.join(label_name, name), os.path.join(label_name2, name))

        elif task == 'IR_drop':
            for name in name_list:
                for feature_name, feature_name2 in zip(*(self.ir_feature_list, self.ir_feature_list2)):
                    copy_file_to_path(os.path.join(feature_name, name), os.path.join(feature_name2, name))
                for label_name, label_name2 in zip(*( self.ir_label_list,  self.ir_label_list2)):
                    copy_file_to_path(os.path.join(label_name, name), os.path.join(label_name2, name))

    def generate_features(self):
        if 'congestion' in self.tasks:
            self.copy_data('congestion')
        if 'DRC' in self.tasks:
            self.copy_data('DRC')
        if 'IR_drop' in self.tasks:
            self.copy_data('IR_drop')


