import os
import json
from typing import Dict, Any, List
from functools import wraps
import time
import re
from schema import CoK


def time_llm_call(func):
    """Simple timing decorator for LLM calls."""

    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        duration = end_time - start_time

        # Report timing if flag is set
        if os.getenv("ENABLE_LLM_TIMING", "false").lower() == "true":
            print(f"LLM call took {duration:.3f} seconds")

        return result

    return wrapper


def _validate_llm_config(llm_config: Dict[str, Any]) -> None:
    """Validate LLM configuration based on provider."""
    provider = llm_config["provider"]

    if provider == "azure":
        required = ["endpoint", "api_version", "model_name"]
        missing = [f for f in required if not llm_config.get(f)]
        if missing:
            raise ValueError(f"Azure provider missing required fields: {missing}")

    elif provider == "ollama":
        if not llm_config.get("model_name"):
            raise ValueError(f"{provider} provider requires model_name")
        if provider == "ollama" and not llm_config.get("base_url"):
            llm_config["base_url"] = "http://localhost:11434"  # Default Ollama URL


def create_client(llm_config: Dict[str, Any]):
    """Create LLM client based on dynamic configuration.

    Example configurations:

    Azure:
    {
        "provider": "azure",
        "endpoint": "https://...",
        "api_version": "2025-03-01-preview",
        "model_name": "gpt-4",
        "use_managed_identity": true
    }

    Ollama:
    {
        "provider": "ollama",
        "model_name": "llama2",
        "base_url": "http://localhost:11434"
    }
    """

    _validate_llm_config(llm_config)

    provider = llm_config["provider"]

    if provider == "azure":
        from azure.identity import AzureCliCredential, get_bearer_token_provider
        from openai import AzureOpenAI

        if llm_config.get("use_managed_identity", True):
            token_provider = get_bearer_token_provider(
                AzureCliCredential(), "https://cognitiveservices.azure.com/.default"
            )
            return AzureOpenAI(
                api_version=llm_config["api_version"],
                azure_endpoint=llm_config["endpoint"],
                azure_ad_token_provider=token_provider,
            )
        else:
            api_key = llm_config.get("api_key") or os.getenv("AZURE_OPENAI_API_KEY")
            if not api_key:
                raise ValueError(
                    "Azure provider with use_managed_identity=false requires api_key"
                )

            return AzureOpenAI(
                api_version=llm_config["api_version"],
                azure_endpoint=llm_config["endpoint"],
                api_key=api_key,
            )

    elif provider == "ollama":
        try:
            import ollama
        except ImportError:
            raise ValueError(
                "Ollama provider requires 'ollama' package. Install with: pip install ollama"
            )

        base_url = llm_config.get("base_url", "http://localhost:11434")
        return ollama.Client(host=base_url)

    else:
        raise ValueError(
            f"Unsupported provider: {provider}. Supported: azure, openai, litellm, ollama"
        )


def clean_response(response: str) -> str:
    """Simple response cleaning - extract JSON if present."""
    if not response:
        return response

    # If we find a "{", remove everything before it (likely preamble text)
    first_brace = response.find("{")
    if first_brace > 0:
        # Keep everything from the first "{" onwards
        response = response[first_brace:]

    return response.strip()


def extract_json_from_markdown_blocks(response: str) -> str:
    """Extract JSON from markdown code blocks (backticks) if present."""
    if not response:
        return response

    # Remove markdown code blocks with json, JSON, or no language specifier
    import re

    # Pattern to match ```json, ```JSON, or just ``` followed by JSON content
    patterns = [
        r"```(?:json|JSON)?\s*\n?(.*?)\n?```",  # Standard markdown blocks
        r"`{3,}\s*(?:json|JSON)?\s*\n?(.*?)\n?`{3,}",  # Variable length backticks
    ]

    for pattern in patterns:
        match = re.search(pattern, response, re.DOTALL | re.IGNORECASE)
        if match:
            return match.group(1).strip()

    return response.strip()


def find_outermost_json_braces(response: str) -> str:
    """Find and extract content between outermost curly braces."""
    if not response:
        return response

    # Find the first opening brace
    start_idx = response.find("{")
    if start_idx == -1:
        return response

    # Track brace depth to find matching closing brace
    brace_count = 0
    end_idx = -1

    for i, char in enumerate(response[start_idx:], start_idx):
        if char == "{":
            brace_count += 1
        elif char == "}":
            brace_count -= 1
            if brace_count == 0:
                end_idx = i
                break

    if end_idx != -1:
        return response[start_idx : end_idx + 1]

    return response[
        start_idx:
    ]  # Return from first brace to end if no matching closing brace


