from typing import Union, List, Tuple
import os
import pickle
import pandas as pd
import torch
from torch_geometric.data import InMemoryDataset, Data, download_url

NAME_TO_ID_PINS = {
    "VDD": 0,
    "VSS": 1,
    "VIN": 2,
    "NM": 3,
    "NM_D": 4,
    "NM_G": 5,
    "NM_S": 6,
    "NM_B": 7,
    "VOUT": 8,
    "R": 9,
    "C": 10,
    "VB": 11,
    "PM": 12,
    "PM_D": 13,
    "PM_G": 14,
    "PM_S": 15,
    "PM_B": 16,
    "IB": 17,
    "VCONT": 18,
    "IOUT": 19,
    "IIN": 20,
    "VCM": 21,
    "VREF": 22,
    "IREF": 23,
    "NPN": 24,
    "NPN_C": 25,
    "NPN_B": 26,
    "NPN_E": 27,
    "PNP": 28,
    "PNP_C": 29,
    "PNP_B": 30,
    "PNP_E": 31,
    "VCLK": 32,
    "TRANSMISSION_GATE": 33,
    "TRANSMISSION_GATE_A": 34,
    "TRANSMISSION_GATE_B": 35,
    "TRANSMISSION_GATE_C": 36,
    "TRANSMISSION_GATE_VDD": 37,
    "TRANSMISSION_GATE_VSS": 38,
    "INVERTER": 39,
    "INVERTER_A": 40,
    "INVERTER_Q": 41,
    "INVERTER_VDD": 42,
    "INVERTER_VSS": 43,
    "L": 44,
    "DIO": 45,
    "DIO_P": 46,
    "DIO_N": 47,
    "XOR": 48,
    "XOR_A": 49,
    "XOR_B": 50,
    "XOR_VDD": 51,
    "XOR_VSS": 52,
    "XOR_Y": 53,
    "LOGICA": 54,
    "LOGICB": 55,
    "PFD": 56,
    "PFD_A": 57,
    "PFD_B": 58,
    "PFD_QA": 59,
    "PFD_QB": 60,
    "PFD_VDD": 61,
    "PFD_VSS": 62,
    "VRF": 63,
    "VLO": 64,
    "VIF": 65,
    "VBB": 66,
    "LOGICQA": 67,
    "LOGICQB": 68,
    "LOGICF": 69,
    "LOGICG": 70,
    "LOGICD": 71,
    "LOGICQ": 72,
    "VLOGICA": 73,
    "VLOGICB": 74,
    "VLATCH": 75,
    "VTRACK": 76,
    "VHOLD": 77,
    "IN": 78,
    "VCKL": 79,
    "net": 80
}

NAME_TO_ID_NODES = {
    "VDD": 0,
    "VSS": 1,
    "VIN": 2,
    "NM": 3,
    "VOUT": 4,
    "R": 5,
    "C": 6,
    "VB": 7,
    "PM": 8,
    "net": 9,
    "IB": 10,
    "VCONT": 11,
    "IOUT": 12,
    "IIN": 13,
    "VCM": 14,
    "VREF": 15,
    "IREF": 16,
    "NPN": 17,
    "PNP": 18,
    "VCLK": 19,
    "TRANSMISSION_GATE": 20,
    "INVERTER": 21,
    "L": 22,
    "DIO": 23,
    "VRF": 24,
    "VLO": 25,
    "VIF": 26,
    "VBB": 27
}

ID_TO_NAME_NODES = {v: k for k, v in NAME_TO_ID_NODES.items()}

ID_TO_NAME_PINS = {v: k for k, v in NAME_TO_ID_PINS.items()}

