import os
import torch
import random
import argparse

import tqdm

from feature_extration.src.read import ReadInnovusOutput, read_lef, read_lef_pin_map
from feature_extration.src.read_features import ReadFeaturesOutput


class Paraser(object):
    def __init__(self) -> None:
        self.parser = argparse.ArgumentParser()
        self.parser.add_argument('--data_set', default='CircuitNet-N28',
                                 help='the parent dir of innovus workspace')
        self.parser.add_argument('--data_root', default='/home/shiyu/CircuitNet-N28',
                                 help='the parent dir of innovus workspace')
        self.parser.add_argument('--train_list', default='lists/trainlist100.txt')
        self.parser.add_argument('--test_list', default='lists/testlist100.txt')
        self.parser.add_argument('--graph_save_root', default='graphs2',
                                 help='the parent dir of innovus workspace')
        self.parser.add_argument('--place_def_root',
                                 default='raw_data/LEFDEF/place-DEF/DEF',
                                 help='the parent dir of innovus workspace')
        self.parser.add_argument('--route_def_root',
                                 default='raw_data/LEFDEF/route-DEF/DEF',
                                 help='the parent dir of innovus workspace')
        # self.parser.add_argument('--lef_path', default=['./LEF/circuitnet.lef'], help='path to LEF files')
        self.parser.add_argument('--lef_path', default='raw_data/LEFDEF/LEF/circuitnet.lef', help='path to LEF files')
        self.parser.add_argument('--unit', default=2000, help='unit defined in the begining of DEF')
        self.parser.add_argument('--save_path', default='./out', help='save path')
        # self.parser.add_argument('--process_capacity', default=10, help='multi process setting, number of files for each process, determine number of process')
        # self.parser.add_argument('--debug', default=True, help='disable multi process to use pdb')

        # We give route def in place_def_name argument for simple.
        # But users should provide place DEF here, due to the difference in place and route DEF from possible optDesign.
        # self.parser.add_argument('--place_def_name', default='1-RISCY-a-1-c2-u0.7-m1-p1-f0.def.gz')#'detailed_route.def.gz')
        # self.parser.add_argument('--route_def_name', default='1-RISCY-a-1-c2-u0.7-m1-p1-f0.def.gz')#'detailed_route.def.gz')
        # self.parser.add_argument('--eGR_congestion_name', default='eGR_congestion') # congestion report from 'dumpNanoCongestArea'
        # self.parser.add_argument('--route_congestion_name', default='route_congestion')
        # self.parser.add_argument('--drc_rpt_name', default='drc.rpt')               # drc report from 'verify_drc'
        # self.parser.add_argument('--twf_rpt_name', default='cts.twf')               # timing window file from 'write_timing_windows'
        # self.parser.add_argument('--power_rpt_name', default='dyn_power.rpt')       # power report from 'report_power'
        # self.parser.add_argument('--ir_rpt_name', default='route_dynamic_ir.rpt')   # ir report form 'report_power_rail_results'
        # self.parser.add_argument('--n_time_window', default='20')                   # number of divided timing windows
        # self.parser.add_argument('--scaling', default=None)
        self.parser.add_argument('--vit_model_name', default='deit_base_patch16_256')
        self.parser.add_argument('--vit_layers', default='3')
        self.parser.add_argument('--start_epoch', type=int, default=0)
        self.parser.add_argument('--epochs', type=int, default=100)
        self.parser.add_argument('--hiddendim', type=int, default=32)
        self.parser.add_argument('--vitlr', type=float, default=1e-8)
        self.parser.add_argument('--vitweight_decay', type=float, default=1e-4)
        self.parser.add_argument('--vitlr_decay', type=float, default=5e-4)
        self.parser.add_argument('--vitbatch_size', type=float, default=32)

        self.parser.add_argument('--lr', type=float, default=5e-1)
        self.parser.add_argument('--weight_decay', type=float, default=5e-4)

        self.parser.add_argument('--gnnlr', type=float, default=1e-4)
        self.parser.add_argument('--gnnweight_decay', type=float, default=2e-4)
        self.parser.add_argument('--gnnlr_decay', type=float, default=2e-4)
        self.parser.add_argument('--recurrent', type=bool, default=False)  # False
        self.parser.add_argument('--topo_conv_type', type=str, default='CFCNN')  # CFCNN
        self.parser.add_argument('--geom_conv_type', type=str, default='SAGE')  # SAGE
        self.parser.add_argument('--agg_type', type=str, default='max')  # max
        self.parser.add_argument('--cat_raw', type=bool, default=False)  # True
        self.parser.add_argument('--gnn_layers', type=int, default=2)  # 2
        self.parser.add_argument('--node_feats', type=int, default=96)  # 64
        self.parser.add_argument('--net_feats', type=int, default=128)  # 128
        self.parser.add_argument('--pin_feats', type=int, default=16)  # 16
        self.parser.add_argument('--edge_feats', type=int, default=4)  # 4
        self.parser.add_argument('--topo_geom', type=str, default='both')
        self.parser.add_argument('--outtype', type=str, default='tanh')
        self.parser.add_argument('--feat_start', type=float, default=0.1)
        self.parser.add_argument('--feat_end', type=float, default=0.5)

        self.parser.add_argument('--l1', type=float, default=1.)
        self.parser.add_argument('--l2', type=float, default=5.)
        self.parser.add_argument('--l3', type=float, default=0.1)

        self.parser.add_argument('--gnn_checkpoint_root', type=str,
                                 default="./results/CircuitNet-N28/2025-07-23_15-33-27/gnn.pth")
        self.parser.add_argument('--vit_checkpoint_root', type=str,
                                 default="./results/CircuitNet-N28/2025-07-23_15-33-27/vit.pth")


