import os
import numpy as np
import pandas as pd
import networkx as nx
import torch
from scipy.sparse import coo_matrix
from torch_geometric.data import HeteroData, Data
from torch_geometric.utils import to_undirected, add_self_loops


def preprocess_netlist_data(netlist_path):
    """
    Preprocess netlist data from ISPD2015 dataset.

    Args:
        netlist_path (str): Path to the netlist file

    Returns:
        dict: Dictionary containing cell information, net connections, and cell features
    """
    # Read netlist data
    cell_info = {}
    net_connections = {}

    with open(netlist_path, 'r') as f:
        lines = f.readlines()

    # Parse cell information and net connections
    for line in lines:
        if line.startswith('CELL'):
            # Parse cell information
            parts = line.strip().split()
            cell_id = int(parts[1])
            x_pos = float(parts[2])
            y_pos = float(parts[3])
            width = float(parts[4])
            height = float(parts[5])

            cell_info[cell_id] = {
                'x': x_pos,
                'y': y_pos,
                'width': width,
                'height': height
            }

        elif line.startswith('NET'):
            # Parse net connections
            parts = line.strip().split()
            net_id = int(parts[1])
            num_pins = int(parts[2])
            connected_cells = [int(parts[3 + i]) for i in range(num_pins)]

            net_connections[net_id] = connected_cells

    # Extract cell features from cell information
    cell_features = np.zeros((len(cell_info), 4))
    for cell_id, info in cell_info.items():
        cell_features[cell_id] = [info['x'], info['y'], info['width'], info['height']]

    return {
        'cell_info': cell_info,
        'net_connections': net_connections,
        'cell_features': cell_features
    }


def preprocess_layout_data(layout_path, grid_size=(64, 64)):
    """
    Preprocess layout data from ISPD2015 dataset.

    Args:
        layout_path (str): Path to the layout file
        grid_size (tuple): Size of the grid (M, N)

    Returns:
        dict: Dictionary containing grid features and congestion ground truth
    """
    # Read layout data
    layout_data = pd.read_csv(layout_path)

    # Extract layout dimensions
    x_min, y_min = layout_data['x'].min(), layout_data['y'].min()
    x_max, y_max = layout_data['x'].max(), layout_data['y'].max()

    # Create grid
    M, N = grid_size
    x_step = (x_max - x_min) / M
    y_step = (y_max - y_min) / N

    # Initialize grid features and congestion ground truth
    grid_features = np.zeros((M, N, 5))  # RUDY, PinRUDY, MacroRegion, H_MacroMargin, V_MacroMargin
    congestion_gt = np.zeros((M, N))

    # Process each data point in the layout
    for _, row in layout_data.iterrows():
        x, y = row['x'], row['y']
        grid_i = min(M - 1, max(0, int((x - x_min) / x_step)))
        grid_j = min(N - 1, max(0, int((y - y_min) / y_step)))

        # Update grid features
        grid_features[grid_i, grid_j, 0] += row['RUDY']
        grid_features[grid_i, grid_j, 1] += row['PinRUDY']
        grid_features[grid_i, grid_j, 2] += row['MacroRegion']
        grid_features[grid_i, grid_j, 3] += row['H_MacroMargin']
        grid_features[grid_i, grid_j, 4] += row['V_MacroMargin']

        # Update congestion ground truth
        # We use the average of horizontal and vertical congestion
        congestion_gt[grid_i, grid_j] = (row['H_Congestion'] + row['V_Congestion']) / 2

    # Normalize grid features
    for i in range(grid_features.shape[2]):
        if grid_features[:, :, i].max() > 0:
            grid_features[:, :, i] = grid_features[:, :, i] / grid_features[:, :, i].max()

    # Reshape grid features and congestion ground truth
    grid_features_flat = grid_features.reshape(M * N, -1)
    congestion_gt_flat = congestion_gt.reshape(M * N)

    # Create grid cell coordinates for grid node features
    grid_coords = np.zeros((M * N, 2))
    for i in range(M):
        for j in range(N):
            idx = i * N + j
            grid_coords[idx] = [i / M, j / N]  # Normalized coordinates

    # Combine features with coordinates
    grid_features_with_coords = np.concatenate([grid_features_flat, grid_coords], axis=1)

    return {
        'grid_features': grid_features_with_coords,
        'grid_congestion_gt': congestion_gt_flat,
        'grid_size': (M, N)
    }


def construct_cell_hypergraph(netlist_data):
    """
    Construct cell-based hypergraph from netlist data using PyTorch Geometric.

    Args:
        netlist_data (dict): Netlist data from preprocess_netlist_data

    Returns:
        torch_geometric.data.HeteroData: Cell-based hypergraph
    """
    cell_info = netlist_data['cell_info']
    net_connections = netlist_data['net_connections']
    cell_features = netlist_data['cell_features']

    # Create a HeteroData object
    data = HeteroData()

    # Add cell nodes
    num_cells = len(cell_info)
    data['cell'].x = torch.FloatTensor(cell_features)

    # Add net nodes
    num_nets = len(net_connections)
    data['net'].x = torch.zeros(num_nets, cell_features.shape[1])

    # Create edges between cells and nets
    cell_to_net_edges = []
    for net_id, connected_cells in net_connections.items():
        for cell_id in connected_cells:
            if cell_id < num_cells:  # Ensure cell_id is valid
                cell_to_net_edges.append((cell_id, net_id))

    if cell_to_net_edges:
        edge_index = torch.tensor(cell_to_net_edges, dtype=torch.long).t().contiguous()
        data['cell', 'to', 'net'].edge_index = edge_index
        # Add reverse edges (net to cell)
        data['net', 'to', 'cell'].edge_index = torch.stack([edge_index[1], edge_index[0]])

    return data


