import torch
from data.preprocess_gds import *
import os
from tqdm import tqdm
import argparse

def compute_bboxes(cell_tensor, mask_idx = None):
    """
    Given a cell tensor of shape (N,9) where:
      - column 0: layer,
      - columns 1-8: coordinates for 4 points as (x,y) pairs,
    compute the bounding box for each cell.
    
    Returns:
      x_min, x_max, y_min, y_max each of shape (N,)
    """
    # Extract coordinates and reshape to (N, 4, 2)
    if mask_idx is not None:
        cell_tensor = cell_tensor[~mask_idx]
    coords = cell_tensor[:, 1:].view(-1, 4, 2)
    x_min = coords[:,:,0].min(dim=1)[0]
    x_max = coords[:,:,0].max(dim=1)[0]
    y_min = coords[:,:,1].min(dim=1)[0]
    y_max = coords[:,:,1].max(dim=1)[0]
    return x_min, x_max, y_min, y_max

def intersection_area(ax_min, ax_max, ay_min, ay_max,
                      bx_min, bx_max, by_min, by_max):
    """
    Compute the intersection area between two sets of boxes.
    The inputs are tensors; broadcasting is used.
    Returns a tensor of intersection areas.
    """
    inter_width = torch.clamp(torch.min(ax_max, bx_max) - torch.max(ax_min, bx_min), min=0)
    inter_height = torch.clamp(torch.min(ay_max, by_max) - torch.max(ay_min, by_min), min=0)
    return inter_width * inter_height

def area_from_bbox(x_min, x_max, y_min, y_max):
    """Compute area of bounding boxes."""
    return torch.clamp(x_max - x_min, min=0) * torch.clamp(y_max - y_min, min=0)


def check_rule1_torch(cells, mask_idx=None):
    """
    For each cell with layer 1457, compute its bounding box area and find the maximum
    overlap area with any cell on layer 515. The penalty for that cell is
      (cell_area - max_overlap) / cell_area,
    and the function returns the average penalty (a scalar tensor).
    """
    if mask_idx is not None:
        layers = cells[~mask_idx, 0]
    else:
        layers = cells[:, 0]
    x_min, x_max, y_min, y_max = compute_bboxes(cells,mask_idx=mask_idx)
    areas = area_from_bbox(x_min, x_max, y_min, y_max)
    
    # Create masks for layer 1457 and layer 515 (assuming layers stored as floats)
    mask_1457 = (layers == 2)
    mask_515  = (layers == 0)

    mask_515[areas > 1] = False # Ignore large areas for layer 515.
    
    # If no layer 1457 cell exists, return 0 penalty.
    if mask_1457.sum() == 0:
        return torch.tensor(0.0)
    
    # Extract bounding boxes for each group.
    x_min_1457 = x_min[mask_1457]
    x_max_1457 = x_max[mask_1457]
    y_min_1457 = y_min[mask_1457]
    y_max_1457 = y_max[mask_1457]
    area_1457  = areas[mask_1457]
    
    x_min_515 = x_min[mask_515]
    x_max_515 = x_max[mask_515]
    y_min_515 = y_min[mask_515]
    y_max_515 = y_max[mask_515]
    
    # If no layer 515 cells, then full violation for all 1457 cells.
    if mask_515.sum() == 0:
        return torch.tensor(0.0)
    
    # Compute pairwise intersection areas:
    # x_min_1457: shape (M,1), x_min_515: shape (1,K)
    inter_area = intersection_area(x_min_1457[:, None], x_max_1457[:, None], 
                                   y_min_1457[:, None], y_max_1457[:, None],
                                   x_min_515[None, :], x_max_515[None, :],
                                   y_min_515[None, :], y_max_515[None, :])
    # For each 1457 cell, get the maximum overlap area.
    # inter_area: shape (M,K), max_overlap: shape (M,)
    # I want to get the number of 1457 cells (M) that has no overlap with 515 cells (K).
    # If no overlap, max_overlap will be 0.
    max_overlap, _ = inter_area.max(dim=1)
    # max_overlap = max_overlap/area_1457  # Scale to be between 0 and 1.
    no_overlap = (max_overlap == 0).sum()
    # Penalty per cell: fraction of area not overlapped.
    penalty = no_overlap / len(max_overlap)
    return penalty   # Scale penalty to be between 0 and 100.