node2pins = {
    "PM": ["PM_D", "PM_G", "PM_S", "PM_B"],
    "NM": ["NM_D", "NM_G", "NM_S", "NM_B"],
    "NPN": ["NPN_C", "NPN_B", "NPN_E"],
    "PNP": ["PNP_C", "PNP_B", "PNP_E"],
    "DIO": ["DIO_P", "DIO_N"],
    "XOR": ["XOR_A", "XOR_B", "XOR_VDD", "XOR_VSS", "XOR_Y"],
    "PFD": ["PFD_A", "PFD_B", "PFD_QA", "PFD_QB", "PFD_VDD", "PFD_VSS"],
    "INVERTER": ["INVERTER_A", "INVERTER_Q", "INVERTER_VDD", "INVERTER_VSS"],
    "TRANSMISSION_GATE": ["TRANSMISSION_GATE_A", "TRANSMISSION_GATE_B", "TRANSMISSION_GATE_C", "TRANSMISSION_GATE_VDD", "TRANSMISSION_GATE_VSS"],
}

eletric_nodes = ["VDD","VIN","VOUT","VSS","net","VB","VBB","VLO","VIF","IB","VCONT","VCLK","VREF","IIN","VCM","IOUT","IREF","VRF","VLATCH","VTRACK",
                 "VHOLD","IN","VCKL"]

