import os
import pandas as pd
from typing import List

from .base import BaseTask, TaskMetric
from autogluon.multimodal import MultiModalPredictor


class CamoSemSegTask(BaseTask):

    def __init__(
        self,
        bc_num_dims: int,
        bc_min_vals: List[float],
        bc_max_vals: List[float],
        bc_grid_sizes: List[int],
    ) -> None:
        super().__init__(bc_num_dims, bc_min_vals, bc_max_vals, bc_grid_sizes)
        self.task_name = "camo_sem_seg"
        dataset_dir = os.path.join(f"autogluon/examples/automm/Conv-LoRA/datasets/{self.task_name}", self.task_name)
        self.train_df = self._expand_path(pd.read_csv(os.path.join(dataset_dir, f"train.csv")), dataset_dir)
        self.test_df = self._expand_path(pd.read_csv(os.path.join(dataset_dir, f"test_CAMO.csv")), dataset_dir)

    def get_q_and_bc(self, ckpt_path: str, data_split: str) -> TaskMetric:
        predictor = MultiModalPredictor.load(ckpt_path)
        if data_split == "train":
            res = predictor.evaluate(self.train_df, metrics=["sm", "fm", "em", "mae"])
        elif data_split == "validation":
            res = predictor.evaluate(self.test_df, metrics=["sm", "fm", "em", "mae"])
        else:
            raise ValueError(f"Invalid data split: {data_split}")
        q_val = res["sm"]
        assert self.bc_num_dims == 1
        bc_ids = (self._get_bin_id(0, q_val),)
        return TaskMetric(quality=q_val, bc_ids=bc_ids)