def check_rule2_torch(cells,mask_idx=None):
    """
    Computes two penalties:
      (a) Overlap penalty: for every pair with one cell on layer 515 and the other on layer 644,
          we add (overlap_area / (area1 + area2)). The final overlap penalty is the average.
      (b) Alternation penalty: for all cells on layers 515 and 644, compute their vertical centers,
          sort them by center_y, and then if two adjacent cells have the same layer, add a penalty of 1.
          The alternation penalty is (number of adjacent same-layer pairs) / (total adjacent pairs).
    Returns the total penalty (a scalar tensor).
    """
    if mask_idx is not None:
        layers = cells[~mask_idx, 0]
    else:
        layers = cells[:, 0]
    x_min, x_max, y_min, y_max = compute_bboxes(cells, mask_idx=mask_idx)
    areas = area_from_bbox(x_min, x_max, y_min, y_max)
    
    mask_515 = (layers == 0)
    mask_644 = (layers == 1)
    
    # Overlap penalty between 515 and 644.
    if mask_515.sum() > 0 and mask_644.sum() > 0:
        x_min_515 = x_min[mask_515]
        x_max_515 = x_max[mask_515]
        y_min_515 = y_min[mask_515]
        y_max_515 = y_max[mask_515]
        area_515  = areas[mask_515]
        
        x_min_644 = x_min[mask_644]
        x_max_644 = x_max[mask_644]
        y_min_644 = y_min[mask_644]
        y_max_644 = y_max[mask_644]
        area_644  = areas[mask_644]
        
        inter_area = intersection_area(x_min_515[:, None], x_max_515[:, None],
                                       y_min_515[:, None], y_max_515[:, None],
                                       x_min_644[None, :], x_max_644[None, :],
                                       y_min_644[None, :], y_max_644[None, :])
        
        inter_515 = intersection_area(x_min_515[:, None], x_max_515[:, None],
                                      y_min_515[:, None], y_max_515[:, None],
                                        x_min_515[None, :], x_max_515[None, :],
                                        y_min_515[None, :], y_max_515[None, :])
        inter_515.fill_diagonal_(0.0)
        inter_644 = intersection_area(x_min_644[None, :], x_max_644[None, :],
                                        y_min_644[None, :], y_max_644[None, :],
                                        x_min_644[:, None], x_max_644[:, None],
                                        y_min_644[:, None], y_max_644[:, None])
        inter_644.fill_diagonal_(0.0)
        
        # Normalize by the sum of areas for each pair.
        # Compute pairwise sum of areas.
        area_sum = area_515[:, None] + area_644[None, :]
        # Avoid division by zero.
        # overlap_ratio = torch.where(area_sum > 0, inter_area / area_sum, torch.zeros_like(inter_area))
        # overlap_penalty = (overlap_ratio > 0).sum() / (overlap_ratio.shape[0] + overlap_ratio.shape[1])
        inter_area_total = (inter_area.sum() + inter_515.sum() + inter_644.sum())/2
        overlap_penalty = inter_area_total / (area_515.sum() + area_644.sum())
    else:
        overlap_penalty = torch.tensor(0.0)
    
    return torch.clip(overlap_penalty, min=0, max=1)  # Scale to be between 0 and 100.  

def check_rule3_torch(cells, mask_idx = None, min_spacing = 0.1, min_num =None):

    if mask_idx is not None:
        layers = cells[~mask_idx, 0]
    else:
        layers = cells[:, 0]
    x_min, x_max, y_min, y_max = compute_bboxes(cells, mask_idx=mask_idx)
    mask_1457 = (layers == 2)

    if mask_1457.sum() < 2:
        return torch.tensor(0.0)

    if min_num is not None and type(min_num) == int:
        if mask_1457.sum() < min_num:
            return torch.tensor(1.0)
    
    centers_x = (x_min + x_max) / 2.0
    centers_y = (y_min + y_max) / 2.0
    
    centers_x = centers_x[mask_1457]
    centers_y = centers_y[mask_1457]
    # x_min_1457 = x_min[mask_1457]
    # x_max_1457 = x_max[mask_1457]
    y_min_1457 = y_min[mask_1457]
    y_max_1457 = y_max[mask_1457]
    
    M = centers_x.shape[0]
    Cx_i = centers_x.view(M, 1)
    Cx_j = centers_x.view(1, M)
    Cy_i = centers_y.view(M, 1)
    Cy_j = centers_y.view(1, M)
    
    # Condition: either Cx_i lies within the x-range of cell j or vice versa.
    # cond = ((Cx_i >= x_min_1457.view(M, 1)) & (Cx_i <= x_max_1457.view(M, 1))) | \
    #        ((Cx_j >= x_min_1457.view(1, M)) & (Cx_j <= x_max_1457.view(1, M)))
    cond = ((Cy_i >= y_min_1457.view(1, M)) & (Cy_i <= y_max_1457.view(1, M))) | \
    ((Cy_j >= y_min_1457.view(M, 1)) & (Cy_j <= y_max_1457.view(M, 1)))
    
    tril = torch.triu(torch.ones(M, M), diagonal=1).bool()
    valid = cond & tril
    if valid.sum() == 0:
        return torch.tensor(0.0)
    
    gap_x = torch.abs(Cx_i - Cx_j)
    penalty_pair = torch.clamp((min_spacing - gap_x), min=0)
    penalty_pair = penalty_pair[valid]
    # print(penalty_pair.shape)
    # return (penalty_pair > 0).sum() / mask_1457.sum()
    return (penalty_pair > 0).sum() / len(penalty_pair.view(-1))


