import logging
import os
import traceback
from itertools import chain
from typing import Any, List, Optional

from rich.console import Console

from .eval_utils import set_all_seeds
from .modality import Modality
from .models import BioSeqTransformer
from .tasks.tasks import Task, TaskResult

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class DGEB:
    """GEB class to run the evaluation pipeline."""

    def __init__(self, tasks: List[type[Task]], seed: int = 42):
        self.tasks = tasks
        set_all_seeds(seed)

    def print_selected_tasks(self):
        """Print the selected tasks."""
        console = Console()
        console.rule("[bold]Selected Tasks\n", style="grey15")
        for task in self.tasks:
            prefix = "    - "
            name = f"{task.metadata.display_name}"
            category = f", [italic grey39]{task.metadata.type}[/]"
            console.print(f"{prefix}{name}{category}")
        console.print("\n")

    def run(
        self,
        model,  # type encoder
        output_folder: Optional[str] = "results",
    ) -> List[TaskResult]:
        """Run the evaluation pipeline on the selected tasks.

        Args:
            model: Model to be used for evaluation
            output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format:
                `{output_folder}/{model_name}/{model_revision}/{task_name}.json`.

        Returns:
            A list of MTEBResults objects, one for each task evaluated.
        """
        # Run selected tasks
        self.print_selected_tasks()
        results = []

        for task in self.tasks:
            logger.info(
                f"\n\n********************** Evaluating {task.metadata.display_name} **********************"
            )

            try:
                result = task().run(model)
            except Exception as e:
                logger.error(e)
                logger.error(traceback.format_exc())
                logger.error(f"Error running task {task}")
                continue

            results.append(result)
            if output_folder:
                save_path = get_output_folder(model.hf_name, task, output_folder)
                with open(save_path, "w") as f_out:
                    f_out.write(result.model_dump_json(indent=2))
        return results


def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]:
    all_names = get_all_model_names()
    for cls in BioSeqTransformer.__subclasses__():
        if model_name in cls.MODEL_NAMES:
            return cls(model_name, **kwargs)
    raise ValueError(f"Model {model_name} not found in {all_names}.")


def get_all_model_names() -> List[str]:
    return list(
        chain.from_iterable(
            cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__()
        )
    )


def get_all_task_names() -> List[str]:
    return [task.metadata.id for task in get_all_tasks()]


def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]:
    return [_get_task(task) for task in tasks]


def get_tasks_by_modality(modality: Modality) -> List[type[Task]]:
    return [task for task in get_all_tasks() if task.metadata.modality == modality]


def get_all_tasks() -> List[type[Task]]:
    return Task.__subclasses__()


def _get_task(task_name: str) -> type[Task]:
    logger.info(f"Getting task {task_name}")
    for task in get_all_tasks():
        if task.metadata.id == task_name:
            return task

    raise ValueError(
        f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}"
    )


def get_output_folder(
    model_hf_name: str, task: type[Task], output_folder: str, create: bool = True
):
    output_folder = os.path.join(output_folder, os.path.basename(model_hf_name))
    # create output folder if it does not exist
    if create and not os.path.exists(output_folder):
        os.makedirs(output_folder)
    return os.path.join(
        output_folder,
        f"{task.metadata.id}.json",
    )
