"""
Data Utilities for IOA Framework

This module provides utilities for:
- Data preprocessing and formatting
- Synthetic data validation
- Dataset creation for training
- Evaluation metrics computation
"""

import json
import re
import logging
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)


@dataclass
class SyntheticDataItem:
    """
    Represents a synthetic data item generated by the Adapter.
    
    This follows the JSON schema from Appendix J.5.
    """
    
    # Knowledge module identifier
    module: str
    
    # Prerequisites
    prereq: List[str] = field(default_factory=list)
    
    # Difficulty tag
    difficulty_tag: str = "introductory"
    
    # Problem statement (analogy -> formal)
    problem: str = ""
    
    # Solution with steps
    solution_steps: List[str] = field(default_factory=list)
    final_answer: str = ""
    verification: str = ""
    
    # Adapter flags
    adapter_flags: Dict[str, Any] = field(default_factory=dict)
    
    # Metadata
    stage_id: str = ""
    seed_ref: str = ""
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization"""
        return {
            "module": self.module,
            "prereq": self.prereq,
            "difficulty_tag": self.difficulty_tag,
            "problem": self.problem,
            "solution": {
                "steps": self.solution_steps,
                "final_answer": self.final_answer,
                "verification": self.verification
            },
            "adapter_flags": self.adapter_flags,
            "metadata": {
                "stage_id": self.stage_id,
                "seed_style_ref": self.seed_ref
            }
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "SyntheticDataItem":
        """Create from dictionary"""
        solution = data.get("solution", {})
        metadata = data.get("metadata", {})
        
        return cls(
            module=data.get("module", ""),
            prereq=data.get("prereq", []),
            difficulty_tag=data.get("difficulty_tag", "introductory"),
            problem=data.get("problem", ""),
            solution_steps=solution.get("steps", []),
            final_answer=solution.get("final_answer", ""),
            verification=solution.get("verification", ""),
            adapter_flags=data.get("adapter_flags", {}),
            stage_id=metadata.get("stage_id", ""),
            seed_ref=metadata.get("seed_style_ref", "")
        )
    
    def to_training_format(self) -> Dict[str, str]:
        """
        Convert to training format (input, output pair).
        
        Returns:
            Dictionary with 'input' and 'output' keys
        """
        # Format solution steps
        solution_text = "\n".join(self.solution_steps)
        if self.final_answer:
            solution_text += f"\n\nFinal Answer: {self.final_answer}"
        
        return {
            "input": self.problem,
            "output": solution_text
        }


def validate_synthetic_item(item: Dict[str, Any]) -> Tuple[bool, str]:
    """
    Validate a synthetic data item against the schema.
    
    Implements the validation from Appendix J.8:
    - Schema validation (required keys/types)
    - Verification check
    - Stage alignment
    
    Args:
        item: Dictionary representing a synthetic data item
    
    Returns:
        Tuple of (is_valid, error_message)
    """
    # Check required keys
    required_keys = ["module", "problem", "solution"]
    for key in required_keys:
        if key not in item:
            return False, f"Missing required key: {key}"
    
    # Check solution structure
    solution = item.get("solution", {})
    solution_keys = ["steps", "final_answer", "verification"]
    for key in solution_keys:
        if key not in solution:
            return False, f"Missing solution key: {key}"
    
    # Check steps are non-empty
    if not solution.get("steps"):
        return False, "Solution steps cannot be empty"
    
    # Check verification exists
    if not solution.get("verification"):
        return False, "Verification cannot be empty"
    
    # Check difficulty tag
    valid_difficulties = ["introductory", "intermediate", "advanced"]
    if item.get("difficulty_tag", "") not in valid_difficulties:
        return False, f"Invalid difficulty tag: {item.get('difficulty_tag')}"
    
    return True, ""


def parse_llm_json_response(response: str) -> List[Dict[str, Any]]:
    """
    Parse JSON from LLM response, handling common formatting issues.
    
    Args:
        response: Raw LLM response text
    
    Returns:
        Parsed JSON as list of dictionaries
    """
    # Clean response
    response = response.strip()
    
    # Remove markdown code blocks
    if response.startswith("```json"):
        response = response[7:]
    elif response.startswith("```"):
        response = response[3:]
    
    if response.endswith("```"):
        response = response[:-3]
    
    response = response.strip()
    
    # Try to parse
    try:
        data = json.loads(response)
        
        # Ensure it's a list
        if isinstance(data, dict):
            data = [data]
        
        return data
    except json.JSONDecodeError as e:
        logger.warning(f"JSON parse error: {e}")
        
        # Try to extract JSON array or object
        array_match = re.search(r'\[[\s\S]*\]', response)
        if array_match:
            try:
                return json.loads(array_match.group())
            except:
                pass
        
        object_match = re.search(r'\{[\s\S]*\}', response)
        if object_match:
            try:
                return [json.loads(object_match.group())]
            except:
                pass
        
        return []


def compute_rouge_l(prediction: str, reference: str) -> float:
    """
    Compute ROUGE-L score between prediction and reference.
    
    ROUGE-L measures the longest common subsequence (LCS).
    Used for instruction following evaluation (Appendix D).
    
    Args:
        prediction: Model prediction
        reference: Ground truth reference
    
    Returns:
        ROUGE-L F1 score
    """
    def lcs_length(x: List[str], y: List[str]) -> int:
        """Compute length of longest common subsequence"""
        m, n = len(x), len(y)
        dp = [[0] * (n + 1) for _ in range(m + 1)]
        
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if x[i-1] == y[j-1]:
                    dp[i][j] = dp[i-1][j-1] + 1
                else:
                    dp[i][j] = max(dp[i-1][j], dp[i][j-1])
        
        return dp[m][n]
    
    # Tokenize
    pred_tokens = prediction.lower().split()
    ref_tokens = reference.lower().split()
    
    if not pred_tokens or not ref_tokens:
        return 0.0
    
    # Compute LCS
    lcs = lcs_length(pred_tokens, ref_tokens)
    
    # Compute precision and recall
    precision = lcs / len(pred_tokens) if pred_tokens else 0.0
    recall = lcs / len(ref_tokens) if ref_tokens else 0.0
    
    # Compute F1
    if precision + recall == 0:
        return 0.0
    
    f1 = 2 * precision * recall / (precision + recall)
    return f1


def compute_exact_match(prediction: str, reference: str) -> bool:
    """
    Check if prediction matches reference exactly (after normalization).
    
    Args:
        prediction: Model prediction
        reference: Ground truth reference
    
    Returns:
        True if exact match
    """
    def normalize(text: str) -> str:
        # Lowercase
        text = text.lower()
        # Remove extra whitespace
        text = ' '.join(text.split())
        # Remove punctuation at end
        text = text.strip().rstrip('.')
        return text
    
    return normalize(prediction) == normalize(reference)


def extract_answer(text: str) -> str:
    """
    Extract the final answer from a solution text.
    
    Looks for patterns like:
    - "Final Answer: X"
    - "The answer is X"
    - "Therefore, X"
    
    Args:
        text: Full solution text
    
    Returns:
        Extracted answer
    """
    # Try common patterns
    patterns = [
        r"Final Answer:\s*(.+?)(?:\n|$)",
        r"The answer is\s*(.+?)(?:\n|$)",
        r"Therefore,?\s*(.+?)(?:\n|$)",
        r"Thus,?\s*(.+?)(?:\n|$)",
        r"=\s*(.+?)(?:\n|$)"
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).strip()
    
    # Return last line if no pattern found
    lines = text.strip().split('\n')
    return lines[-1].strip() if lines else ""


def format_for_training(
    items: List[SyntheticDataItem],
    include_system_prompt: bool = True
) -> List[Dict[str, Any]]:
    """
    Format synthetic data items for model training.
    
    Args:
        items: List of synthetic data items
        include_system_prompt: Whether to include system prompt
    
    Returns:
        List of training examples
    """
    training_data = []
    
    system_prompt = (
        "You are a helpful assistant that provides clear, step-by-step solutions. "
        "Always explain your reasoning and verify your answers."
    )
    
    for item in items:
        formatted = item.to_training_format()
        
        example = {
            "messages": [
                {"role": "user", "content": formatted["input"]},
                {"role": "assistant", "content": formatted["output"]}
            ]
        }
        
        if include_system_prompt:
            example["messages"].insert(
                0,
                {"role": "system", "content": system_prompt}
            )
        
        training_data.append(example)
    
    return training_data


def save_synthetic_data(
    items: List[SyntheticDataItem],
    filepath: str,
    format: str = "jsonl"
) -> None:
    """
    Save synthetic data to file.
    
    Args:
        items: List of synthetic data items
        filepath: Output file path
        format: "jsonl" or "json"
    """
    if format == "jsonl":
        with open(filepath, 'w', encoding='utf-8') as f:
            for item in items:
                f.write(json.dumps(item.to_dict(), ensure_ascii=False) + '\n')
    else:
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(
                [item.to_dict() for item in items],
                f,
                indent=2,
                ensure_ascii=False
            )
    
    logger.info(f"Saved {len(items)} items to {filepath}")


def load_synthetic_data(filepath: str) -> List[SyntheticDataItem]:
    """
    Load synthetic data from file.
    
    Args:
        filepath: Input file path
    
    Returns:
        List of synthetic data items
    """
    items = []
    
    if filepath.endswith('.jsonl'):
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data = json.loads(line)
                    items.append(SyntheticDataItem.from_dict(data))
    else:
        with open(filepath, 'r', encoding='utf-8') as f:
            data_list = json.load(f)
            for data in data_list:
                items.append(SyntheticDataItem.from_dict(data))
    
    logger.info(f"Loaded {len(items)} items from {filepath}")
    return items


def create_training_dataset(
    synthetic_data: List[SyntheticDataItem],
    tokenizer: Any,
    max_length: int = 2048
) -> Any:
    """
    Create a HuggingFace Dataset for training.
    
    Args:
        synthetic_data: List of synthetic data items
        tokenizer: HuggingFace tokenizer
        max_length: Maximum sequence length
    
    Returns:
        HuggingFace Dataset
    """
    try:
        from datasets import Dataset
    except ImportError:
        raise ImportError("Install datasets: pip install datasets")
    
    # Format data
    training_examples = format_for_training(synthetic_data)
    
    # Tokenize
    def tokenize_example(example):
        # Concatenate messages
        text = ""
        for msg in example["messages"]:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                text += f"<|system|>\n{content}\n"
            elif role == "user":
                text += f"<|user|>\n{content}\n"
            else:
                text += f"<|assistant|>\n{content}\n"
        
        return tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            padding="max_length"
        )
    
    dataset = Dataset.from_list(training_examples)
    dataset = dataset.map(tokenize_example, remove_columns=["messages"])
    
    return dataset


def filter_by_verification(
    items: List[SyntheticDataItem],
    strict: bool = True
) -> List[SyntheticDataItem]:
    """
    Filter synthetic data items by verification status.
    
    As mentioned in Section 3.4, if obtained results cannot pass
    verification, the reasoning process will be filtered out.
    
    Args:
        items: List of synthetic data items
        strict: If True, require non-empty verification
    
    Returns:
        Filtered list of items
    """
    filtered = []
    
    for item in items:
        if strict and not item.verification:
            continue
        
        # Could add more sophisticated verification here
        filtered.append(item)
    
    logger.info(f"Filtered {len(items)} -> {len(filtered)} items")
    return filtered


def deduplicate_items(
    items: List[SyntheticDataItem],
    similarity_threshold: float = 0.9
) -> List[SyntheticDataItem]:
    """
    Remove duplicate or very similar items.
    
    As mentioned in Appendix J.2, semantic deduplication maintains
    diversity across items.
    
    Args:
        items: List of synthetic data items
        similarity_threshold: Threshold for considering items duplicates
    
    Returns:
        Deduplicated list
    """
    if not items:
        return []
    
    unique_items = [items[0]]
    
    for item in items[1:]:
        is_duplicate = False
        
        for unique_item in unique_items:
            similarity = compute_rouge_l(item.problem, unique_item.problem)
            if similarity >= similarity_threshold:
                is_duplicate = True
                break
        
        if not is_duplicate:
            unique_items.append(item)
    
    logger.info(f"Deduplicated {len(items)} -> {len(unique_items)} items")
    return unique_items


if __name__ == "__main__":
    # Test utilities
    
    # Test ROUGE-L
    pred = "The quick brown fox jumps over the lazy dog"
    ref = "The fast brown fox leaps over the lazy dog"
    rouge = compute_rouge_l(pred, ref)
    print(f"ROUGE-L score: {rouge:.3f}")
    
    # Test JSON parsing
    response = """```json
    [
        {
            "module": "math/algebra",
            "problem": "Solve x + 2 = 5",
            "solution": {
                "steps": ["Step 1: Subtract 2", "Step 2: x = 3"],
                "final_answer": "x = 3",
                "verification": "3 + 2 = 5 ✓"
            },
            "difficulty_tag": "introductory"
        }
    ]
    ```"""
    
    parsed = parse_llm_json_response(response)
    print(f"Parsed {len(parsed)} items")
    
    # Test synthetic item
    if parsed:
        item = SyntheticDataItem.from_dict(parsed[0])
        print(f"Module: {item.module}")
        print(f"Problem: {item.problem}")