import pandas as pd
import numpy as np
import os
import re
from sklearn.cluster import MiniBatchKMeans
import pickle
from typing import Dict, List, Tuple, Optional
from .dataset_manager import get_dataset_config, DatasetConfig
from .data_splitter import create_data_splitter


class WSIDataProcessor:

    def __init__(self, dataset_name: str = None):
        if dataset_name:
            self.config = get_dataset_config(dataset_name)
        else:
            self.config = get_dataset_config()

        self.precise_labels = self.config.precise_classes
        self.uncertain_labels = self.config.uncertain_classes

        if self.config.name == 'hupo_cancer':
            self.general_labels = ['ampullary_cancer']
            self.label_mapping = {
                'bile_duct_cancer_cn': 'bile_duct_cancer',
                'duodenal_cancer_cn': 'duodenal_cancer',
                'pancreatic_head_cancer_cn': 'pancreatic_head_cancer',
                'bile_duct_cancer_uncertain_cn': 'bile_duct_cancer?',
                'duodenal_cancer_uncertain_cn': 'duodenal_cancer?',
                'ampullary_cancer_cn': 'ampullary_cancer'
            }
        else:
            self.general_labels = []
            self.label_mapping = self.config.class_mapping

    def _parse_feature_type(self, feature_type: str) -> Tuple[str, int]:
        if feature_type == 'original':
            return 'original', self.config.feature_dim
        else:
            raise ValueError(f"Only 'original' feature type is supported, got: {feature_type}")

    def _extract_first_number(self, scan_block_str: str) -> Optional[str]:
        numbers = re.findall(r'\d+', str(scan_block_str))
        return numbers[0] if numbers else None

    def _normalize_label(self, chinese_label: str) -> str:
        return self.label_mapping.get(chinese_label, chinese_label)

    def load_labels_from_excel(self, excel_path: str = None) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]:
        if excel_path is None:
            excel_path = self.config.get_label_file_path()

        print(f"Loading labels from {excel_path}...")

        if str(excel_path).endswith('.csv'):
            df = pd.read_csv(excel_path)
        else:
            df = pd.read_excel(excel_path)

        # Use generic label loading based on dataset configuration
        wsi_name_col = self.config.wsi_name_column
        label_col = self.config.label_column

        # Check if required columns exist
        if wsi_name_col not in df.columns:
            raise ValueError(f"WSI name column '{wsi_name_col}' not found in file")
        if label_col not in df.columns:
            raise ValueError(f"Label column '{label_col}' not found in file")

        # 分别收集精确标签和模糊标签的WSI
        precise_wsi_labels = {}  # 用于train/val划分
        uncertain_wsi_labels = {}  # 直接作为test

        for _, row in df.iterrows():
            # Get WSI name and label from configured columns
            wsi_name = str(row[wsi_name_col])
            label = str(row[label_col])

            # Handle different dataset formats
            if self.config.name == 'hupo_cancer':
                # Legacy hupo_cancer format: extract number and create filename
                pathology_id = str(row['病理号'])
                scan_block = str(row['扫描蜡块'])
                first_number = self._extract_first_number(scan_block)
                if not first_number:
                    print(f"Warning: No number found in scan block '{scan_block}' for {pathology_id}")
                    continue
                wsi_name = f"{pathology_id}-{first_number}"
                # Normalize Chinese label to English
                label = self._normalize_label(label)

            # 根据标签类型分类
            if label in self.precise_labels:
                precise_wsi_labels[wsi_name] = label
            elif hasattr(self, 'uncertain_labels') and label in self.uncertain_labels:
                # 壶腹癌的模糊标签直接作为test
                uncertain_wsi_labels[wsi_name] = label
            elif not hasattr(self, 'uncertain_labels'):
                # 其他数据集没有uncertain_labels，所有标签都当作precise
                precise_wsi_labels[wsi_name] = label

        print(f"Loaded WSI labels:")
        print(f"  Precise labels: {len(precise_wsi_labels)} WSIs")
        print(f"  Uncertain labels: {len(uncertain_wsi_labels)} WSIs")

        # 对精确标签进行train/val/test划分
        if len(precise_wsi_labels) > 0:
            data_splitter = create_data_splitter(self.config)
            train_labels, val_labels, precise_test_labels = data_splitter.split_data(df, precise_wsi_labels)
        else:
            train_labels, val_labels, precise_test_labels = {}, {}, {}

        # 合并测试集：精确标签的test部分 + 所有模糊标签
        final_test_labels = {}
        final_test_labels.update(precise_test_labels)  # 从精确标签中划分的test
        final_test_labels.update(uncertain_wsi_labels)  # 所有模糊标签

        print(f"Final label statistics:")
        print(f"  Train: {len(train_labels)} WSIs (precise only)")
        print(f"  Val: {len(val_labels)} WSIs (precise only)")
        print(f"  Test: {len(final_test_labels)} WSIs ({len(precise_test_labels)} precise + {len(uncertain_wsi_labels)} uncertain)")
        print(f"  Precise label types: {set(precise_wsi_labels.values()) if precise_wsi_labels else set()}")
        print(f"  Uncertain label types: {set(uncertain_wsi_labels.values()) if uncertain_wsi_labels else set()}")

        return train_labels, val_labels, final_test_labels



    def _extract_first_number(self, scan_block: str) -> str:
        import re
        match = re.search(r'\d+', scan_block)
        return match.group() if match else ""

    def _normalize_label(self, chinese_label: str) -> str:
        return self.label_mapping.get(chinese_label, chinese_label)

    def load_features(self, features_dir: str = None, feature_type: str = 'original') -> Dict[str, np.ndarray]:
        print(f"Loading {feature_type} features...")

        if feature_type != 'original':
            raise ValueError(f"Only 'original' feature type is supported, got: {feature_type}")

        features_dict = {}
        feature_dir = features_dir or str(self.config.raw_features_path)
        suffix = f'.{self.config.feature_format}'

        if not os.path.exists(feature_dir):
            raise ValueError(f"Feature directory not found: {feature_dir}")

        loaded_count = 0
        for file_path in os.listdir(feature_dir):
            if file_path.endswith(suffix):
                wsi_name = file_path[:-4]  # Remove .npy

                try:
                    features = np.load(os.path.join(feature_dir, file_path))
                    features_dict[wsi_name] = features.astype(np.float32)
                    loaded_count += 1
                except Exception as e:
                    print(f"Warning: Failed to load {file_path}: {e}")

        print(f"Loaded {loaded_count} feature files")
        if loaded_count > 0:
            patch_counts = [features.shape[0] for features in features_dict.values()]
            feature_dims = [features.shape[1] for features in features_dict.values()]

            print(f"Feature dimension: {feature_dims[0]}D (consistent across all WSIs)")
            print(f"Patch count statistics:")
            print(f"  Min patches per WSI: {min(patch_counts)}")
            print(f"  Max patches per WSI: {max(patch_counts)}")
            print(f"  Mean patches per WSI: {sum(patch_counts)/len(patch_counts):.1f}")
            print(f"  Total patches: {sum(patch_counts)}")

        return features_dict

    def create_datasets(self, features_dict: Dict[str, np.ndarray],
                       train_labels: Dict[str, str],
                       val_labels: Dict[str, str],
                       test_labels: Dict[str, str],
                       train_val_split: float = 0.8) -> Tuple[Dict, Dict, Dict]:
        # Get WSIs with precise labels that have features
        precise_wsis = []
        for wsi_name in train_labels.keys():
            if wsi_name in features_dict:
                precise_wsis.append(wsi_name)

        # Split precise WSIs into train and val
        np.random.shuffle(precise_wsis)
        split_idx = int(len(precise_wsis) * train_val_split)
        train_wsis = precise_wsis[:split_idx]
        val_wsis = precise_wsis[split_idx:]

        # Create label to index mapping for precise labels only
        unique_precise_labels = sorted(list(set(train_labels.values())))
        label_to_idx = {label: idx for idx, label in enumerate(unique_precise_labels)}

        # Create datasets
        train_data = {
            'features': {wsi: features_dict[wsi] for wsi in train_wsis},
            'labels': {wsi: label_to_idx[train_labels[wsi]] for wsi in train_wsis},
            'label_names': unique_precise_labels
        }

        val_data = {
            'features': {wsi: features_dict[wsi] for wsi in val_wsis},
            'labels': {wsi: label_to_idx[val_labels[wsi]] for wsi in val_wsis},
            'label_names': unique_precise_labels
        }

        # Test data includes only uncertain/general label WSIs with features (NOT train/val data)
        test_wsis = [wsi for wsi in test_labels.keys() if wsi in features_dict]
        test_data = {
            'features': {wsi: features_dict[wsi] for wsi in test_wsis},
            'labels': {wsi: test_labels[wsi] for wsi in test_wsis},  # Keep original string labels
            'label_names': sorted(list(set(test_labels.values()))),
            'wsi_names': test_wsis  # 添加WSI名称列表
        }

        print(f"Dataset split:")
        print(f"  Train: {len(train_data['features'])} WSIs")
        print(f"  Val: {len(val_data['features'])} WSIs")
        print(f"  Test: {len(test_data['features'])} WSIs")
        print(f"  Precise label classes: {unique_precise_labels}")

        return train_data, val_data, test_data


