import logging
import re
import numpy as np
import os
import yaml

try:
    from shapely.geometry import Polygon
    from shapely.affinity import affine_transform
    SHAPELY_AVAILABLE = True
except ImportError:
    SHAPELY_AVAILABLE = False
    eval_logger = logging.getLogger("lmms-eval")
    eval_logger.warning("Shapely not available. Polygon IoU computation will use approximation.")

from datasets import Dataset

eval_logger = logging.getLogger("lmms-eval")

COCO_POLY_METRICS = ["IoU", "ACC@0.1", "ACC@0.3", "ACC@0.5", "ACC@0.7", "ACC@0.9", "Center_ACC"]


def get_polygon_points_from_metadata(metadata=None):
    """Get number of polygon points from metadata/env (default 16)."""
    import os
    env_points = os.getenv('POLYGON_POINTS')
    if env_points:
        try:
            return int(env_points)
        except ValueError:
            pass
    
    if metadata and 'polygon_points' in metadata:
        return int(metadata['polygon_points'])
    return 16


def refcoco_poly_preprocess_dataset(dataset: Dataset, metadata=None):
    dataset = dataset.map(lambda x: {"image_width": x["image"].width, "image_height": x["image"].height})

    def process_segmentation(example):
        num_points = get_polygon_points_from_metadata(metadata)
        segmentation = example["segmentation"]
        if not segmentation or len(segmentation) == 0:
            bbox = example["bbox"]
            x, y, w, h = bbox
            polygon_coords = [x, y, x+w, y, x+w, y+h, x, y+h]
            while len(polygon_coords) < 2 * num_points:
                polygon_coords.extend([x, y, x+w, y, x+w, y+h, x, y+h])
            polygon_coords = polygon_coords[:2 * num_points]
        else:
            polygon_coords = segmentation[0] if isinstance(segmentation[0], list) else segmentation
            polygon_coords = normalize_polygon(polygon_coords, target_points=num_points)
        
        normalized_coords = []
        for i in range(0, len(polygon_coords), 2):
            normalized_coords.append(polygon_coords[i] / example["image_width"])
            normalized_coords.append(polygon_coords[i+1] / example["image_height"])
        
        return {"polygon": normalized_coords}
    
    dataset = dataset.select(range(20))
    dataset = dataset.map(process_segmentation)

    def explode_answers(example):
        answers = example.pop("answer")
        return [{"answer": answer, **example} for answer in answers]

    exploded_rows = []
    for example in dataset:
        exploded_rows.extend(explode_answers(example))

    new_dataset = Dataset.from_list(exploded_rows)
    print(f"Exploded dataset from {len(dataset)} to {len(new_dataset)} rows")

    return new_dataset


def normalize_polygon(segmentation_data, target_points=16):
    """Normalize a polygon to a fixed number of points (default 16)."""
    if not segmentation_data or len(segmentation_data) < 6:  # At least 3 points (6 coordinates)
        return [0] * (target_points * 2)

    # Convert flat list to points array
    points = np.array(segmentation_data).reshape(-1, 2)
    num_points = len(points)

    if num_points <= 1:
        return [0] * (target_points * 2)
        
    # If we have the exact number of points, return as is
    if num_points == target_points:
        return segmentation_data

    if num_points < target_points:
        while len(points) < target_points:
            edge_lengths = np.linalg.norm(np.roll(points, -1, axis=0) - points, axis=1)
            longest_edge_index = np.argmax(edge_lengths)
            p1 = points[longest_edge_index]
            p2 = points[(longest_edge_index + 1) % len(points)]
            midpoint = (p1 + p2) / 2
            points = np.insert(points, longest_edge_index + 1, midpoint, axis=0)
        
        return points.flatten().tolist()

    if num_points > target_points:
        indices = np.linspace(0, num_points - 1, target_points, dtype=int)
        sampled_points = points[indices]
        return sampled_points.flatten().tolist()

    return segmentation_data