class AnalogGenieDataset(InMemoryDataset):
    def __init__(self, root: str, transform=None, pre_filter=None, split="train", pins=True, pin_prediction=False):
        """
        Custom PyTorch Geometric Dataset for handling graph data.

        Args:
            root (str): Root directory where the dataset should be stored.
            raw_path (str, List[str], Tuple): Path to the raw data file(s).
            transform (callable, optional): A function/transform applied to each graph during loading.
            pre_transform (callable, optional): A function/transform applied to each graph before saving it to disk.
            pre_filter (callable, optional): A function that filters out unwanted graphs.
            pins (boolean): specify wihich version of the dataset you want to load the one with or without pins.
            pin_prediction (boolnean): specify if we are training the network just to predict the connections between the pins and the other nodes.
        """
        if pins:
            self.raw_path = os.path.join(root, 'raw/igraph_analogenie_dataset_v2_pins.pkl')
            self.node2id = NAME_TO_ID_PINS
        else:
            self.raw_path = os.path.join(root, 'raw/igraph_analogenie_dataset_v2_nodes.pkl')
            self.node2id = NAME_TO_ID_NODES
            # if pin_prediction:
            #     self.node2id = NAME_TO_ID_PINS
            # else:
            #     self.node2id = NAME_TO_ID_NODES
        self.pin_prediction = pin_prediction
                
        self.root = root
        self.pins = pins
        if not os.path.exists(self.root):
            os.makedirs(self.root)
        assert split in ['train', 'val', 'test']
        super().__init__(root, transform, pre_filter)
        dataset_version = "pins" if self.pins else "nodes" 
        path = os.path.join(self.processed_dir, f'{split}_{dataset_version}.pt') # --> so ***nodes_pin_prediction is useless -> to update
        self.load(path)

    @property
    def raw_file_names(self):
        if self.pins:
            return "igraph_analogenie_dataset_v2_pins.pkl"
        else:
            return "igraph_analogenie_dataset_v2_nodes.pkl"

    @property
    def processed_dir(self):
        return os.path.join(self.root, 'processed')
    
    @property
    def processed_file_names(self) -> List[str]:
        dataset_version = "pins" if self.pins else "nodes" 
        if self.pin_prediction:
            dataset_version += "_pin_prediction"
        return [f'train_{dataset_version}.pt', f'val_{dataset_version}.pt', f'test_{dataset_version}.pt']
    
    def download(self):
        pass

    def _get_number_of_files(self):
        """
        Placeholder method to compute the number of graphs to process.
        Override or implement as needed based on your raw data structure.
        """
        with open(self.raw_path, 'rb') as file:
            splits = pickle.load(file)
        return len(splits[0]) + len(splits[1])  # Sum of training and testing splits
    
    def pretransform(self, data):
        if self.pins:
            return self.igraph_to_pyg(data["igraph"], None, int(data["index"]))
        else:
            return self.igraph_to_pyg(data["igraph"], data["labels"], int(data["index"]))
        
    def igraph_to_pyg(self, igraph_graph, label, circuit_idx):
        """
        Converts an igraph graph to a PyTorch Geometric Data object.
        
        Args:
            igraph_graph (igraph.Graph): The input graph in igraph format.
            
        Returns:
            torch_geometric.data.Data: The converted graph in PyTorch Geometric format.
        """
        # Get the edge index
        edges = igraph_graph.get_edgelist()  # List of tuples (source, target)
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()  # Convert to tensor and transpose

        # Get node features, if they exist
        if igraph_graph.vs.attributes():
            # Combine all node attributes into a single feature matrix
            node_features = []
            for attr in igraph_graph.vs.attributes():
                if attr == "type":
                    node_features.append(torch.tensor([self.node2id[attr] for attr in igraph_graph.vs[attr]]))
                # name is like NM1 , so type + a number we save just the number in the name gievn that the type can be found in type
                elif attr == "name": 
                    numbers = [int(''.join(filter(str.isdigit, name))) if any(char.isdigit() for char in name) else 0 for name in igraph_graph.vs[attr]]
                    node_features.append(torch.tensor(numbers))
            x = torch.stack(node_features, dim=1) if node_features else None
            x = x.long()
        else:
            x = None

        # Get edge features, if they exist
        if igraph_graph.es.attributes():
            edge_features = []
            for attr in igraph_graph.es.attributes():
                edge_features.append(torch.tensor(igraph_graph.es[attr]))
            edge_attr = torch.stack(edge_features, dim=1) if edge_features else None
        else:
            edge_attr = torch.ones((len(edges),), dtype=torch.float)
        edge_attr = edge_attr.long()

        """ 
        I save the labels as node_idx of the components like NM1 looking in the graph the node with 
        name equal to NM1.

        then a list of idx that are the connections of the pins of that component, based on the type of 
        the component the order matters and is the same used in spice and stored in the variable node2pins.
        """
        label_idxs = []
        if label:
            for node in label:
                node_idx = igraph_graph.vs.select(name=node).indices[0] 
                connections = [node_idx,[]]
                pin_connections = label[node]
                for pin in pin_connections:
                    node_connected = igraph_graph.vs.select(name=pin)
                    if len(node_connected.indices) > 1:
                        raise ValueError(f"Multiple nodes found for pin '{pin}' in component '{node}'.")
                    node_connected = node_connected.indices[0]
                    connections[1].append(node_connected)

                label_idxs.append(connections)

        # if label:
        #     label = torch.tensor([0]).unsqueeze(0)

        # Create the PyG Data object
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,y =label_idxs, rrwp=None, rrwp_index=None, rrwp_val=None, log_deg=None, deg=None, circuit_idx= circuit_idx)

        # Add any global graph attributes, if they exist
        for attr in igraph_graph.attributes():
            setattr(data, attr, igraph_graph[attr])

        return data

    def process(self):
        with open(self.raw_path, 'rb') as file:
            splits = pickle.load(file)

        train = splits[:int(len(splits) * 0.8)]
        val = splits[int(len(splits) * 0.8) : int(len(splits) * 0.9)]
        test = splits[int(len(splits) * 0.9):]

        combined_data = [("train",train),("val",val),("test",test)]
        for split_name,split in combined_data:
            if self.pre_filter:
                split = [data for data in split if self.pre_filter(data)]
            data_list = [self.pretransform(data) for idx,data in enumerate(split)]
            os.makedirs(self.processed_dir, exist_ok=True)
            dataset_version = "pins" if self.pins else "nodes" 
            if self.pin_prediction:
                dataset_version += "_pin_prediction"
            self.save(data_list, os.path.join(self.processed_dir, f'{split_name}_{dataset_version}.pt'))
