"""
Question Answer Data Handler

This module provides a data handler for the Question Answer task.
"""

import numpy as np
import pandas as pd

from src.utils.decorator_utils import with_logger


class QATaskDataHandler:
    """ """

    @with_logger
    def __init__(
        self,
        data_path,
        dataset_name,
    ):
        """
        Initialise the tamper detection data handler.

        Args:
            font_semantics: Whether to use "font" or "semantics" data
        """
        self.data_path = data_path
        self.dataset_name = dataset_name
        self.data_df_tr = None
        self.data_df_te = None

        self.init_dataset()

    @with_logger
    def init_dataset(self):
        """
        Initialise the dataset by loading images.
        """
        path_tr = f"{self.data_path}/train/{self.dataset_name}.csv"
        path_te = f"{self.data_path}/test/{self.dataset_name}.csv"

        logger.debug(f"Loading training data from {path_tr}")
        self.data_df_tr = pd.read_csv(path_tr)

        logger.debug(f"Loading testing data from {path_te}")
        self.data_df_te = pd.read_csv(path_te)

    @with_logger
    def get_data_by_id(
        self,
        train_test_flag: str,
        data_id: int,
    ):
        """ """

        if train_test_flag == "train":
            data_df = self.data_df_tr
        elif train_test_flag == "test":
            data_df = self.data_df_te
        else:
            raise ValueError("train_test_flag must be 'train' or 'test'")

        if data_id < 0 or data_id >= len(data_df):
            raise ValueError("data_id is out of bounds")

        data_row = data_df.iloc[data_id]

        return (data_row["question"], data_row["answer"], data_row["reasoning"])

    @with_logger
    def get_data_by_id_lst(
        self,
        train_test_flag: str,
        data_id_lst: list[int],
    ):
        if train_test_flag == "train":
            data_df = self.data_df_tr
        elif train_test_flag == "test":
            data_df = self.data_df_te
        else:
            raise ValueError("train_test_flag must be 'train' or 'test'")

        if not data_id_lst:
            raise ValueError("data_id_lst is empty")

        data_rows = data_df.iloc[data_id_lst]

        return data_rows

    def get_data(self, train_test_flag: str) -> pd.DataFrame:
        """
        Get the data for a specific train/test split.

        Args:
            train_test_flag: "train" or "test"

        Returns:
            The data for the specified split
        """
        if train_test_flag == "train":
            return self.data_df_tr
        elif train_test_flag == "test":
            return self.data_df_te
        else:
            raise ValueError("train_test_flag must be 'train' or 'test'")

    @with_logger
    def get_size(
        self,
        data_type: str,
    ) -> int:
        """
        Get the size of the dataset.

        Returns:
            The number of data items in the dataset.
        """
        if data_type == "train":
            return self.data_df_tr.shape[0]
        elif data_type == "test":
            return self.data_df_te.shape[0]
        else:
            raise ValueError("data_type must be 'train' or 'test'")

    @with_logger
    def get_answer_dtype(self) -> str:
        """
        Get the data type of the answer column.

        Returns:
            The data type of the answer column
        """
        column_dtype = self.data_df_tr["answer"].dtype

        if np.issubdtype(column_dtype, np.integer):  # Check if integer type
            return int
        elif np.issubdtype(column_dtype, np.floating):  # Check if float type
            return float
        elif np.issubdtype(column_dtype, np.object_) or np.issubdtype(
            column_dtype, np.str_
        ):  # Check if string/object type
            return str
        else:
            return None  # For unsupported or unknown types
