"""
Formalization data model.
"""

from dataclasses import dataclass, field
from typing import Dict, Optional, List, Any

from .proof_attempt import ProofAttempt

# Effective parameters for formalization model (goedel-prover-v2-8b)
FORMALIZATION_EFFECTIVE_PARAMS = 8.0


@dataclass
class Formalization:
    """Represents a formalization of a lemma/theorem generated by the formalizer model.

    A lemma/theorem can have multiple formalizations (different samples from the formalizer).
    Each formalization contains the Lean code, compilation/validation status, and proof attempts.
    """
    id: int = 0  # Formalization index/sample number
    compilation_pass: bool = False  # Whether formalization compiled successfully
    validation_pass: bool = False  # Whether validation passed
    is_selected: bool = False  # Whether this formalization is in selected_formalizations (was kept)
    detailed_cost: Optional[Dict[str, Any]] = None  # Cost from formalization LLM call

    # Lazy-loaded fields (Optional for skeleton mode)
    formal_statement: Optional[str] = None  # The Lean code
    formalization_reasoning: Optional[str] = None  # LLM reasoning for the formalization
    compilation_result: Optional[str] = None  # Compilation output/result
    compilation_errors: Optional[str] = None  # Compilation error messages
    validation_result: Optional[str] = None  # Validation output/result (raw response with reasoning)
    validation_reasoning: Optional[str] = None  # Extracted validation reasoning

    proof_attempts: List[ProofAttempt] = field(default_factory=list)
    proof_attempts_by_round: Optional[Dict[int, List[ProofAttempt]]] = None  # {correction_round_id: [attempts]}

    def get_best_attempt(self, lemmas_dict: Optional[Dict[int, 'Lemma']] = None) -> Optional[ProofAttempt]:
        """Returns the best passing & complete proof attempt.

        Priority:
        1. Passing attempts where all used lemmas are solved
        2. Passing attempts with no lemmas (direct proofs)
        3. Passing attempts with valid but incomplete lemma dependencies
        4. Skip attempts with used_lemma_ids=None (invalid/corrupted)

        Args:
            lemmas_dict: Optional dict of lemmas to check if dependencies are solved

        Returns None if no valid passing attempt exists.
        """
        from typing import TYPE_CHECKING
        if TYPE_CHECKING:
            from .lemma import Lemma

        passing_attempts = [attempt for attempt in self.proof_attempts if attempt.is_passing()]

        if not passing_attempts:
            return None

        # Priority 1: Passing attempts where all used lemmas are solved AND uses at least one lemma
        # Prefer proofs that use the breakdown structure
        if lemmas_dict is not None:
            for attempt in passing_attempts:
                # Treat None as empty set for this check
                used_lemmas = attempt.get_used_lemmas(lemmas_dict=lemmas_dict, recursive=False) if attempt.used_lemma_ids is not None else set()

                if used_lemmas:
                    # Check if all used lemmas are solved
                    all_solved = all(
                        lemmas_dict.get(lid) and lemmas_dict[lid].is_solved()
                        for lid in used_lemmas
                    )
                    if all_solved:
                        return attempt

        # Priority 2: Direct proofs with no lemmas (None or empty set)
        # Only return these if no fully-solved lemma-based proof exists
        for attempt in passing_attempts:
            # Treat None as empty set
            used_lemmas = attempt.get_used_lemmas(lemmas_dict=None, recursive=False) if attempt.used_lemma_ids is not None else set()
            if not used_lemmas:
                return attempt

        # Priority 3: Passing attempts with unsolved lemma dependencies
        # Better than nothing, but not ideal
        for attempt in passing_attempts:
            if attempt.used_lemma_ids is not None:
                return attempt

        # Last resort: return first passing (shouldn't reach here)
        return passing_attempts[0]

    def is_proven(self) -> bool:
        """Returns True if this formalization has a passing & complete proof."""
        return self.get_best_attempt() is not None
    
    def get_cost(self, cost_type) -> float:
        """Returns the formalization cost.

        For sflops cost types, calculates from tokens * effective_params if not pre-computed.
        Formalization uses goedel-prover-v2-8b with effective_params = 8.
        """
        if cost_type == 'prover_calls':
            return 0.0

        if not self.detailed_cost:
            return 0.0

        # If sflops requested and not already in detailed_cost, calculate from tokens
        if cost_type in ['input_sflops', 'output_sflops']:
            if cost_type not in self.detailed_cost:
                tokens_key = 'input_tokens' if cost_type == 'input_sflops' else 'output_tokens'
                tokens = self.detailed_cost.get(tokens_key, 0)
                return float(tokens * FORMALIZATION_EFFECTIVE_PARAMS)

        return float(self.detailed_cost.get(cost_type, 0.0))

    def get_total_cost(self, cost_type="cost", exclude_prover_calls: bool = False) -> float:
        """Returns the sum of costs from formalization and all proof attempts.

        Args:
            cost_type: Type of cost to return (e.g., 'cost', 'output_sflops', 'input_tokens')
            exclude_prover_calls: If True, excludes proof attempt costs (only returns formalization cost)
        """
        # Formalization cost
        formalization_cost = self.get_cost(cost_type)
        # Proof attempt costs
        attempt_cost = sum(attempt.get_total_cost(cost_type, exclude_prover_calls=exclude_prover_calls) for attempt in self.proof_attempts)
        return formalization_cost + attempt_cost


    def to_dict(self, origin_problem_id: str, round_id: int, breakdown_id: int, lemma_id: int, minified: bool = False):
        """Convert to dictionary representation (flat, no nested proof_attempts).

        Args:
            origin_problem_id: Origin problem ID
            round_id: Round ID
            breakdown_id: Breakdown ID
            lemma_id: Lemma ID (-1 for theorem)
            minified: If True, exclude reasoning fields
        """
        result = {
            "metadata": {
                "origin_problem_id": origin_problem_id,
                "round_id": round_id,
                "breakdown_id": breakdown_id,
                "lemma_id": lemma_id,
                "formalization_id": self.id
            },
            "compilation_pass": self.compilation_pass,
            "validation_pass": self.validation_pass,
            "is_selected": self.is_selected,
            "detailed_cost": self.detailed_cost,
            "formal_statement": self.formal_statement,
            "validation_result": self.validation_result
        }

        if not minified:
            result["formalization_reasoning"] = self.formalization_reasoning
            result["validation_reasoning"] = self.validation_reasoning

        return result

    @classmethod
    def from_dict(cls, data: dict) -> 'Formalization':
        """Reconstruct Formalization from dictionary representation (without proof_attempts)."""
        metadata = data.get("metadata", {})

        return cls(
            id=metadata.get("formalization_id", data.get("id", 0)),
            compilation_pass=data.get("compilation_pass", False),
            validation_pass=data.get("validation_pass", False),
            is_selected=data.get("is_selected", False),
            detailed_cost=data.get("detailed_cost"),
            formal_statement=data.get("formal_statement"),
            formalization_reasoning=data.get("formalization_reasoning"),
            validation_result=data.get("validation_result"),
            validation_reasoning=data.get("validation_reasoning"),
            proof_attempts=[]  # Proof attempts loaded separately
        )