def check_rule4_torch(cells, mask_idx = None,min_spacing = 0.2, min_num = None):
    """
    For each pair of layer 1457 cells, if the center of one lies within the horizontal span
    (x_min to x_max) of the other, compute the vertical gap (|center_y_i - center_y_j|).
    If the gap is less than min_spacing, the penalty is (min_spacing - gap)/min_spacing.
    Returns the average penalty over all such pairs.
    """
    if mask_idx is not None:
        layers = cells[~mask_idx, 0]
    else:
        layers = cells[:, 0]
    x_min, x_max, y_min, y_max = compute_bboxes(cells, mask_idx=mask_idx)
    mask_1457 = (layers == 2)

    if mask_1457.sum() < 2:
        return torch.tensor(0.0)

    if min_num is not None and type(min_num) == int:
        if mask_1457.sum() < min_num:
            return torch.tensor(1.0)

    centers_x = (x_min + x_max) / 2.0
    centers_y = (y_min + y_max) / 2.0
    
    centers_x = centers_x[mask_1457]
    centers_y = centers_y[mask_1457]
    x_min_1457 = x_min[mask_1457]
    x_max_1457 = x_max[mask_1457]
    
    M = centers_x.shape[0]
    Cx_i = centers_x.view(M, 1)
    Cx_j = centers_x.view(1, M)
    Cy_i = centers_y.view(M, 1)
    Cy_j = centers_y.view(1, M)
    

    cond = ((Cx_i >= x_min_1457.view(1, M)) & (Cx_i <= x_max_1457.view(1, M))) | \
    ((Cx_j >= x_min_1457.view(M, 1)) & (Cx_j <= x_max_1457.view(M, 1)))
    
    tril = torch.triu(torch.ones(M, M), diagonal=1).bool()
    valid = cond & tril
    if valid.sum() == 0:
        return torch.tensor(0.0)
    
    gap_y = torch.abs(Cy_i - Cy_j)
    penalty_pair = torch.clamp((min_spacing - gap_y), min=0)
    penalty_pair = penalty_pair[valid]

    return (penalty_pair > 0).sum() / len(penalty_pair.view(-1))

def parse_arguments():
    parser = argparse.ArgumentParser("Evaluate DRC")
    parser.add_argument("--folder", type=str, required=True, help="Path to the folder containing sampled .txt files")
    return parser.parse_args()

if __name__ == "__main__":
    
    args = parse_arguments()
    
    if not os.path.isdir(args.folder):
        print(f"Error: {args.folder} is not a valid directory.")
        exit(1)
        
    print(f"Processing files in {args.folder}...")
    
    drc_array = []
    txt_files = [f for f in os.listdir(args.folder) if f.endswith('.txt')]
    
    if not txt_files:
        print("No .txt files found in the specified folder.")
        exit(1)
        
    print(f"Found {len(txt_files)} .txt files.")
    
    for file in tqdm(txt_files):
        file_path = os.path.join(args.folder, file)
        
        header, cells = parse_layout_file(file_path=file_path)
        cell_tensor = cells_to_tensor(cells)

        resolution = 40000
        normalized_tensor = normalize_data(cell_tensor, resolution=resolution)
        cell_tensor_generated, mask_generated = prepare_batch(normalized_tensor, max_len=600)

        # Perform DRC checks
        scale = resolution * 0.025
        rule1_checks = check_rule1_torch(cell_tensor_generated, mask_generated)
        rule2_checks = check_rule2_torch(cell_tensor_generated, mask_generated)
        rule3_checks = check_rule3_torch(cell_tensor_generated, mask_generated, min_spacing = 60/scale, min_num=0)
        rule4_checks = check_rule4_torch(cell_tensor_generated, mask_generated, min_spacing = 50/scale, min_num=0)
        drc_array.append([rule1_checks, rule2_checks, rule3_checks, rule4_checks])
   
    value = torch.tensor(drc_array)
    print(f"DRC checks:")
    if len(value) > 0:
        for i in range(value.shape[1]):
            array = value[:,i]
            print('Rule %d: %.3f' % (i+1, array.mean().item()))