def refcoco_poly_doc_to_visual(doc):
    image = doc["image"].convert("RGB")
    return [image.convert("RGB")]


def refcoco_poly_doc_to_text(doc, metadata=None):
    num_points = get_polygon_points_from_metadata(metadata)
    assert isinstance(doc["answer"], str), "Answer must be a string"
    return (
        f"Please provide the {num_points} points polygon coordinate of the region this sentence describes: "
        + doc["answer"]
    )

def parse_polygon_sequence_within(input_str, metadata=None):
    """
    Extract the first sequence of 2*num_points floating-point numbers within square brackets from a string.

    Args:
    input_str (str): A string that may contain a sequence of 2*num_points floats within square brackets.
    metadata (dict): Metadata containing polygon_points configuration.

    Returns:
    list: A list of 2*num_points floats if the pattern is found, or a list of 2*num_points zeros if the pattern is not found.
    """
    num_points = get_polygon_points_from_metadata(metadata)
    float_pattern = r"(-?\d+(?:\.\d+)?)"
    comma_pattern = r",\s*"
    
    pattern_parts = [float_pattern] + [comma_pattern + float_pattern] * (2 * num_points - 1)
    pattern = r"\[\s*" + "".join(pattern_parts) + r"\s*\]"

    match = re.search(pattern, input_str)

    if match:
        vals = [float(match.group(i)) for i in range(1, 2 * num_points + 1)]
        # Normalize predictions: if values look like 0-999 (or >1), scale by 1000 to [0,1]
        if any(abs(v) > 1.0 for v in vals):
            vals = [max(0.0, min(1.0, v / 1000.0)) for v in vals]
        return vals

    simple_pattern = r"\[([\d\.,\s\-]+)\]"
    simple_match = re.search(simple_pattern, input_str)
    if simple_match:
        numbers_str = simple_match.group(1)
        try:
            numbers = [float(x.strip()) for x in numbers_str.split(',') if x.strip()]
            # Same normalization as above
            if any(abs(v) > 1.0 for v in numbers):
                numbers = [max(0.0, min(1.0, v / 1000.0)) for v in numbers]
            # Pad or truncate to 2*num_points numbers
            if len(numbers) < 2 * num_points:
                numbers.extend([0.0] * (2 * num_points - len(numbers)))
            return numbers[:2 * num_points]
        except ValueError:
            pass

    return [0.0] * (2 * num_points)


def refcoco_poly_process_result(doc, result, metadata=None):
    """
    Args:
        doc: a instance of the eval dataset
        results: [pred]
        metadata: metadata containing polygon_points configuration
    Returns:
        a dictionary with key: metric name, value: metric value
    """
    pred = result[0] if len(result) > 0 else ""
    pred = parse_polygon_sequence_within(pred, metadata)
    ann_id = doc["question_id"]
    data_dict = {"answer": doc["answer"], "pred": pred, "ann_id": ann_id, "polygon": doc["polygon"]}
    return {f"refcoco_poly_{metric}": data_dict for metric in COCO_POLY_METRICS}


def compute_polygon_iou(poly1_coords, poly2_coords):
    """
    Compute the Intersection over Union (IoU) of two polygons.

    Parameters:
    - poly1_coords (list of float): Polygon coordinates [x1, y1, x2, y2, ..., x16, y16].
    - poly2_coords (list of float): Polygon coordinates [x1, y1, x2, y2, ..., x16, y16].

    Returns:
    - float: IoU of poly1 and poly2.
    """
    if SHAPELY_AVAILABLE:
        try:
            points1 = [(poly1_coords[i], poly1_coords[i+1]) for i in range(0, len(poly1_coords), 2)]
            points2 = [(poly2_coords[i], poly2_coords[i+1]) for i in range(0, len(poly2_coords), 2)]
            
            polygon1 = Polygon(points1)
            polygon2 = Polygon(points2)
            
            if not polygon1.is_valid:
                polygon1 = polygon1.buffer(0)
            if not polygon2.is_valid:
                polygon2 = polygon2.buffer(0)
            
            intersection = polygon1.intersection(polygon2)
            union = polygon1.union(polygon2)
            
            if union.area == 0:
                return 0.0
            
            iou = intersection.area / union.area
            return max(0.0, min(1.0, iou))  # Clamp to [0, 1]
            
        except Exception as e:
            eval_logger.warning(f"Error computing polygon IoU with Shapely: {e}")
    
    # No fallback without Shapely
    return 0.0


