from dataclasses import dataclass
from typing import Optional

from src.dataset_processing.common.config.base_configs import BaseProcessedDatasetConfig, BaseRawDatasetConfig
from src.dataset_processing.common.enums.dataset_types import DatasetType

####################################################################################################
# TriviaQA
####################################################################################################
        
@dataclass
class TriviaQARawConfig(BaseRawDatasetConfig):
    """Configuration specific to raw TriviaQA datasets"""
    split: str = "validation"
    version: str = "rc.wikipedia"
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.TRIVIAQA
        self.dataset_name = "triviaqa"
        
@dataclass
class TriviaQAProcessedConfig(BaseProcessedDatasetConfig):
    """Configuration specific to processed TriviaQA datasets"""
    split: str = "validation"
    version: str = "rc.wikipedia"
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.TRIVIAQA
        self.dataset_name = "triviaqa"

####################################################################################################
# CommonSenseQA
####################################################################################################

@dataclass
class CommonSenseQARawConfig(BaseRawDatasetConfig):
    """Configuration specific to raw CommonSenseQA datasets"""
    split: str = "train"
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.COMMONSENSEQA
        self.dataset_name = "commonsenseqa"

@dataclass
class CommonSenseQAProcessedConfig(BaseProcessedDatasetConfig):
    """Configuration specific to processed CommonSenseQA datasets"""
    split: str = "train"
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.COMMONSENSEQA
        self.dataset_name = "commonsenseqa"

####################################################################################################
# CoQA
####################################################################################################

@dataclass
class CoQARawConfig(BaseRawDatasetConfig):
    """Configuration specific to raw CoQA datasets"""
    split: str = "dev"
    questions_per_conversation: Optional[int] = None
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.COQA
        self.dataset_name = "coqa"

@dataclass
class CoQAProcessedConfig(BaseProcessedDatasetConfig):
    """Configuration specific to processed CoQA datasets"""
    split: str = "dev"
    questions_per_conversation: Optional[int] = None
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.COQA
        self.dataset_name = "coqa"
     
####################################################################################################
# MMLU
####################################################################################################

@dataclass
class MMLURawConfig(BaseRawDatasetConfig):
    """Configuration specific to raw MMLU datasets"""
    split: str = "test"
    subject: str = "abstract_algebra"
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.MMLU
        self.dataset_name = "mmlu"

@dataclass
class MMLUProcessedConfig(BaseProcessedDatasetConfig):
    """Configuration specific to processed MMLU datasets"""
    split: str = "test"
    subject: str = "abstract_algebra"
    
    def __post_init__(self):
        """Ensure dataset type is correct"""
        self.dataset_type = DatasetType.MMLU
        self.dataset_name = "mmlu"
