from pathlib import Path
import pandas as pd 
import torch
from torch_geometric.data import Data 
from data.base import BaseDataset
from data.preprocess_gds import parse_layout_file, cells_to_tensor
import os
from tqdm import tqdm
import multiprocessing as mp
import numpy as np
import gc

def process_single_layout(args):
    file_path, data_path, cell_length = args
    
    try:
        abs_file_path = os.path.join(data_path, '..', file_path)
        
        if not os.path.exists(abs_file_path):
            return None
        
        _, cells = parse_layout_file(abs_file_path)
        
        if not cells:
            return None
            
        cell_tensor = cells_to_tensor(cells)
        
        # Normalize tensor
        normalized_tensor = torch.FloatTensor(cell_tensor.clone())
        normalized_tensor[cell_tensor[:, 0] == 515, 0] = 0 
        normalized_tensor[cell_tensor[:, 0] == 644, 0] = 1 
        normalized_tensor[cell_tensor[:, 0] == 1457, 0] = 2
        
        boxes = [] 
        labels = []
        
        for i in range(len(normalized_tensor)):
            cell = normalized_tensor[i] 
            label = int(cell[0])
            
            # Extract coordinates from GDS (x1, y1, ... x4, y4)
            coords = cell[1:9].reshape(4, 2)
            
            min_x = coords[:, 0].min() 
            max_x = coords[:, 0].max()
            min_y = coords[:, 1].min()
            max_y = coords[:, 1].max()
            
            xc = (min_x + max_x) / 2
            yc = (min_y + max_y) / 2
            width = max_x - min_x 
            height = max_y - min_y
            
            resolution = 40000 
            # xc = np.clip((xc + resolution) / (2 * resolution), 0, 1)
            # yc = np.clip((yc + resolution) / (2 * resolution), 0, 1)
            # width = np.clip(width / (2 * resolution), 0, 1)
            # height = np.clip(height / (2 * resolution), 0, 1)
            xc = np.clip(xc / resolution, 0, 1)
            yc = np.clip(yc / resolution, 0, 1)
            width = np.clip(width / resolution, 0, 1)
            height = np.clip(height / resolution, 0, 1)
            
            boxes.append([xc, yc, width, height])
            labels.append(label)
            
        if len(boxes) == 0:
            return None
        
        boxes = torch.tensor(boxes, dtype=torch.float)
        labels = torch.tensor(labels, dtype=torch.long)
        
        data = Data(x=boxes, y=labels)
        data.attr = {
            'name': Path(file_path).name,
            'cell_length': cell_length,
        }
        return data
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None
    
def process_single_layout_optimized(args):
    file_path, data_path, cell_length = args
    
    try:
        abs_file_path = os.path.join(data_path, '..', file_path)
        
        if not os.path.exists(abs_file_path):
            return None
        
        _, cells = parse_layout_file(abs_file_path)
        
        if not cells:
            return None
            
        cell_tensor = cells_to_tensor(cells)
        
        # Vectorized label normalization
        labels = cell_tensor[:, 0].clone()
        labels[labels == 515] = 0
        labels[labels == 644] = 1
        labels[labels == 1457] = 2
        
        # Vectorized coordinate processing
        coords = cell_tensor[:, 1:9].reshape(-1, 4, 2)  # [N, 4, 2]
        
        # Compute bounding boxes for all cells at once
        min_coords = coords.min(dim=1)[0]  # [N, 2]
        max_coords = coords.max(dim=1)[0]  # [N, 2]
        
        centers = (min_coords + max_coords) / 2  # [N, 2]
        sizes = max_coords - min_coords  # [N, 2]
        
        # Normalize all coordinates at once
        resolution = 40000
        centers_norm = torch.clamp(centers / resolution, 0, 1)
        sizes_norm = torch.clamp(sizes / resolution, 0, 1)
        
        # Combine into final format
        boxes = torch.cat([centers_norm, sizes_norm], dim=1)  # [N, 4]
        
        if len(boxes) == 0:
            return None
        
        data = Data(x=boxes.float(), y=labels.long())
        data.attr = {
            'name': Path(file_path).name,
            'cell_length': cell_length,
        }
        return data
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

class GDSLayout(BaseDataset):
    # GDS layout types
    labels = [
        'layout_515',       # 1
        'layout_644',       # 2
        'layout_1457'       # 3
    ]
    
    def __init__(self, split='train', transform=None, data_path=None):
        self.data_path = data_path
        super().__init__('gds', split, transform, data_path=self.data_path)
        
    def download(self):
        pass
        
    def process(self): 
        raise ValueError("Processing should be done externally and data should be pre-saved.")

