
from dataclasses import dataclass
from functools import total_ordering
from typing import List, Any, Optional
from enum import Enum
from dataclasses_json import DataClassJsonMixin
from src.utils import trim_long_string

@dataclass
@total_ordering
class PerformanceMetric(DataClassJsonMixin):
    mean: float
    std: float
    min: float
    max: float
    num_trials: int
    
    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, PerformanceMetric):
            return NotImplemented
        return self.mean == other.mean and self.std == other.std
    
    def __lt__(self, other: Any) -> bool:
        """
        self < other means self is worse
        1) mean: smaller is better, larger is worse
        2) if mean is the same -> std: smaller is better, larger is worse
        """
        if not isinstance(other, PerformanceMetric):
            return NotImplemented
        if self.mean != other.mean:
            return self.mean > other.mean # mean larger => worse => self < other
        return self.std > other.std # mean same, std larger => worse => self < other
        
@dataclass
@total_ordering
class OpVerificationResult(DataClassJsonMixin):
    compiled: bool
    compile_info: str
    correctness: bool
    correctness_info: str
    performance: Optional[PerformanceMetric]
    hardware: str
    device_id: int = -1
    
    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, OpVerificationResult):
            return NotImplemented
        return (self.compiled == other.compiled and
                self.correctness == other.correctness and
                self.performance == other.performance)
    
    def __lt__(self, other: Any) -> bool:
        """
        self < other means self is worse
        """
        if not isinstance(other, OpVerificationResult):
            return NotImplemented
        
        # 1) performance
        if self.performance and other.performance:
            if self.performance != other.performance:
                return self.performance < other.performance
        elif not self.performance and other.performance:
            return True
        elif self.performance and not other.performance:
            return False
        
        # 2) correctness
        if self.correctness != other.correctness:
            return (not self.correctness) and other.correctness
        
        # 3) compiled
        if self.compiled != other.compiled:
            return (not self.compiled) and other.compiled
        return False
    
    def __str__(self) -> str:
        # performance block
        if self.performance is None:
            perf_block = "None"
        else:
            perf_block = (
                f"mean={self.performance.mean:.4g}, "
                f"std={self.performance.std:.4g}, "
                f"min={self.performance.min:.4g}, "
                f"max={self.performance.max:.4g}, "
                f"trials={self.performance.num_trials}"
            )

        compile_info = trim_long_string(self.compile_info)
        correctness_info = trim_long_string(self.correctness_info)

        return (
            "VerificationResult\n"
            f"  status:\n"
            f"    compiled:     {self.compiled}\n"
            f"    correctness:  {self.correctness}\n"
            f"  performance:\n"
            f"    {perf_block}\n"
            f"  environment:\n"
            f"    hardware:     {self.hardware}\n"
            f"    device_id:    {self.device_id}\n"
            f"  messages:\n"
            f"    compile_info:     {compile_info}\n"
            f"    correctness_info: {correctness_info}"
        )
    
class VerifyType(str,Enum):
    MODEL = "model"
    RULE = "rule"