def construct_grid_hypergraph(layout_data, netlist_data, grid_size=(64, 64)):
    """
    Construct grid-based hypergraph from layout data and netlist data using PyTorch Geometric.

    Args:
        layout_data (dict): Layout data from preprocess_layout_data
        netlist_data (dict): Netlist data from preprocess_netlist_data
        grid_size (tuple): Size of the grid (M, N)

    Returns:
        torch_geometric.data.HeteroData: Grid-based hypergraph
    """
    cell_info = netlist_data['cell_info']
    net_connections = netlist_data['net_connections']
    grid_features = layout_data['grid_features']
    M, N = grid_size

    # Create a HeteroData object
    data = HeteroData()

    # Add grid nodes
    num_grids = M * N
    data['grid'].x = torch.FloatTensor(grid_features)

    # Add net nodes
    num_nets = len(net_connections)
    data['net'].x = torch.zeros(num_nets, grid_features.shape[1])

    # Map cells to grids based on their positions
    x_min, y_min = min(info['x'] for info in cell_info.values()), min(info['y'] for info in cell_info.values())
    x_max, y_max = max(info['x'] for info in cell_info.values()), max(info['y'] for info in cell_info.values())
    x_step = (x_max - x_min) / M
    y_step = (y_max - y_min) / N

    # Connect nets to grids they pass through
    net_to_grids = {}
    for net_id, connected_cells in net_connections.items():
        grid_set = set()

        for cell_id in connected_cells:
            if cell_id in cell_info:
                x, y = cell_info[cell_id]['x'], cell_info[cell_id]['y']
                grid_i = min(M - 1, max(0, int((x - x_min) / x_step)))
                grid_j = min(N - 1, max(0, int((y - y_min) / y_step)))
                grid_idx = grid_i * N + grid_j
                grid_set.add(grid_idx)

        net_to_grids[net_id] = list(grid_set)

    # Create edges between grids and nets
    grid_to_net_edges = []
    for net_id, grid_list in net_to_grids.items():
        for grid_idx in grid_list:
            grid_to_net_edges.append((grid_idx, net_id))

    if grid_to_net_edges:
        edge_index = torch.tensor(grid_to_net_edges, dtype=torch.long).t().contiguous()
        data['grid', 'to', 'net'].edge_index = edge_index
        # Add reverse edges (net to grid)
        data['net', 'to', 'grid'].edge_index = torch.stack([edge_index[1], edge_index[0]])

    # Add spatial adjacency between grid cells
    grid_adjacency_edges = []
    for i in range(M):
        for j in range(N):
            grid_idx = i * N + j

            # Connect to adjacent grids (4-connectivity)
            if i > 0: grid_adjacency_edges.append((grid_idx, (i - 1) * N + j))  # up
            if i < M - 1: grid_adjacency_edges.append((grid_idx, (i + 1) * N + j))  # down
            if j > 0: grid_adjacency_edges.append((grid_idx, i * N + (j - 1)))  # left
            if j < N - 1: grid_adjacency_edges.append((grid_idx, i * N + (j + 1)))  # right

    if grid_adjacency_edges:
        edge_index = torch.tensor(grid_adjacency_edges, dtype=torch.long).t().contiguous()
        data['grid', 'adjacent', 'grid'].edge_index = edge_index

    return data


def map_congestion_to_cells(netlist_data, layout_data, grid_size=(64, 64)):
    """
    Map grid-based congestion to cell-based congestion.

    Args:
        netlist_data (dict): Netlist data from preprocess_netlist_data
        layout_data (dict): Layout data from preprocess_layout_data
        grid_size (tuple): Size of the grid (M, N)

    Returns:
        np.ndarray: Cell-based congestion values
    """
    cell_info = netlist_data['cell_info']
    grid_congestion = layout_data['grid_congestion_gt'].reshape(grid_size)
    M, N = grid_size

    # Extract layout dimensions
    x_min, y_min = min(info['x'] for info in cell_info.values()), min(info['y'] for info in cell_info.values())
    x_max, y_max = max(info['x'] for info in cell_info.values()), max(info['y'] for info in cell_info.values())

    # Calculate grid steps
    x_step = (x_max - x_min) / M
    y_step = (y_max - y_min) / N

    # Map cells to grid congestion
    cell_congestion = np.zeros(len(cell_info))
    for cell_id, info in cell_info.items():
        x, y = info['x'], info['y']
        grid_i = min(M - 1, max(0, int((x - x_min) / x_step)))
        grid_j = min(N - 1, max(0, int((y - y_min) / y_step)))

        # Assign grid congestion to cell
        cell_congestion[cell_id] = grid_congestion[grid_i, grid_j]

    return cell_congestion