def cluster_all_patches(features_dict: Dict[str, np.ndarray], n_clusters: int = 10,
                       save_path: str = None, dataset_name: str = None) -> Dict[str, np.ndarray]:
    print(f"Step 1: Clustering all patches into {n_clusters} clusters...")

    # Combine all patch features
    all_patches = []
    wsi_patch_counts = []

    for wsi_name, features in features_dict.items():
        all_patches.append(features)
        wsi_patch_counts.append(len(features))

    all_patches = np.vstack(all_patches)
    print(f"Total patches: {all_patches.shape[0]}, Feature dim: {all_patches.shape[1]}")

    # Clustering
    clusterer = MiniBatchKMeans(
        n_clusters=n_clusters,
        random_state=None,
        batch_size=min(1000, len(all_patches) // 10),
        n_init=10
    )

    cluster_labels = clusterer.fit_predict(all_patches)

    # Save clusterer
    if save_path:
        # If dataset_name provided, use dataset-specific path
        if dataset_name and not os.path.isabs(save_path):
            config = get_dataset_config(dataset_name)
            save_path = str(config.models_path / os.path.basename(save_path))

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        with open(save_path, 'wb') as f:
            pickle.dump(clusterer, f)
        print(f"Clusterer saved to {save_path}")

    # Assign cluster labels back to each WSI
    cluster_labels_dict = {}
    start_idx = 0

    for wsi_name, patch_count in zip(features_dict.keys(), wsi_patch_counts):
        end_idx = start_idx + patch_count
        cluster_labels_dict[wsi_name] = cluster_labels[start_idx:end_idx]
        start_idx = end_idx

    print("Clustering completed!")
    return cluster_labels_dict


def prepare_dataset_for_training(dataset: Dict, cluster_labels_dict: Dict[str, np.ndarray]) -> Tuple[List, List, List]:
    wsi_data = []
    wsi_labels = []
    wsi_names = []

    for wsi_name in dataset['features'].keys():
        if wsi_name in cluster_labels_dict:
            wsi_features = dataset['features'][wsi_name]
            wsi_clusters = cluster_labels_dict[wsi_name]
            wsi_label = dataset['labels'][wsi_name]

            wsi_data.append((wsi_features, wsi_clusters))
            wsi_labels.append(wsi_label)
            wsi_names.append(wsi_name)

    return wsi_data, wsi_labels, wsi_names
