import os
import pandas as pd
from typing import List, Optional
from src.dataset_processing.common.config.base_configs import BaseDatasetConfig
from src.dataset_processing.common.enums.field_names import DatasetEntryNames, PerturbationNames
from src.dataset_processing.common.enums.source_types import DatasetSourceType
from src.dataset_processing.common.models.dataset_entry import DatasetEntry
from src.dataset_processing.common.models.dataset_result import DatasetResult
from src.dataset_processing.common.processing.base_processor import DatasetProcessor
from src.dataset_processing.datasets.mmlu.data_processor import MMLUDataProcessor
from src.dataset_processing.datasets.mmlu.file_handler import MMLUFileHandler
from src.dataset_processing.datasets.mmlu.models import ProcessedMMLUEntry
from src.dataset_processing.perturbations.config.perturbation_config import PerturbationConfig
from src.loggers.setup_logging import setup_logging

logger = setup_logging()

def get_data_path(path):
    value = os.environ.get(path)
    if not value:
        raise ValueError(f"Required environment variable '{path}' is missing or empty in .env file")
    return value

data_path = get_data_path('DATA_PATH')

class MMLUProcessor(DatasetProcessor):
    """Processor for MMLU datasets."""
    
    def __init__(
        self,
        base_dir: str = data_path
    ):
        """Initialize processor with base directory."""
        super().__init__()
        self.base_dir = base_dir
        self.file_handler = MMLUFileHandler(base_dir)
        self.data_processor = MMLUDataProcessor()

    def cache_dataset(self, result: DatasetResult, config: BaseDatasetConfig) -> None:
        """Cache processed dataset."""
        cache_dir = self.file_handler.get_dataset_dir(config.source_type)
        cache_path = self.file_handler.get_cache_path(cache_dir, config)
        
        df = self._convert_result_to_dataframe(result)
        df.to_csv(cache_path, index=False)
        logger.info(f"Cached dataset to {cache_path}")

    def load_from_cache(self, config: BaseDatasetConfig) -> Optional[DatasetResult]:
        """Load dataset from cache if available."""
        cache_dir = self.file_handler.get_dataset_dir(config.source_type)
        cache_path = self.file_handler.get_cache_path(cache_dir, config)
        
        if not os.path.exists(cache_path):
            return None
            
        try:
            df = pd.read_csv(cache_path)
            if config.num_entries:
                df = df.head(config.num_entries)
            return self._create_result_from_dataframe(df)
        except Exception as e:
            logger.error(f"Error loading cached dataset: {str(e)}")
            return None

    def process_dataset(self, config: BaseDatasetConfig) -> DatasetResult:
        """Process MMLU dataset according to configuration."""
        if config.source_type == DatasetSourceType.PROCESSED:
            if config.force_reprocess:
                result = self._process_raw_dataset(config)
                self._save_to_processed(result, config)
                return result
            try:
                return self._load_processed_dataset(config)
            except FileNotFoundError:
                result = self._process_raw_dataset(config)
                self._save_to_processed(result, config)
                return result
        else:  # RAW
            return self._process_raw_dataset(config)

    def _process_raw_dataset(self, config: BaseDatasetConfig) -> DatasetResult:
        """Process raw MMLU dataset."""
        dataset = self.file_handler.load_dataset(config.split, config.subject)
        
        processed_entries = []
        for idx, raw_entry in dataset.iterrows():
            if config.num_entries and idx >= config.num_entries:
                break
                
            entry = self.data_processor.parse_raw_entry(raw_entry, config.subject, idx)
            
            perturbation_config = None
            if config.source_type == DatasetSourceType.PROCESSED:
                perturbation_config = PerturbationConfig(
                    type=config.perturbation_type,
                    intensity=config.perturbation_intensity
                )
            
            processed_entry = self.data_processor.create_processed_entry(
                entry,
                perturbation_config
            )
            
            processed_entries.append(self._convert_to_dataset_entry(processed_entry))
        
        return self._create_dataset_result(processed_entries, config)

    def _convert_to_dataset_entry(self, processed_entry: ProcessedMMLUEntry) -> DatasetEntry:
        """Convert processed entry to dataset entry."""
        return DatasetEntry(
            question=processed_entry.question,
            answer=processed_entry.answer,
            metadata=processed_entry.metadata
        )

    def _convert_result_to_dataframe(self, result: DatasetResult) -> pd.DataFrame:
        """Convert dataset result to DataFrame."""
        data = []
        for entry in result.entries:
            row = {
                DatasetEntryNames.QUESTION.value: entry.question,
                DatasetEntryNames.ANSWER.value: entry.answer,
                **entry.metadata
            }
            if result.perturbation_config:
                row.update({
                    PerturbationNames.PERTURBATION_TYPE.value: result.perturbation_config[0].type,
                    PerturbationNames.PERTURBATION_INTENSITY.value: result.perturbation_config[0].intensity
                })
            data.append(row)
        return pd.DataFrame(data)

    def _create_dataset_result(
        self,
        entries: List[DatasetEntry],
        config: BaseDatasetConfig
    ) -> DatasetResult:
        """Create dataset result with perturbation info."""
        return DatasetResult(
            entries=entries,
            perturbation_config=[PerturbationConfig(config.perturbation_type, config.perturbation_intensity)]
            if config.source_type == DatasetSourceType.PROCESSED and config.perturbation_type != "none" else None,
            config=config
        )

    def _load_processed_dataset(self, config: BaseDatasetConfig) -> DatasetResult:
        """Load already processed dataset."""
        processed_dir = self.file_handler.get_dataset_dir(DatasetSourceType.PROCESSED)
        processed_path = self.file_handler.get_cache_path(processed_dir, config)
        
        if not os.path.exists(processed_path):
            raise FileNotFoundError(f"Processed dataset not found: {processed_path}")
        
        return self._load_from_csv(processed_path, config)
    
    def _save_to_processed(self, result: DatasetResult, config: BaseDatasetConfig) -> None:
        """Save processed dataset to the processed directory."""
        processed_dir = self.file_handler.get_dataset_dir(DatasetSourceType.PROCESSED)
        processed_path = self.file_handler.get_cache_path(processed_dir, config)
        
        os.makedirs(processed_dir, exist_ok=True)
        
        df = self._convert_result_to_dataframe(result)
        df.to_csv(processed_path, index=False)
        logger.info(f"Saved processed dataset to {processed_path}")

    def _load_from_csv(self, file_path: str, config: BaseDatasetConfig) -> DatasetResult:
        """Load and process dataset from CSV."""
        try:
            df = pd.read_csv(file_path)
        except pd.errors.EmptyDataError:
            raise FileNotFoundError(f"Empty dataset file: {file_path}")
        
        if config.num_entries:
            df = df.head(config.num_entries)
            
        entries = []
        perturbation_config = []
        
        for _, row in df.iterrows():
            entries.append(self._create_dataset_entry_from_row(row))
            if PerturbationNames.PERTURBATION_TYPE.value in row and PerturbationNames.PERTURBATION_INTENSITY.value in row:
                perturbation_config.append(
                    PerturbationConfig(
                        type=row[PerturbationNames.PERTURBATION_TYPE.value],
                        intensity=row[PerturbationNames.PERTURBATION_INTENSITY.value]
                    )
                )
        
        return DatasetResult(
            entries=entries,
            perturbation_config=perturbation_config if perturbation_config else None,
            config=config
        )
    
    def _create_dataset_entry_from_row(self, row: pd.Series) -> DatasetEntry:
        """Create dataset entry from DataFrame row."""
        metadata = {
            key: row[key] for key in row.keys()
            if key not in [
                DatasetEntryNames.QUESTION.value,
                DatasetEntryNames.ANSWER.value,
                PerturbationNames.PERTURBATION_TYPE.value,
                PerturbationNames.PERTURBATION_INTENSITY.value
            ]
        }
        
        return DatasetEntry(
            question=row[DatasetEntryNames.QUESTION.value],
            answer=row[DatasetEntryNames.ANSWER.value],
            metadata=metadata
        )
