"""
HuggingFace LLM Wrapper for Robot Ontology

Provides a simple interface for local LLM inference with JSON-only output.
Falls back to DummyLLM if model loading fails.
"""

import json
import re
from typing import Dict, Any, Optional
from abc import ABC, abstractmethod


class BaseLLM(ABC):
    """Abstract base class for LLM wrappers."""

    @abstractmethod
    def generate_json(
        self,
        system_prompt: str,
        user_prompt: str,
        max_new_tokens: int = 512
    ) -> Dict[str, Any]:
        """Generate a JSON response from the LLM."""
        pass


class DummyLLM(BaseLLM):
    """
    Dummy LLM that returns empty candidates.
    Used when no real LLM is available or as fallback.
    """

    def __init__(self):
        self.name = "DummyLLM"

    def generate_json(
        self,
        system_prompt: str,
        user_prompt: str,
        max_new_tokens: int = 512
    ) -> Dict[str, Any]:
        """Return empty proposals."""
        return {
            "proposed": [],
            "equivalences": [],
            "decompositions": [],
            "merges": [],
            "_meta": {
                "model": "DummyLLM",
                "note": "No real LLM available, returning empty candidates"
            }
        }


def repair_json(text: str) -> str:
    """
    Attempt to repair common JSON issues.

    Repairs:
    - Trim before first '{' and after last '}'
    - Replace trailing commas
    - Fix unquoted keys (basic)
    """
    # Find JSON boundaries
    start = text.find('{')
    end = text.rfind('}')

    if start == -1 or end == -1:
        # Try array
        start = text.find('[')
        end = text.rfind(']')

    if start == -1 or end == -1:
        return text

    text = text[start:end + 1]

    # Remove trailing commas before } or ]
    text = re.sub(r',\s*([}\]])', r'\1', text)

    # Try to fix common issues with unquoted keys (very basic)
    # This won't work for all cases but handles simple ones
    text = re.sub(r'(\{|\,)\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*:', r'\1"\2":', text)

    return text


def parse_json_safely(text: str) -> Dict[str, Any]:
    """
    Parse JSON with repair attempts.

    Returns parsed dict or error dict if all attempts fail.
    """
    # First try direct parse
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    # Try repair
    repaired = repair_json(text)
    try:
        return json.loads(repaired)
    except json.JSONDecodeError:
        pass

    # Try to find JSON in the text (e.g., surrounded by markdown)
    json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
    if json_match:
        try:
            return json.loads(json_match.group(1))
        except json.JSONDecodeError:
            pass

    # All attempts failed
    return {
        "error": "parse_failed",
        "raw": text[:500]  # Truncate for safety
    }


class HFLLM(BaseLLM):
    """
    HuggingFace LLM wrapper using transformers.

    Supports local models via AutoModelForCausalLM or pipeline.
    Enforces JSON-only output with repair attempts.
    """

    def __init__(
        self,
        model_name_or_path: str,
        device: str = "auto",
        use_pipeline: bool = True
    ):
        """
        Initialize the LLM.

        Args:
            model_name_or_path: HuggingFace model name or local path
            device: Device to use ("auto", "cuda", "cpu")
            use_pipeline: Use transformers pipeline (simpler) vs raw model
        """
        self.model_name = model_name_or_path
        self.device = device
        self.use_pipeline = use_pipeline
        self.model = None
        self.tokenizer = None
        self.pipe = None

        self._load_model()

    def _load_model(self):
        """Load the model and tokenizer."""
        try:
            from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
            import torch

            if self.use_pipeline:
                # Use pipeline for simplicity
                self.pipe = pipeline(
                    "text-generation",
                    model=self.model_name,
                    device_map=self.device if self.device != "auto" else "auto",
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    trust_remote_code=True
                )
            else:
                # Load model and tokenizer separately
                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.model_name,
                    trust_remote_code=True
                )
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    device_map=self.device if self.device != "auto" else "auto",
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    trust_remote_code=True
                )

        except ImportError as e:
            raise ImportError(
                f"transformers/torch not installed. Install with: pip install transformers torch\n"
                f"Original error: {e}"
            )
        except Exception as e:
            raise RuntimeError(f"Failed to load model {self.model_name}: {e}")

    def generate_json(
        self,
        system_prompt: str,
        user_prompt: str,
        max_new_tokens: int = 512
    ) -> Dict[str, Any]:
        """
        Generate a JSON response.

        Constructs a prompt that instructs the model to return only valid JSON,
        then parses and repairs the output as needed.
        """
        # Build prompt with JSON instruction
        full_prompt = f"""<|system|>
{system_prompt}

IMPORTANT: Return ONLY valid JSON. No explanations, no markdown, just the JSON object.
</s>
<|user|>
{user_prompt}
</s>
<|assistant|>
"""
        # Alternative prompt format for some models
        if "llama" in self.model_name.lower() or "qwen" in self.model_name.lower():
            full_prompt = f"""### System:
{system_prompt}

IMPORTANT: Return ONLY valid JSON. No explanations, no markdown, just the JSON object.

### User:
{user_prompt}

### Assistant:
"""

        try:
            if self.pipe is not None:
                # Use pipeline
                result = self.pipe(
                    full_prompt,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    return_full_text=False
                )
                generated_text = result[0]['generated_text']
            else:
                # Use raw model
                inputs = self.tokenizer(full_prompt, return_tensors="pt")
                inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id
                )

                generated_text = self.tokenizer.decode(
                    outputs[0][inputs['input_ids'].shape[1]:],
                    skip_special_tokens=True
                )

            # Parse JSON
            return parse_json_safely(generated_text)

        except Exception as e:
            return {
                "error": "generation_failed",
                "message": str(e)
            }


def create_llm(
    model_name: str = None,
    device: str = "auto",
    fallback_to_dummy: bool = True
) -> BaseLLM:
    """
    Factory function to create an LLM instance.

    Args:
        model_name: Model name or path (None for DummyLLM)
        device: Device to use
        fallback_to_dummy: If True, return DummyLLM on failure

    Returns:
        BaseLLM instance
    """
    if not model_name:
        return DummyLLM()

    try:
        return HFLLM(model_name, device)
    except Exception as e:
        if fallback_to_dummy:
            print(f"Warning: Failed to load LLM ({e}), using DummyLLM")
            return DummyLLM()
        raise


# Test
if __name__ == "__main__":
    # Test DummyLLM
    llm = DummyLLM()
    result = llm.generate_json("Test", "Test")
    print(f"DummyLLM result: {result}")

    # Test JSON parsing
    test_cases = [
        '{"key": "value"}',
        '  {"key": "value",}  ',  # Trailing comma
        '```json\n{"key": "value"}\n```',  # Markdown
        'Here is the JSON: {"key": "value"} done',  # Embedded
    ]

    for tc in test_cases:
        parsed = parse_json_safely(tc)
        print(f"Input: {tc[:30]}... -> Parsed: {parsed}")
