"""
Utility functions for data visualization task disambiguation.
"""
import re
import json


def extract_code_from_response(response: str) -> str:
    """Extract Python code from LLM response."""
    if "```python" in response:
        code = response.split("```python")[1]
        code = code.split("```")[0]
        return code.strip()
    elif "```" in response:
        code = response.split("```")[1]
        code = code.split("```")[0]
        return code.strip()
    return response.strip()


def normalize_viz_code(code: str) -> str:
    """Normalize visualization code for comparison."""
    # Remove comments
    code = re.sub(r"#.*$", "", code, flags=re.MULTILINE)
    # Normalize whitespace
    code = re.sub(r"\s+", " ", code)
    return code.strip()


def detect_chart_type(code: str) -> str:
    """Detect the chart type from visualization code."""
    code_lower = code.lower()
    
    chart_patterns = {
        "heatmap": ["heatmap", "imshow", "pcolormesh"],
        "scatter": ["scatter", "scatterplot"],
        "line": ["plot(", "lineplot", "line("],
        "bar": ["bar(", "barplot", "barh("],
        "histogram": ["hist(", "histplot", "histogram"],
        "pie": ["pie("],
        "box": ["boxplot", "box("],
        "violin": ["violinplot", "violin("],
        "area": ["fill_between", "stackplot", "area("],
        "contour": ["contour", "contourf"],
    }
    
    for chart_type, patterns in chart_patterns.items():
        for pattern in patterns:
            if pattern in code_lower:
                return chart_type
    return "unknown"


def detect_library(code: str) -> str:
    """Detect the visualization library used."""
    if "plotly" in code.lower() or "px." in code or "go." in code:
        return "plotly"
    elif "seaborn" in code.lower() or "sns." in code:
        return "seaborn"
    elif "altair" in code.lower() or "alt." in code:
        return "altair"
    elif "bokeh" in code.lower():
        return "bokeh"
    elif "matplotlib" in code.lower() or "plt." in code:
        return "matplotlib"
    return "matplotlib"  # default


def detect_color_scheme(code: str) -> str:
    """Detect color scheme from code."""
    code_lower = code.lower()
    
    schemes = {
        "viridis": ["viridis"],
        "plasma": ["plasma"],
        "inferno": ["inferno"],
        "magma": ["magma"],
        "cividis": ["cividis"],
        "coolwarm": ["coolwarm"],
        "blues": ["blues"],
        "reds": ["reds"],
        "greens": ["greens"],
        "spectral": ["spectral"],
        "rainbow": ["rainbow"],
        "jet": ["jet"],
        "default": [],
    }
    
    for scheme, patterns in schemes.items():
        for pattern in patterns:
            if pattern in code_lower:
                return scheme
    return "default"


def extract_viz_features(code: str) -> dict:
    """Extract visualization features for clustering."""
    return {
        "chart_type": detect_chart_type(code),
        "library": detect_library(code),
        "color_scheme": detect_color_scheme(code),
        "has_title": "title" in code.lower(),
        "has_legend": "legend" in code.lower(),
        "has_grid": "grid" in code.lower(),
        "has_labels": "label" in code.lower(),
        "interactive": "plotly" in code.lower() or "bokeh" in code.lower(),
    }


def features_to_text(features: dict) -> str:
    """Convert features dict to text for embedding."""
    parts = [
        f"chart:{features['chart_type']}",
        f"lib:{features['library']}",
        f"color:{features['color_scheme']}",
    ]
    if features.get("has_title"):
        parts.append("titled")
    if features.get("has_legend"):
        parts.append("legend")
    if features.get("has_grid"):
        parts.append("grid")
    if features.get("interactive"):
        parts.append("interactive")
    return " ".join(parts)
