from mlcirc_utils import StandardGraph, StandardAST, ComponentEnum
from mlcirc_utils.graphs import EdgeTerminalEnum
from mlcirc_tech_sky130 import BaseSky130Configuration, Sky130TechPlugin, Sky130CornerEnum
from mlcirc_pipelines.interfaces_general import ASTtoGraphInterface
from mlcirc_pipelines.interfaces_rl import get_enum_name_from_value, RLEdgeTerminalEnum, RLNodeTypeEnum
from pathlib import Path
import yaml
import pprint as ppr

from infrastructure.netlist_dataset import NetlistDataset
import pickle

FOLDER_PATH = "./dataset/full_edit" # all the .scs files

def add_rl_features(graph: StandardGraph, target_circuit, pin_order):
    
    out = []
    inpt = []
    vdd = None
    vss = None
    for i, val in enumerate(pin_order):
        if val == "VDD": 
            vdd = i
        elif val == "VSS":
            vss = i
        elif vdd is None and vss is None:
            out.append(val)
        else:
            inpt.append(val)
    for n in graph.get_nodes():
        custom_features = graph.get_node_features(n[0], ["custom_features"])
        if custom_features is None:
            custom_features = {}
        node_enum = graph.get_node_features(n[0], ["component_type"])
        empty_list = [0 for i in range(RLNodeTypeEnum.__len__())]
        if node_enum == ComponentEnum.NET:
            #empty_list[RLNodeTypeEnum.NET.value] = 1
            if n[0] == "VDD":
                empty_list[RLNodeTypeEnum.VDD.value] = 1
            elif n[0] == "VSS":
                empty_list[RLNodeTypeEnum.GND.value] = 1
            elif n[0] in out:
                empty_list[RLNodeTypeEnum.OUTPUT.value] = 1
            elif n[0] in inpt:
                empty_list[RLNodeTypeEnum.INPUT.value] = 1
            else:
                empty_list[RLNodeTypeEnum.NET.value] = 1
            custom_features["one_hot_node_type"] = empty_list
            custom_features["node_name"] = n[0]
        else:
            empty_list[RLNodeTypeEnum[node_enum.name].value] = 1
            custom_features["one_hot_node_type"] = empty_list

        if graph.get_node_features(n[0], ["component_type"]) == ComponentEnum.NET:
            custom_features["name"] = target_circuit
            custom_features["pin_order"] = pin_order
        graph.overwrite_node_field(n[0], "custom_features", custom_features)
    
    for e in graph.get_edges():
        edge_attr = graph.get_edge_features(e[0], e[1], ['condensed_terminal_feature'])
        indices = [index for index, value in enumerate(edge_attr) if value == 1]
        edge_attr_rl = [0 for idx in range(len(RLEdgeTerminalEnum))]
        for ind in indices:
            enum_name = get_enum_name_from_value(EdgeTerminalEnum, ind)
            if enum_name != 'NET':
                renum = RLEdgeTerminalEnum[enum_name].value
                edge_attr_rl[renum] = 1
        #custom_edges = graph.get_edge_features(e[0], e[1], ["custom_features"])
        graph.overwrite_edge_field(e[0], e[1], "one_hot_edge_attr", edge_attr_rl)

configuration = BaseSky130Configuration()
plugin = Sky130TechPlugin(configuration)
test_visualization_path = Path(__file__).resolve().parent / "figures" / "test"
interface = ASTtoGraphInterface(plugin, test_visualization_path)

file_list = [str(f) for f in Path(FOLDER_PATH).iterdir() if f.is_file() and f.suffix == ".scs"]
graph_list = []
netlist_dataset_homogenous = NetlistDataset(".")
file_no = 0
for file_name in file_list:
        #target_circuit = "AA_inv_chain"
        ast = plugin.parse_file(file_name)                  # Invoke the parser from the tech plugin, collect the AST.
        ast.subcircuits.update(plugin.base_ast.subcircuits) 
        ppr.pprint(ast.subcircuits)
        for k, v in ast.subcircuits.items():
            if "A_" in k: # This is a little hacky, need to name the top level of everything *A_name
                target_circuit = k
                break
        print(file_no, target_circuit)
        pin_order = ast.subcircuits[target_circuit].pins

        std_graph = interface.transform_to_standard_graph(file_name, target_circuit, "tt")
        one_hot_features = add_rl_features(std_graph, target_circuit, pin_order)
        netlist_dataset_homogenous.add_item(std_graph, target_circuit)
        graph_list.append(one_hot_features)
        file_no += 1

with open("dataset/full_edit/Z_graphs.pkl", "wb") as file:
        pickle.dump(graph_list, file)       
netlist_dataset_homogenous.process_new_data("dataset/full_edit/data_netlist.pt")