from typing import Dict, List, Tuple, Optional
import numpy as np
from haipr.data import HAIPRData
import logging
from dataclasses import dataclass
from haipr.utils.logger import setup_logger

logger = setup_logger(__name__, log_level=logging.DEBUG)


@dataclass
class CVFold:
    """Dataclass to store fold information"""

    train_idx: np.ndarray
    val_idx: np.ndarray
    test_idx: np.ndarray
    fold_id: int


class NestedCVManager:
    """Manages nested cross-validation splits and related data."""

    def __init__(
        self,
        data: HAIPRData,
        inner_split_method: str = "cv",
        n_outer_splits: int = 5,
        n_inner_splits: int = 5,
        mlflow_run_id: Optional[str] = None,
    ):
        """
        Initialize the NestedCVManager.

        Args:
            data: HAIPRData object containing the dataset
            inner_split_method: Method to use for inner splits ('cv', 'lomo', etc.)
            n_outer_splits: Number of outer CV splits
            n_inner_splits: Number of inner splits
            mlflow_run_id: Optional MLflow run ID to use for logging
        """
        self.data = data
        self.inner_split_method = inner_split_method
        self.n_outer_splits = n_outer_splits
        self.n_inner_splits = n_inner_splits
        self.mlflow_run_id = mlflow_run_id
        self.folds: List[CVFold] = []

        self._generate_nested_splits()

    def _generate_nested_splits(self):
        """Generate nested cross-validation splits with CV outer and specified inner splits."""
        logger.info(
            f"Generating nested splits with {self.inner_split_method} inner splits"
        )

        # Generate outer CV splits
        self.data.cv = self.n_outer_splits
        self.data.generate_splits("cv")  # Always use CV for outer splits
        outer_splits = self.data.splits

        # For each outer split, generate inner splits of specified type
        for outer_fold, (train_idx, test_idx) in enumerate(outer_splits):
            logger.debug(f"Processing outer fold {outer_fold}")

            # Create temporary dataset for inner splits
            train_data = self.data.data.iloc[train_idx].reset_index(drop=True)
            train_labels = self.data.labels[train_idx]

            tmp_config = self.data.config.copy()
            tmp_config.data.split_method = self.inner_split_method

            temp_data = HAIPRData(
                train_data, tmp_config
            )
            temp_data.labels = train_labels

            # Generate inner splits using specified method
            temp_data.generate_splits(self.inner_split_method)

            # Store all combinations of splits
            for inner_fold, (inner_train_idx, inner_val_idx) in enumerate(
                temp_data.splits
            ):
                # Map inner indices back to original data indices
                mapped_train_idx = train_idx[inner_train_idx]
                mapped_val_idx = train_idx[inner_val_idx]

                self.folds.append(
                    CVFold(
                        train_idx=mapped_train_idx,
                        val_idx=mapped_val_idx,
                        test_idx=test_idx,
                        fold_id=outer_fold,
                    )
                )

    def get_outer_folds(self) -> List[Tuple[np.ndarray, np.ndarray]]:
        """Get unique outer folds."""
        unique_folds = {}
        for fold in self.folds:
            if fold.fold_id not in unique_folds:
                unique_folds[fold.fold_id] = (
                    np.concatenate([fold.train_idx, fold.val_idx]),
                    fold.test_idx,
                )
        return [unique_folds[i] for i in sorted(unique_folds.keys())]

    def get_inner_folds(
        self, outer_fold_id: int
    ) -> List[Tuple[np.ndarray, np.ndarray]]:
        """Get inner folds for a specific outer fold."""
        inner_folds = [
            (fold.train_idx, fold.val_idx)
            for fold in self.folds
            if fold.fold_id == outer_fold_id
        ]
        return inner_folds

    def get_all_folds(self) -> List[CVFold]:
        """Get all fold combinations."""
        return self.folds