def compute_polygon_accuracy(poly1_coords, poly2_coords, threshold=0.5):
    """Thresholded accuracy using polygon IoU (no bbox fallback)."""
    iou = compute_polygon_iou(poly1_coords, poly2_coords)
    return iou >= threshold


def compute_polygon_center_accuracy(poly1_coords, poly2_coords):
    """Check if centroid of poly2 lies within poly1 (no bbox fallback)."""
    if SHAPELY_AVAILABLE:
        try:
            points1 = [(poly1_coords[i], poly1_coords[i+1]) for i in range(0, len(poly1_coords), 2)]
            points2 = [(poly2_coords[i], poly2_coords[i+1]) for i in range(0, len(poly2_coords), 2)]
            polygon1 = Polygon(points1)
            polygon2 = Polygon(points2)
            if not polygon1.is_valid:
                polygon1 = polygon1.buffer(0)
            if not polygon2.is_valid:
                polygon2 = polygon2.buffer(0)
            centroid = polygon2.centroid
            return polygon1.contains(centroid)
        except Exception as e:
            eval_logger.warning(f"Error computing polygon center accuracy with Shapely: {e}")
    return False
    


def refcoco_poly_aggregation_result(results, metric):
    """
    Aggregate the results of the RefCOCO polygon evaluation task using the specified metric.

    Args:
    - results (list of dict): List of result dictionaries.
    - metric (str): Metric to use for aggregation.

    Returns:
    - dict: Dictionary containing the aggregated results for the specified metric.
    """
    scorers = {
        "IoU": compute_polygon_iou,
        "ACC@0.1": lambda x, y: compute_polygon_accuracy(x, y, 0.1),
        "ACC@0.3": lambda x, y: compute_polygon_accuracy(x, y, 0.3),
        "ACC@0.5": lambda x, y: compute_polygon_accuracy(x, y, 0.5),
        "ACC@0.7": lambda x, y: compute_polygon_accuracy(x, y, 0.7),
        "ACC@0.9": lambda x, y: compute_polygon_accuracy(x, y, 0.9),
        "Center_ACC": compute_polygon_center_accuracy,
    }
    results_dict = {metric: []}
    for result in results:
        # Extract the ground truth and predicted polygons
        gt_polygon = result["polygon"]
        pred_polygon = result["pred"]
        # Compute the specified metric between the ground truth and predicted polygons
        score = scorers[metric](gt_polygon, pred_polygon)
        results_dict[metric].append(score)
    results_dict[metric] = sum(results_dict[metric]) / len(results_dict[metric])
    print(f"Aggregated {metric} score: {results_dict[metric]}")
    return results_dict[metric]


def refcoco_poly_iou(results):
    return refcoco_poly_aggregation_result(results, "IoU")


def refcoco_poly_acc01(results):
    return refcoco_poly_aggregation_result(results, "ACC@0.1")


def refcoco_poly_acc03(results):
    return refcoco_poly_aggregation_result(results, "ACC@0.3")


def refcoco_poly_acc05(results):
    return refcoco_poly_aggregation_result(results, "ACC@0.5")


def refcoco_poly_acc07(results):
    return refcoco_poly_aggregation_result(results, "ACC@0.7")


def refcoco_poly_acc09(results):
    return refcoco_poly_aggregation_result(results, "ACC@0.9")


def refcoco_poly_center_acc(results):
    return refcoco_poly_aggregation_result(results, "Center_ACC")