def robust_json_extraction(response: str) -> str:
    """Robust JSON extraction handling multiple formats including gpt-oss:20b quirks."""
    if not response:
        return response

    original_response = response.strip()

    # Step 1: Try to extract from markdown code blocks
    extracted = extract_json_from_markdown_blocks(original_response)

    # Step 2: If that didn't work or didn't find blocks, try outermost braces
    if extracted == original_response or not extracted.startswith("{"):
        extracted = find_outermost_json_braces(extracted)

    # Step 3: Fallback to original clean_response logic
    if not extracted.startswith("{"):
        extracted = clean_response(extracted)

    return extracted


def clean_dspy_response(response: str) -> str:
    """Clean response by removing think tokens and extracting content between summary markers."""
    if not response:
        return response

    # Extract content between [[ ## summary ## ]] and [[ ## completed ## ]] markers
    summary_pattern = (
        r"\[\[\s*##\s*summary\s*##\s*\]\](.*?)\[\[\s*##\s*completed\s*##\s*\]\]"
    )
    summary_match = re.search(
        summary_pattern, response, flags=re.DOTALL | re.IGNORECASE
    )

    if summary_match:
        return summary_match.group(1).strip()

    # If completed marker not present, take everything after summary marker
    summary_start_pattern = r"\[\[\s*##\s*summary\s*##\s*\]\](.*)"
    summary_start_match = re.search(
        summary_start_pattern, response, flags=re.DOTALL | re.IGNORECASE
    )

    if summary_start_match:
        return summary_start_match.group(1).strip()

    return response.strip()


def extract_json_from_response(response: str) -> str:
    """Extract JSON from response text."""
    # Try to parse the entire cleaned response first
    try:
        json.loads(response)
        return response
    except json.JSONDecodeError:
        # Fallback: wrap in basic JSON structure
        escaped_content = response.replace('"', '\\"').replace("\n", "\\n")
        return f'{{"summary": "{escaped_content}"}}'


def parse_to_summary(response: str) -> dict:
    """Parse response to plain dictionary with robust JSON extraction."""
    try:
        # First try direct JSON parsing
        data = json.loads(response)
    except (json.JSONDecodeError, ValueError):
        json_str = extract_json_from_response(response)
        data = json.loads(json_str)

    if "attributes" not in data:
        data = {"attributes": data}
    return data


def parse_to_cok(response: str):
    """Parse response to CoK format as plain dict with robust JSON extraction."""
    try:
        # First try direct model validation
        cok_obj = CoK.model_validate_json(response)
        # Convert to plain dict structure
        return {
            "UPDATE": {
                key: {"update": proposed_update.update}
                for key, proposed_update in cok_obj.UPDATE.items()
            },
            "ADD": {
                key: {"add": proposed_add.add}
                for key, proposed_add in cok_obj.ADD.items()
            },
        }
    except (json.JSONDecodeError, ValueError):
        # Try robust extraction before fallback
        try:
            data = json.loads(response)
            return {"UPDATE": data.get("UPDATE", {}), "ADD": data.get("ADD", {})}
        except (json.JSONDecodeError, ValueError):
            # Final fallback to empty objects
            return {"UPDATE": {}, "ADD": {}}


def merge_summary_objects(summaries: List[dict]) -> dict:
    """Merge multiple summary dictionaries into one."""
    merged_attributes = {}

    for summary in summaries:
        for key, values in summary["attributes"].items():
            if key not in merged_attributes:
                merged_attributes[key] = []
            if isinstance(values, list):
                merged_attributes[key].extend(values)
            else:
                merged_attributes[key].append(values)

    return merged_attributes


def apply_cok_operations(summary: dict, operations: dict) -> dict:
    """Apply CoK operations to a summary dictionary."""
    result_attributes = summary.copy()

    # Apply updates
    for jsonpath, update_data in operations["UPDATE"].items():
        clean_path = jsonpath.strip('"').replace("$.attributes.", "").replace("$.", "")

        # Fallback for when summary doesn't do a list of strs
        if isinstance(result_attributes.get(clean_path), str):
            result_attributes[clean_path] = [result_attributes[clean_path]]

        if clean_path in result_attributes:
            if isinstance(update_data, dict) and "update" in update_data:
                new_values = update_data["update"]
            else:
                new_values = update_data

            if not isinstance(new_values, list):
                new_values = [new_values] if new_values else []

            # Update existing values
            for val in new_values:
                result_attributes[clean_path].append(val)

    # Apply additions
    for jsonpath, add_data in operations["ADD"].items():
        clean_path = jsonpath.strip('"').replace("$.attributes.", "").replace("$.", "")
        if isinstance(add_data, dict) and "add" in add_data:
            new_values = add_data["add"]
        else:
            new_values = add_data

        if not isinstance(new_values, list):
            new_values = [new_values] if new_values else []

        result_attributes[clean_path] = new_values

    return result_attributes
