from typing import Tuple, Dict, List, Optional, TypedDict
from dataclasses import dataclass, field  # Add field

from tasks.base import TaskMetric


@dataclass
class ArchiveData:
    quality: float  # Model performance on the task
    model_path: str  # Path to the model which was evaluated
    sampling_freq: int  # Number of times the model was sampled
    validation_quality: float  # Model performance on the validation set
    overall_fitness: Optional[float] = (
        None  # Overall fitness of the model across multiple tasks (and their examples)
    )
    skill_vector: Optional[List[bool]] = None  # Binary vector of examples passed/failed
    example_results: Optional[Dict[int, bool]] = (
        None  # Maps example_id to whether it was solved correctly
    )


@dataclass
class ModelEvalResult:
    model_path: str  # Path to the model which was evaluated
    task_metrics: Dict[str, TaskMetric]  # Task metrics


@dataclass
class QDInfo:
    task_name: str  # Name of the task
    quality: float  # Model performance on the task
    bc_ids: Tuple[int]  # Model behavior characterization


@dataclass
class MergeResult:
    qd_info: Dict[str, QDInfo]  # QD information
    save_path: str  # Path to the saved model


@dataclass
class LSMergeResult:
    qd_info: Dict[str, QDInfo]  # QD information
    save_path: str  # Path to the saved model
    task_metrics: Dict[str, TaskMetric]  # Task metrics
    validation_quality: Optional[float] = (
        None  # Model performance on validation data
    )


@dataclass
class ACDTaskEvalDetail:
    """Stores detailed results for a single ACD task evaluation."""

    task_id: str
    instructions: str
    raw_output: str
    score: float  # Store the score as well for context


@dataclass
class ACDDNSMergeResult:
    save_path: str  # Path to the saved model
    task_metrics: Optional[Dict[str, TaskMetric]] = (
        None  # Metrics for standard tasks
    )
    acd_skill_vector: Optional[Dict[str, float]] = (
        None  # Skill vector for ACD tasks {task_id: score}
    )
    avg_acd_quality: Optional[float] = (
        None  # Average quality across evaluated ACD tasks
    )
    # Add field for detailed ACD eval results (optional list of details)
    acd_eval_details: Optional[List[ACDTaskEvalDetail]] = field(
        default=None, repr=False
    )  # Don't include potentially long details in default repr

    # Add field for whether the model returns gibberish or not
    is_gibberish: bool = False


@dataclass
class DNSSolution:
    """Represents a solution in the DNS archive."""

    model_path: str  # Path to the model
    fitness: float  # Overall fitness (accuracy across all tasks)
    skill_vector: List[bool]
    rank: Optional[int] = None  # Domination rank (computed during sorting)
    validation_quality: Optional[float] = (
        None  # Validation quality (accuracy on validation set)
    )


@dataclass
class ACDDNSSolution:
    """Represents a solution in the DNS archive."""

    model_path: str  # Path to the model
    fitness: float  # Overall fitness (accuracy across all tasks)
    acd_skill_vector: Optional[Dict[str, float]] = (
        None  # Skill vector for ACD tasks {task_id: score}
    )
    rank: Optional[int] = None  # Domination rank (computed during sorting)
    validation_quality: Optional[float] = (
        None  # Validation quality (accuracy on validation set)
    )
    # Add field for detailed ACD eval results (optional list of details)
    acd_eval_details: Optional[List[ACDTaskEvalDetail]] = field(
        default=None, repr=False
    )  # Don't include potentially long details in default repr

    # Add field for whether the model returns gibberish or not
    is_gibberish: bool = False


@dataclass
class DNSArchive:
    """Archive for DNS that maintains a fixed population size."""

    solutions: List[DNSSolution]  # List of current solutions
    max_size: int  # Maximum archive size
    w: float  # Distance scaling parameter for domination
    adaptive_w: bool  # Whether to adaptively tune w
    fitness_threshold: float = 0.9  # Threshold for fitness comparison (90%)
    coverage_threshold: float = 0.95  # Threshold for example coverage (95%)

    def __post_init__(self):
        self.solutions = sorted(
            self.solutions, key=lambda x: x.fitness, reverse=True
        )


class ACDDNSArchiveData(TypedDict):
    dns_archive: List[ACDDNSSolution]
    dirs: Dict[str, str]