from typing import Union, List, Tuple
import os
import pickle
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data, download_url
import zipfile
from torch_geometric.data.separate import separate
import copy


class OCBDataset(InMemoryDataset):
    def __init__(self, root: str, transform=None, pre_filter=None, use_pins=False, split="train"):
        """
        Custom PyTorch Geometric Dataset for handling graph data.

        Args:
            root (str): Root directory where the dataset should be stored.
            raw_path (str, List[str], Tuple): Path to the raw data file(s).
            transform (callable, optional): A function/transform applied to each graph during loading.
            pre_transform (callable, optional): A function/transform applied to each graph before saving it to disk.
            pre_filter (callable, optional): A function that filters out unwanted graphs.
            subcircuit (bool, optional): The dataset, comes with two version, the one using predefined subcircuit and the one using the whole circuit.
                this parameter is used to select which version to use. If True, the dataset will use the predefined subcircuit, otherwise the whole circuit.
        """
        self.raw_path = os.path.join(root, 'raw')
        self.root = root
        if not os.path.exists(self.root):
            os.makedirs(self.root)
        
        assert ('CktBench301' in root) or ('CktBench101' in root)
        key = os.path.basename(root).split('CktBench')[-1]
        self.raw_filename = 'OCB101v2_graphs' # f'ckt_bench_{key}.pkl'
        self.perform_filename = 'OCB101v2_labels' # 'OCB101v2_labels' # f'perform{key}.csv'

        assert split in ['train', 'val', 'test']
        super().__init__(root, transform, pre_filter)
        key = 'pins' if use_pins else 'no_pins'
        path = os.path.join(self.processed_dir, f'{split}_{key}.pt')
        self.load(path)

    @property
    def raw_file_names(self):
        return [self.raw_filename, self.perform_filename]

    @property
    def processed_dir(self):
        return os.path.join(self.root, 'processed')
    
    @property
    def processed_file_names(self) -> List[str]:
        return ['train_pins.pt', 'val_pins.pt', 'test_pins.pt', 'train_no_pins.pt', 'val_no_pins.pt', 'test_no_pins.pt']
        # return ['train.pt', 'val.pt', 'test.pt']
    
    def download(self):
        if "CktBench301" in self.root:
            url = "https://raw.githubusercontent.com/zehao-dong/CktGNN/refs/heads/main/OCB/CktBench301/ckt_bench_301.pkl.zip"
            download_url(url, self.raw_path)
            with zipfile.ZipFile(os.path.join(self.raw_path, 'ckt_bench_301.pkl.zip'), 'r') as zip_ref:
                zip_ref.extractall(self.raw_path)
            url = "https://raw.githubusercontent.com/zehao-dong/CktGNN/refs/heads/main/OCB/CktBench301/perform301.csv"
            download_url(url, self.raw_path)
        elif "CktBench101" in self.root:
            url = "https://raw.githubusercontent.com/zehao-dong/CktGNN/refs/heads/main/OCB/CktBench101/ckt_bench_101.pkl"
            download_url(url, self.raw_path)
            url = "https://raw.githubusercontent.com/zehao-dong/CktGNN/refs/heads/main/OCB/CktBench101/perform101.csv"
            download_url(url, self.raw_path)
        else:
            raise ValueError("Invalid dataset name. Should include 'CktBench301' or 'CktBench101'.")

    def _get_number_of_files(self):
        """
        Placeholder method to compute the number of graphs to process.
        Override or implement as needed based on your raw data structure.
        """
        with open(os.path.join(self.raw_path, self.raw_filename), 'rb') as file:
            splits = pickle.load(file)
        return len(splits[0]) + len(splits[1])  # Sum of training and testing splits
    
    def pretransform(self, data, label):
        return self.igraph_to_pyg(data, label)
        
    def igraph_to_pyg(self, igraph_graph, label):
        """
        Converts an igraph graph to a PyTorch Geometric Data object.
        
        Args:
            igraph_graph (igraph.Graph): The input graph in igraph format.
            
        Returns:
            torch_geometric.data.Data: The converted graph in PyTorch Geometric format.
        """
        # Get the edge index
        edges = igraph_graph.get_edgelist()  # List of tuples (source, target)
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()  # Convert to tensor and transpose

        # Get node features, if they exist
        if igraph_graph.vs.attributes():
            # Combine all node attributes into a single feature matrix
            node_features = []
            for attr in igraph_graph.vs.attributes():
                if None in igraph_graph.vs[attr]:
                    # If features are absent, generate them randomly
                    node_features.append(torch.from_numpy(np.random.randint(1, 100, len(igraph_graph.vs['type']))))
                else:
                    node_features.append(torch.tensor(igraph_graph.vs[attr]))
            x = torch.stack(node_features, dim=1) if node_features else None
            x = x.long()
        else:
            x = None

        # Get edge features, if they exist
        if igraph_graph.es.attributes():
            edge_features = []
            for attr in igraph_graph.es.attributes():
                edge_features.append(torch.tensor(igraph_graph.es[attr]))
            edge_attr = torch.stack(edge_features, dim=1) if edge_features else None
        else:
            edge_attr = torch.ones((len(edges),), dtype=torch.float)
        edge_attr = edge_attr.long()

        # label = torch.tensor([label["gain"], label["pm"], label["bw"]]).unsqueeze(0)  # shape: (1, 3)
        # New version below
        label = torch.tensor(label).unsqueeze(0)  # shape: (1, 3)

        # Create the PyG Data object
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=label, rrwp=None, rrwp_index=None, rrwp_val=None, log_deg=None, deg=None)

        # Add any global graph attributes, if they exist
        for attr in igraph_graph.attributes():
            setattr(data, attr, igraph_graph[attr])

        return data

    def process(self):
        with open(os.path.join(self.raw_path, self.raw_filename), 'rb') as file:
            splits = pickle.load(file)
        splits = splits[0] + splits[1]

        # # Used for 1st version of OCB
        # Y = []
        # df = pd.read_csv(os.path.join(self.raw_path, self.perform_filename))
        # for index, row in df.iterrows():
        #     row = row.to_dict()
        #     Y.append(row)

        # Now labels are already contained in two lists of dict, 1st entry being train/val, 2nd test
        with open(os.path.join(self.raw_path, self.perform_filename), 'rb') as f:
            Y = pickle.load(f)
        Y = Y[0] + Y[1]

        if "CktBench101" in self.raw_path:        
            train = splits[:int(len(splits) * 0.8)]
            train_y = Y[:int(len(splits) * 0.8)]
            val = splits[int(len(splits) * 0.8):int(len(splits) * 0.9)]
            val_y = Y[int(len(splits) * 0.8):int(len(splits) * 0.9)]
            test = splits[int(len(splits) * 0.9):]
            test_y = Y[int(len(splits) * 0.9):]
        else: # Unsupported without re-generating OCB-301 beforehand
            train = splits[:int(len(splits) * 0.8)]
            train_y = Y[:int(len(splits) * 0.8)]
            val = splits[int(len(splits) * 0.8):int(len(splits) * 0.9)]
            val_y = Y[int(len(splits) * 0.8):int(len(splits) * 0.9)]
            test = splits[int(len(splits) * 0.9):]
            test_y = Y[int(len(splits) * 0.9):]

        combined_data = [("train", train, train_y), ("val", val, val_y), ("test", test, test_y)]
        for split_name, split, y in combined_data:
            for k_idx, key in enumerate(['no_pins', 'pins']): # idx 0: graph w/ circuit nodes, idx 1: nodes + gm pins
                if self.pre_filter:
                    split = [data for data in split if self.pre_filter(data)]
                data_list = [self.pretransform(data[k_idx], y[idx]) for idx, data in enumerate(split)]
                os.makedirs(self.processed_dir, exist_ok=True)
                self.save(data_list, os.path.join(self.processed_dir, f'{split_name}_{key}.pt'))