"""Task abstract class for evaluation and results."""

import logging
from abc import ABC, abstractmethod
from enum import Enum
from importlib.metadata import version
from typing import Any, List, Literal, Optional

import datasets
from pydantic import BaseModel, model_validator

# HACK: if Modality is not defined, then import it from modality.py
try:
    from ..modality import Modality
except Exception:
    # if not, super hack to get the leaderboard working.
    # SHOULD MATCH the code exactly in modality.py
    # can we read the file and run that code?
    from enum import Enum

    class Modality(Enum):
        """Data modality, either DNA or protein sequence."""

        PROTEIN = "protein"
        DNA = "dna"


logging.basicConfig(level=logging.INFO)

TaskType = Literal[
    "classification",
    "pair_classification",
    "clustering",
    "eds",
    "bigene_mining",
    "retrieval",
]


class TaskMetric(BaseModel):
    id: str
    display_name: str
    description: Optional[str] = None
    value: float = 0.0


class LayerResult(BaseModel):
    layer_number: int
    layer_display_name: str
    metrics: List[TaskMetric]


class DGEBModel(BaseModel):
    hf_name: str
    num_layers: int
    num_params: int
    embed_dim: int


class Dataset(BaseModel):
    path: str
    revision: str

    def load(self) -> datasets.DatasetDict:
        ds = datasets.load_dataset(self.path, revision=self.revision)
        if not isinstance(ds, datasets.DatasetDict):
            raise ValueError(
                f"Dataset {self.path} is not a datasets.DatasetDict object."
            )
        return ds


class TaskMetadata(BaseModel):
    id: str
    display_name: str
    description: str
    modality: Modality
    type: TaskType
    # List of datasets used by the task.
    # Each dataset is a dict of all arguments to pass to `datasets.load_dataset()`.
    datasets: List[Dataset]
    primary_metric_id: str


# tasks.py
class TaskResult(BaseModel):
    dgeb_version: str
    task: "TaskMetadata"
    # TODO: Convert model to ModelMetadata
    model: DGEBModel
    results: List[LayerResult]

    @model_validator(mode="after")
    def check_valid_primary_metric(self):
        for result in self.results:
            if all(
                metric.id != self.task.primary_metric_id for metric in result.metrics
            ):
                raise ValueError(
                    f"Primary metric {self.task.primary_metric_id} not found in results.metrics"
                )
        return self

    @staticmethod
    def from_dict(
        task_metadata: "TaskMetadata",
        layer_results: LayerResult,
        model_metadata: DGEBModel,
    ):
        return TaskResult(
            dgeb_version=version("dgeb"),
            task=task_metadata,
            model=model_metadata,
            results=list(
                LayerResult(
                    layer_number=int(layer),
                    layer_display_name=str(layer),
                    metrics=[
                        TaskMetric(id=metric, display_name=metric, value=value)
                        for metric, value in metrics.items()
                    ],
                )
                for layer, metrics in layer_results["layers"].items()
            ),
        )


# move to model.py?
class Task(ABC):
    metadata: TaskMetadata

    # using Any instead of "BioSeqTransformer" to avoid installing all deps in leaderboard
    @abstractmethod
    def run(self, model: Any, layers: Optional[List[int]] = None) -> TaskResult:
        pass
