import numpy as np
from tqdm import tqdm
from copy import deepcopy
import math
import os
import torch
import pandas as pd


def parse_coordinate(coord_str):
    """
    Converts a coordinate string formatted as 'X: Y' into a tuple of integers.
    Returns:
        (x, y) tuple if successful, otherwise None.
    """
    if ":" in coord_str:
        parts = coord_str.split(":")
        if len(parts) == 2:
            try:
                x = int(parts[0].strip())
                y = int(parts[1].strip())
                return (x, y)
            except ValueError:
                pass
    return None

def parse_layout_file(file_path):
    header = []
    cells = []
    current_cell = []
    in_cell = False

    with open(file_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue  # Skip empty lines

            # Start of a cell block
            if line == "BOUNDARY":
                in_cell = True
                current_cell = {}
            # End of a cell block
            elif line == "ENDEL" and in_cell:
                cells.append(current_cell)
                xy_started = False
                in_cell = False
            # If inside a cell, accumulate the cell content
            elif in_cell:
                if line.startswith("LAYER"):
                    parts = line.split()
                    if len(parts) >= 2:
                        try:
                            current_cell["Layer"] = int(parts[1])
                        except ValueError:
                            current_cell["Layer"] = parts[1]
                elif line.startswith("DATATYPE"):
                    parts = line.split()
                    if len(parts) >= 2:
                        try:
                            current_cell["Datatype"] = int(parts[1])
                        except ValueError:
                            current_cell["Datatype"] = parts[1]
                # Process the XY coordinate lines
                elif line.startswith("XY"):
                    # The first XY line contains a coordinate pair after the keyword
                    parts = line.split(maxsplit=1)
                    if len(parts) > 1:
                        coord = parse_coordinate(parts[1])
                        if coord:
                            current_cell["XY"] = [coord]
                    xy_started = True
                else:
                    # Additional coordinate lines (once XY has started)
                    if xy_started:
                        coord = parse_coordinate(line)
                        if coord:
                            current_cell["XY"].append(coord)
            # Otherwise, it's part of the header
            else:
                header.append(line)

    return header, cells


def cells_to_tensor(cells):
    tensor_data = []

    for cell in cells:
        cell_data = []
        cell_data.append(cell["Layer"])
        xy = cell.get("XY", [])
        if not xy:
            continue
        # Remove the repeated last coordinate if it is the same as the first.
        if xy[0] == xy[-1]:
            xy = xy[:-1]
        if len(xy) != 4:
            raise ValueError(f"Expected 4 unique coordinates per cell, got {len(xy)} in cell: {cell}")
        for coord in xy:
            cell_data.extend([coord[0], coord[1]])
        tensor_data.append(cell_data)

    tensor = torch.tensor(tensor_data, dtype=torch.float32)
    return tensor



def write_layout_file(file_path, header, cells):
    """
    Write the layout text file using the header and cell information.

    Args:
        file_path (str): Path to write the output file.
        header (list of str): Header lines.
        cells (list of dict): Cell definitions. Each cell should contain:
            - "Layer": The layer number.
            - "Datatype": The datatype.
            - "XY": A list of (x, y) coordinate tuples.
    """
    with open(file_path, "w") as f:
        # Write header lines
        for line in header:
            if 'BGNSTR' in line:
                f.write("\n"+ line + "\n")
            elif "ENDLIB" in line:
                continue
            elif "ENDSTR" in line:
                continue
            else:
                f.write(line + "\n")
        f.write("\n")  # Optional blank line after header

        # Write each cell block
        for cell in cells:
            f.write("BOUNDARY\n")
            if "Layer" in cell:
                f.write(f"LAYER {cell['Layer']}\n")
            if "Datatype" in cell:
                f.write(f"DATATYPE {cell['Datatype']}\n")
            if "XY" in cell and cell["XY"]:
                coords = cell["XY"]
                # Write the first coordinate prefixed with "XY"
                first = coords[0]
                f.write("XY " + f"{first[0]}: {first[1]}\n")
                # Write remaining coordinates on separate lines
                for coord in coords[1:]:
                    f.write(f"{coord[0]}: {coord[1]}\n")
            f.write("ENDEL\n\n")

        f.write("ENDSTR\n")
        f.write("ENDLIB")


def split_cell(cell, resolution):
    """
    Splits a rectangular cell into subcells based on the given resolution.
    Assumes the cell's polygon is rectangular (i.e. its bounding box is the cell).


    Args:
        cell (dict): A dictionary with keys "Layer", "Datatype", and "XY" (a list of (x, y) tuples).
        resolution (int or float): Maximum allowed width/height for each subcell.

    Returns:
        List[dict]: A list of new cell dictionaries representing subcells.

    """
    coords = cell.get("XY", [])
    if not coords:
        return [cell]

    # Compute the bounding box of the cell.
    xs = [pt[0] for pt in coords]
    ys = [pt[1] for pt in coords]
    min_x, max_x = min(xs), max(xs)
    min_y, max_y = min(ys), max(ys)
    width = max_x - min_x
    height = max_y - min_y



    # Determine how many splits are needed in x and y directions.
    # n_x = math.ceil(width / resolution) if width > resolution else 1
    # n_y = math.ceil(height / resolution) if height > resolution else 1
    n_x = (max_x // resolution) - (min_x // resolution) + 1
    n_y = (max_y // resolution) - (min_y // resolution) + 1

    new_cells = []
    # For each subdivision, calculate the new bounding box and build a new cell.
    for i in range(n_x):
        sub_min_x = max(((min_x // resolution)+i) * resolution, min_x)
        # Make sure we don't exceed the original max_x.
        sub_max_x = min(((min_x // resolution)+i+1) * resolution, max_x)
        for j in range(n_y):
            sub_min_y = max(((min_y // resolution)+j) * resolution, min_y)
            sub_max_y = min(((min_y // resolution)+j+1) * resolution, max_y)
            # Build a rectangular polygon (closed polygon: first coordinate repeated at end)
            new_cell = {
                "Layer": cell.get("Layer"),
                "Datatype": cell.get("Datatype"),
                "XY": [
                    (sub_min_x, sub_min_y),
                    (sub_min_x, sub_max_y),
                    (sub_max_x, sub_max_y),
                    (sub_max_x, sub_min_y),
                    (sub_min_x, sub_min_y)
                ]
            }
            new_cells.append(new_cell)
    return new_cells


def split_cells(cells, resolution):
    """
    Processes a list of cells, splitting any cell whose width or height exceeds the resolution.

    Args:
        cells (list of dict): List of cell dictionaries.
        resolution (int or float): The maximum allowed horizontal/vertical length for each cell.

    Returns:
        list of dict: A new list of cell dictionaries (some may have been split).
    """
    new_cells = []
    for cell in tqdm(cells):
        coords = cell.get("XY", [])
        if not coords:
            continue
        xs = [pt[0] for pt in coords]
        ys = [pt[1] for pt in coords]
        # width = max(xs) - min(xs)
        # height = max(ys) - min(ys)
        n_x = (max(xs) // resolution) - (min(xs) // resolution)
        n_y = (max(ys) // resolution) - (min(ys) // resolution)
        # If either dimension exceeds resolution, split the cell.
        # if width > resolution or height > resolution:
        #     new_cells.extend(split_cell(cell, resolution))
        if n_x > 0 or n_y > 0:
            new_cells.extend(split_cell(cell, resolution))
        else:
            new_cells.append(cell)
    return new_cells


def tensor_to_cells(tensor):
    """
    Convert a 2D PyTorch tensor (of shape [N, 9]) back to a list of cell dictionaries.

    Each row of the tensor is expected to have:
      - The first element: the cell's Layer (a number)
      - The next eight elements: the coordinates for 4 points (x0, y0, x1, y1, x2, y2, x3, y3)

    The function rebuilds the cell dictionary as:
      {
          "Layer": <layer>,
          "Datatype" : 100
          "XY": [(x0,y0), (x1,y1), (x2,y2), (x3,y3), (x0,y0)]
      }
    0: LAYOUT END
    1: LAYOUT 515
    2: LAYOUT 644
    3: LAYOUT 1457

    Args:
        tensor (torch.Tensor): A 2D tensor of shape (N, 9) with cell data.

    Returns:
        List[dict]: A list of cell dictionaries.
    """
    cells = []
    # Convert the tensor to a list of lists for easier iteration.
    tensor_data = tensor.tolist()
    for row in tensor_data:
        if len(row) != 9:
            raise ValueError(f"Expected 9 values per cell, got {len(row)} in row {row}")
        # First value is the layer.
        layer = int(row[0])
        if layer == 0 or layer == 515:
            layer = 515
        elif layer == 1 or layer == 644:
            layer = 644
        elif layer == 2 or layer == 1457:
            layer = 1457
        else:
            break
        # The remaining 8 values are coordinates for 4 points.
        coords_flat = row[1:]
        if len(coords_flat) != 8:
            raise ValueError(f"Expected 8 coordinate values, got {len(coords_flat)}")
        # Group the flat coordinate list into 4 pairs.
        xy = []
        for i in range(0, len(coords_flat), 2):
            x = int(coords_flat[i])
            y = int(coords_flat[i + 1])
            xy.append((x, y))
        # Append the first coordinate again to close the polygon.
        xy.append(xy[0])
        # Build the cell dictionary.
        cell = {"Layer": layer, "Datatype" : 100,"XY": xy}
        cells.append(cell)
    return cells

def normalize_data(cell_tensor, resolution):
    normalized_tensor = deepcopy(cell_tensor)
    normalized_tensor = torch.FloatTensor(normalized_tensor)
    normalized_tensor[cell_tensor[:,0] == 515,0] = 0
    normalized_tensor[cell_tensor[:,0] == 644,0] = 1
    normalized_tensor[cell_tensor[:,0] == 1457,0] = 2
    normalized_tensor[:,1:] = (normalized_tensor[:,1:]/(resolution /2))-1
    return normalized_tensor

def prepare_batch(cell_tensor, max_len):
    keys = torch.ones(len(cell_tensor),dtype=torch.long)
    padding_key = torch.zeros(max_len - len(cell_tensor),dtype=torch.long)
    padding = torch.zeros((max_len - len(cell_tensor),9),dtype=torch.float)
    pad_tensor = torch.cat((cell_tensor, padding), 0)
    mask_idx = (1-torch.cat((keys, padding_key), 0)) == 1

    return pad_tensor, mask_idx