def read(data_root, graph_save_root, read_list, place_def_root, route_def_root, unit, lef_dic, lef_dic_jnet):
    graphlists = []
    gcelllist = []
    print("gengerating dataset")
    read_list = tqdm.tqdm(read_list, total=len(read_list))
    for name in read_list:
        name += '.def.gz'
        # path = os.path.join(arg.data_root, path)
        # save_name = arg.place_def_name # os.path.basename(path)
        process_log = ReadInnovusOutput(data_root, graph_save_root, name, place_def_root, route_def_root, unit, lef_dic)
        process_log.preprocess()
        graphlists.append(process_log.list_hetero_graph)
        gcelllist.append(process_log.gcell_size)
        ### call functions to get features, reference in README ###
        # process_log.read_route_pin_position()
        # process_log.compute_cell_density()
        # process_log.get_RUDY()
        # process_log.read_instance_placement()
    print("merging dataset")
    features_log = ReadFeaturesOutput(data_root, graphlists, gcelllist, read_list)
    return features_log


def get_unit(path):
    with open(path, 'r') as read_file:
        read_unit = 0
        for line in read_file:
            if line.lstrip().startswith('UNITS'):
                read_unit = 1
            if read_unit:
                if line.lstrip().startswith('DATABASE'):
                    unit = line.split()[2]
                    # break
            if line == 'END UNITS':
                read_unit = 0
                break
    return int(unit)


def get_cell(data_root, graph_save_root, list_path, lef_path, place_def_root, route_def_root, istrain=1):
    if not os.path.exists(graph_save_root):
        os.makedirs(graph_save_root)
    lef_dic = {}
    lef_dic_jnet = {}
    unit = get_unit(os.path.join(data_root, lef_path))
    lef_dic = read_lef(os.path.join(data_root, lef_path), lef_dic, unit)
    # lef_dic_jnet = read_lef_pin_map(os.path.join(data_root,lef_path),lef_dic_jnet, unit)
    read_list = []
    if not os.path.exists(os.path.join(data_root, list_path[0])):
        defroute = os.path.join(data_root, route_def_root)
        routenames = set(os.listdir(defroute))
        defplace = os.path.join(data_root, place_def_root)
        placenames = set(os.listdir(defplace))
        names = list(routenames.intersection(placenames))
        print(len(names))
        part_names = []

        for name in names:
            # print(name)
            if name.split('-')[2] == 'FPU':
                pass
                # base_name = '-'.join(name.split('-')[1:6])
            elif name.split('-')[1] == 'zero':
                pass
                # base_name = '-'.join(name.split('-')[1:6])
            else:
                part_names.append(name)
        print(f"deal with {len(part_names)} samples ")
        print(f"deal with {len(part_names)} samples ")
        print(f"deal with {len(part_names)} samples ")
        random.shuffle(part_names)
        # if len(part_names) >= 1000:
        #    part_names = part_names[:1000]
        partition = int(len(part_names) / 5 * 4)
        train_names = part_names[:partition]
        test_names = part_names[partition:]
        print(len(train_names))
        print(len(test_names))
        Note = open(os.path.join(data_root, list_path[0]), mode='w')
        for name in train_names:
            if name.endswith(".def.gz"):
                name = name[:-7]
            read_list.append(name)
            Note.write(name + '\n')
        Note.close()
        Note = open(os.path.join(data_root, list_path[1]), mode='w')
        for name in test_names:
            if name.endswith(".def.gz"):
                name = name[:-7]
            read_list.append(name)
            Note.write(name + '\n')
        Note.close()
    # fp_distribution = open(self.distribution_file,'r')
    if istrain:
        list_path = list_path[0]
    else:
        list_path = list_path[1]
    list_file = open(os.path.join(data_root, list_path), 'r')
    for line in list_file:
        if line.endswith('\n'):
            line = line[:-1]
        read_list.append(line)
    return read(data_root, graph_save_root, read_list, place_def_root, route_def_root, unit, lef_dic, lef_dic_jnet)


if __name__ == '__main__':
    argp = Paraser()
    arg = argp.parser.parse_args()

    lef_dic = {}
    lef_dic_jnet = {}
    get_cell(arg.data_root, arg.graph_save_root, arg.lef_path, arg.place_def_root, arg.route_def_root, arg.unit)

