import os
import yaml
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pandas as pd


class DatasetConfig:

    def __init__(self, name: str, config_dict: Dict):
        self.name = name
        self.display_name = config_dict.get('display_name', name)
        self.description = config_dict.get('description', '')

        self.base_path = Path(config_dict['base_path'])
        self.raw_features_path = self.base_path / 'raw_features'
        self.labels_path = self.base_path / 'labels'

        self.label_file = config_dict['labels']['file']
        self.label_column = config_dict['labels']['column']
        self.wsi_name_column = config_dict['labels']['wsi_name_column']

        self.precise_classes = config_dict['classes']['precise']
        self.uncertain_classes = config_dict['classes'].get('uncertain', [])
        self.class_mapping = config_dict['classes'].get('mapping', {})

        self.feature_dim = config_dict.get('feature_dim', 512)
        self.feature_format = config_dict.get('feature_format', 'npy')

        self.data_split_config = config_dict.get('data_split', {})

        self.output_base = Path('outputs') / self.name
        self.models_path = self.output_base / 'models'
        self.results_path = self.output_base / 'results'
        self.logs_path = self.output_base / 'logs'
        self._create_directories()

    def _create_directories(self):
        for path in [self.models_path, self.results_path, self.logs_path]:
            path.mkdir(parents=True, exist_ok=True)

    def get_label_file_path(self) -> Path:
        return self.labels_path / self.label_file

    def get_feature_file_path(self, wsi_name: str) -> Path:
        return self.raw_features_path / f"{wsi_name}.{self.feature_format}"

    def get_model_save_path(self, feature_type: str, n_clusters: int,
                           n_samples: int, sampling_method: str) -> Path:
        if sampling_method == 'none':
            param_str = f"{feature_type}_C{n_clusters}_{sampling_method}"
        else:
            param_str = f"{feature_type}_C{n_clusters}_S{n_samples}_{sampling_method}"

        return self.models_path / param_str

    def get_results_file_path(self, filename: str) -> Path:
        return self.results_path / filename


class DatasetManager:

    def __init__(self, config_file: str = 'configs/datasets.yaml'):
        self.config_file = Path(config_file)
        self.datasets: Dict[str, DatasetConfig] = {}
        self.current_dataset: Optional[str] = None

        self._load_configs()

    def _load_configs(self):
        if not self.config_file.exists():
            self._create_default_config()

        with open(self.config_file, 'r', encoding='utf-8') as f:
            configs = yaml.safe_load(f)

        for name, config in configs['datasets'].items():
            self.datasets[name] = DatasetConfig(name, config)
    
    def _create_default_config(self):
        default_config = {
            'datasets': {
                'hupo_cancer': {
                    'display_name': 'Periampullary Cancer',
                    'description': 'Periampullary cancer WSI classification',
                    'base_path': 'datasets/hupo_cancer',
                    'labels': {
                        'file': 'periampullary_cancer.xlsx',
                        'column': 'label',
                        'wsi_name_column': 'wsi_name'
                    },
                    'classes': {
                        'precise': ['bile_duct_cancer', 'duodenal_cancer', 'pancreatic_head_cancer'],
                        'uncertain': ['bile_duct_cancer?', 'duodenal_cancer?', 'ampullary_cancer'],
                        'mapping': {}
                    },
                    'feature_dim': 512,
                    'feature_format': 'npy'
                }
            }
        }

        self.config_file.parent.mkdir(parents=True, exist_ok=True)

        with open(self.config_file, 'w', encoding='utf-8') as f:
            yaml.dump(default_config, f, default_flow_style=False, allow_unicode=True)

    def list_datasets(self) -> List[str]:
        return list(self.datasets.keys())

    def set_dataset(self, dataset_name: str):
        if dataset_name not in self.datasets:
            raise ValueError(f"Dataset '{dataset_name}' not found. Available: {self.list_datasets()}")

        self.current_dataset = dataset_name
        print(f"Active dataset: {self.datasets[dataset_name].display_name}")

    def get_current_config(self) -> DatasetConfig:
        if self.current_dataset is None:
            raise ValueError("No dataset selected. Use set_dataset() first.")

        return self.datasets[self.current_dataset]

    def get_config(self, dataset_name: str) -> DatasetConfig:
        if dataset_name not in self.datasets:
            raise ValueError(f"Dataset '{dataset_name}' not found.")

        return self.datasets[dataset_name]

    def add_dataset(self, name: str, config_dict: Dict):
        self.datasets[name] = DatasetConfig(name, config_dict)

        with open(self.config_file, 'r', encoding='utf-8') as f:
            configs = yaml.safe_load(f)

        configs['datasets'][name] = config_dict

        with open(self.config_file, 'w', encoding='utf-8') as f:
            yaml.dump(configs, f, default_flow_style=False, allow_unicode=True)

        print(f"Dataset '{name}' added successfully")


dataset_manager = DatasetManager()


def get_dataset_config(dataset_name: str = None) -> DatasetConfig:
    if dataset_name:
        return dataset_manager.get_config(dataset_name)
    else:
        return dataset_manager.get_current_config()


def set_active_dataset(dataset_name: str):
    dataset_manager.set_dataset(dataset_name)


def list_available_datasets() -> List[str]:
    return dataset_manager.list_datasets()
