from abc import ABC, abstractmethod
from typing import List, Optional, Union, Dict
import os

from lmms_eval.tasks import initialize_tasks, include_path, get_task_dict, ConfigurableTask
from lmms_eval.api.registry import ALL_TASKS

import torch.distributed as dist


def rank0_print(*args):
    if dist.is_initialized():
        if dist.get_rank() == 0:
            print(f"Rank {dist.get_rank()}: ", *args)
    else:
        print(*args)


class BaseEmbedder(ABC):
    def __init__(self, name: str, output_path: str) -> None:
        super().__init__()
        self.name = name
        self.output_path = output_path
        os.makedirs(self.output_path, exist_ok=True)
        initialize_tasks()

    def flatten(self, input):
        new_list = []
        for i in input:
            for j in i:
                new_list.append(j)
        return new_list

    # A static method to build requests for lmms_eval tasks
    # Pass in task name and return a list of Requests
    @staticmethod
    def init_task(task: str, ignored_ids: Union[set, List] = None):
        task_dict = get_task_dict([task], model_name="llava")
        task_obj = task_dict[task]
        if type(task_obj) == tuple:
            group, task_obj = task_obj
        DATASET_PATH = task_obj.DATASET_PATH
        DATASET_NAME = None
        if task_obj.DATASET_NAME is not None:
            DATASET_NAME = task_obj.DATASET_NAME

        docs = task_obj.test_docs() if task_obj.has_test_docs() else task_obj.validation_docs()
        split = task_obj.config.test_split if task_obj.has_test_docs() else task_obj.config.validation_split
        rank0_print(f"\nTask : {task_obj.config.task}\n - #num : {len(task_obj.test_docs()) if task_obj.has_test_docs() else task_obj.validation_docs()}")
        task_obj.build_all_requests()
        requests = []
        for instance in task_obj.instances:
            reqtype = instance.request_type
            contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = instance.args
            if ignored_ids is not None and doc_id in ignored_ids:
                continue
            requests.append(instance)
        return DATASET_PATH, DATASET_NAME, split, requests, task_obj, docs

    @abstractmethod
    def embed_task(self, task: str):
        return
