# -*- coding: utf-8 -*-

import json
import re
import argparse
from typing import List, Dict, Any, Optional

# ============================================================================ #
# Section 1: Core Parsing and Pin Generation Logic
# ============================================================================ #

class PredictionParser:

    def __init__(self, image_order_file: str):
        """
        Initializes the parser with the image order.

        Args:
            image_order_file (str): Path to a JSON file containing a list of
                                    items, where each item has an 'images' key.
        """
        try:
            with open(image_order_file, 'r', encoding='utf-8') as f:
                order_data = json.load(f)
            self.image_paths = [item['images'][0] for item in order_data]
            print(f"Successfully loaded {len(self.image_paths)} image paths.")
        except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
            print(f"Error: Could not load or parse image order file '{image_order_file}': {e}")
            raise

    def process_file(self, prediction_file_path: str) -> List[Dict[str, Any]]:
        raw_predictions = self._load_predictions(prediction_file_path)
        if not raw_predictions:
            return []

        print(f"\nProcessing {len(raw_predictions)} raw predictions...")
        processed_results = []
        
        for i, pred_item in enumerate(raw_predictions):
            if i >= len(self.image_paths):
                print(f"Warning: More predictions ({len(raw_predictions)}) than image paths ({len(self.image_paths)}). Stopping.")
                break

            image_path = self.image_paths[i]
            gpt_response = pred_item.get("predict", "")
            
            answer_match = re.search(r'<answer>(.*?)</answer>', gpt_response, re.DOTALL)
            if not answer_match:
                self._debug_dump("PARSE_FAILURE", i, image_path, gpt_response)
                continue

            pins = self._parse_answer_content(answer_match.group(1).strip())
            if pins:
                formatted_result = self._format_conversations(pins, image_path)
                processed_results.append(formatted_result)
            else:
                self._debug_dump("PROCESS_FAILURE", i, image_path, gpt_response)

        print(f"\nSuccessfully processed {len(processed_results)} predictions.")
        return processed_results

    def _load_predictions(self, file_path: str) -> List[Dict[str, Any]]:
        """Loads raw predictions from a .jsonl file."""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return [json.loads(line) for line in f if line.strip()]
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error: Could not load prediction file '{file_path}': {e}")
            return []

    def _parse_answer_content(self, content: str) -> Optional[List[Dict]]:
        if "pin count" in content.lower():
            return self._parse_as_other_type(content)
        
        params = self._extract_final_parameters(content)
        if not params:
            return None

        if 'pin_per_side' in params: params['pins_per_side'] = params['pin_per_side']
        if 'dx' in params and 'dx1' not in params: params['dx1'] = params['dx']
        if 'dy' in params and 'dy1' not in params: params['dy1'] = params['dy']
        if 'side_to_side_distance' not in params:
            if 'inner_side_distance' in params and 'dx1' in params: params['side_to_side_distance'] = params['inner_side_distance'] + params['dx1']
            elif 'outer_side_distance' in params and 'dx1' in params: params['side_to_side_distance'] = params['outer_side_distance'] - params['dx1']
        if 'row_spacing' not in params and 'inner_row_spacing' in params and 'dy' in params: params['row_spacing'] = params['inner_row_spacing'] + params['dy']

        grid_reqs = ['row', 'column', 'row_spacing', 'column_spacing', 'diameter']
        sides4_req = ['pins_per_side', 'side_to_side_distance', 'pin_spacing_within_side', 'dx1', 'dy1']
        sides2_req = ['row', 'column', 'row_spacing', 'column_spacing', 'dx']

        if all(k in params for k in grid_reqs): return self._generate_grid_pins(**params)
        if all(k in params for k in sides4_req):
            params.setdefault('dx2', 0.0); params.setdefault('dy2', 0.0)
            return self._generate_4sides_pins(**params)
        if all(k in params for k in sides2_req):
            params.setdefault('dy', 0.0)
            return self._generate_2sides_pins(**params)

        return None

    def _extract_final_parameters(self, content: str) -> Dict[str, float]:
        """Extracts key-value numerical parameters from text."""
        match = re.search(r'Final\s+parameters\s*:([^<]+)$', content, re.IGNORECASE | re.DOTALL)
        text_to_search = match.group(1) if match else content
        
        found_params = re.findall(r'([\w\s-]+?)\s*=\s*(-?\d+(?:\.\d+)?)', text_to_search)
        
        return {
            re.sub(r'[\s-]+', '_', k.strip()).lower(): float(v) 
            for k, v in found_params
        }
    
    def _parse_as_other_type(self, content: str) -> Optional[List[Dict]]:
        """Handles parsing for the 'other' IC type."""
        coords_match = re.search(r'pin coordinates[:\s]*(.*)', content, re.IGNORECASE | re.DOTALL)
        dims_match = re.search(r'pin dimensions[:\s]*(.*)', content, re.IGNORECASE | re.DOTALL)
        
        if not (coords_match and dims_match):
            return None

        coord_pattern = re.compile(r'[^,;]+?,\s*(-?\d+\.?\d*)\s*,\s*(-?\d+\.?\d*)')
        dim_pattern = re.compile(r'[^,;]+?,\s*(-?\d+\.?\d*)\s*(?:,\s*(-?\d+\.?\d*))?')

        try:
            coordinates = [(float(x), float(y)) for x, y in coord_pattern.findall(coords_match.group(1))]
            dimensions = [(float(dx), float(dy) if dy else 0.0) for dx, dy in dim_pattern.findall(dims_match.group(1))]
        except ValueError:
            return None

        if coordinates and dimensions and len(coordinates) == len(dimensions):
            return self._generate_other_pins(coordinates, dimensions)
        return None

    def _generate_2sides_pins(self, row, column, row_spacing, column_spacing, dx, dy, **kwargs) -> List[Dict]:
        pins = []
        shape_type = "circle" if dy == 0 else "rectangle"

        for i in range(int(row)):
            y = (i - (int(row) - 1) / 2) * row_spacing
            for j in range(int(column)):
                x = (j - (int(column) - 1) / 2) * column_spacing
                if shape_type == "circle":
                    pins.append({"center": (round(x, 3), round(y, 3)), "shape": {"type": "circle", "dx": dx, "dy": 0}})
                else:
                    pins.append({"center": (round(x, 3), round(y, 3)), "shape": {"type": "rectangle", "dx": dx, "dy": dy}})
        return pins

    def _generate_4sides_pins(self, pins_per_side, side_to_side_distance, pin_spacing_within_side, dx1, dy1, dx2, dy2, **kwargs) -> List[Dict]:
        pins = []
        n = int(pins_per_side)
        half_dist = side_to_side_distance / 2
        for i in range(n):
            offset = (i - (n - 1) / 2) * pin_spacing_within_side
            pins.append({"center": (-half_dist, offset), "shape": {"type": "rectangle", "dx": dx1, "dy": dy1}})
            pins.append({"center": (half_dist, offset), "shape": {"type": "rectangle", "dx": dx1, "dy": dy1}})
            pins.append({"center": (offset, half_dist), "shape": {"type": "rectangle", "dx": dy1, "dy": dx1}})
            pins.append({"center": (offset, -half_dist), "shape": {"type": "rectangle", "dx": dy1, "dy": dx1}})
        if dx2 > 0:
            pins.append({"center": (0, 0), "shape": {"type": "rectangle", "dx": dx2, "dy": dy2 if dy2 > 0 else dx2}})
        return pins

    def _generate_grid_pins(self, row, column, row_spacing, column_spacing, diameter, **kwargs) -> List[Dict]:
        pins = []
        for i in range(int(row)):
            y = (i - (int(row) - 1) / 2) * row_spacing
            for j in range(int(column)):
                x = (j - (int(column) - 1) / 2) * column_spacing
                pins.append({"center": (round(x, 3), round(y, 3)), "shape": {"type": "circle", "dx": diameter, "dy": 0}})
        return pins
        
    def _generate_other_pins(self, coordinates: List, dimensions: List) -> List[Dict]:
        pins = []
        for (x, y), (dx, dy) in zip(coordinates, dimensions):
            shape_type = "circle" if dy == 0 else "rectangle"
            pins.append({"center": (round(x, 3), round(y, 3)), "shape": {"type": shape_type, "dx": dx, "dy": dy}})
        return pins

    def _format_conversations(self, pins: List[Dict], image_path: str) -> Dict:
        """Formats the generated pin data into the final output structure."""
        pin_count = len(pins)
        sizes = [f"{i+1},{p['shape']['dx']}" if p['shape']['type'] == "circle" else f"{i+1},{p['shape']['dx']},{p['shape']['dy']}" for i, p in enumerate(pins)]
        coords = [f"{i+1},{p['center'][0]},{p['center'][1]}" for i, p in enumerate(pins)]
        return {"conversations": [{"from": "gpt", "value": str(pin_count)}, {"from": "gpt", "value": "\n".join(sizes)}, {"from": "gpt", "value": "\n".join(coords)}], "images": [image_path]}

    def _debug_dump(self, kind: str, index: int, image_path: str, raw_text: str):
        """Prints a debug message for a failed parsing attempt."""
        ans_match = re.search(r'<answer>(.*?)</answer>', raw_text, re.DOTALL)
        ans = ans_match.group(1).strip() if ans_match else 'No <answer> tag found'
        print(f"--- {kind} [Index: {index}, Image: {image_path}] ---")
        preview = ' '.join(ans.split())
        print(f"Content: {preview[:300]}{'...' if len(preview) > 300 else ''}")
        print("-" * 50)


# ============================================================================ #
# Section 2: Main Execution Block
# ============================================================================ #

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Parse IC pin layout predictions from a model's output file.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("prediction_file", help="Path to the .jsonl prediction file.")
    parser.add_argument("image_order_file", help="Path to the JSON file that defines the order of images.")
    parser.add_argument("-o", "--output_file", default="processed_pins.json", help="Path to save the processed output JSON file.")
    
    args = parser.parse_args()

    try:
        # 1. Initialize the parser with the image order.
        pin_parser = PredictionParser(image_order_file=args.image_order_file)

        # 2. Process the prediction file to get structured pin data.
        processed_pins = pin_parser.process_file(prediction_file_path=args.prediction_file)

        # 3. Save the results to the specified output file.
        if processed_pins:
            with open(args.output_file, 'w', encoding='utf-8') as f:
                json.dump(processed_pins, f, indent=4)
            print(f"\nSuccessfully saved {len(processed_pins)} processed items to '{args.output_file}'.")
        else:
            print("\nNo predictions were successfully processed. Output file was not created.")

    except Exception as e:
        print(f"\nAn unexpected error occurred during execution: {e}")