import pandas as pd
import numpy as np
from typing import Dict, Tuple, List
from sklearn.model_selection import train_test_split


class DataSplitter:

    def __init__(self, config):
        self.config = config
        self.split_config = config.data_split_config

    def split_data(self, df: pd.DataFrame, wsi_labels: Dict[str, str]) -> Tuple[Dict, Dict, Dict]:
        method = self.split_config.get('method', 'random')

        if method == 'predefined':
            return self._split_predefined(df, wsi_labels)
        elif method == 'random':
            return self._split_random(wsi_labels)
        else:
            raise ValueError(f"Unknown split method: {method}")

    def _split_predefined(self, df: pd.DataFrame, wsi_labels: Dict[str, str]) -> Tuple[Dict, Dict, Dict]:
        split_column = self.split_config.get('split_column', 'split')
        split_values = self.split_config.get('split_values', {})

        if split_column not in df.columns:
            raise ValueError(f"Split column '{split_column}' not found in data")

        train_value = split_values.get('train', 'train')
        val_value = split_values.get('val', 'val')
        test_value = split_values.get('test', 'test')

        train_labels = {}
        val_labels = {}
        test_labels = {}

        wsi_name_col = self.config.wsi_name_column
        wsi_to_split = {}

        for _, row in df.iterrows():
            wsi_name = str(row[wsi_name_col])

            if self.config.name == 'hupo_cancer':
                pathology_id = str(row['病理号'])
                scan_block = str(row['扫描蜡块'])
                first_number = self._extract_first_number(scan_block)
                if first_number:
                    wsi_name = f"{pathology_id}-{first_number}"

            split_value = str(row[split_column])
            wsi_to_split[wsi_name] = split_value

        for wsi_name, label in wsi_labels.items():
            split_value = wsi_to_split.get(wsi_name, '')

            if split_value == train_value:
                train_labels[wsi_name] = label
            elif split_value == test_value:
                val_labels[wsi_name] = label
            else:
                print(f"Warning: Unknown split value '{split_value}' for WSI {wsi_name}")

        print(f"Predefined split results:")
        print(f"  Train: {len(train_labels)} WSIs (from 'train' split)")
        print(f"  Val: {len(val_labels)} WSIs (from 'test' split)")
        print(f"  Test: 0 WSIs (no test set for predefined splits)")

        return train_labels, val_labels, test_labels
    
    def _split_random(self, wsi_labels: Dict[str, str]) -> Tuple[Dict, Dict, Dict]:
        train_ratio = self.split_config.get('train_ratio', 0.7)
        val_ratio = self.split_config.get('val_ratio', 0.2)
        test_ratio = self.split_config.get('test_ratio', 0.1)

        total_ratio = train_ratio + val_ratio + test_ratio
        if abs(total_ratio - 1.0) > 1e-6:
            raise ValueError(f"Split ratios must sum to 1.0, got {total_ratio}")

        label_to_wsis = {}
        for wsi_name, label in wsi_labels.items():
            if label not in label_to_wsis:
                label_to_wsis[label] = []
            label_to_wsis[label].append(wsi_name)

        train_labels = {}
        val_labels = {}
        test_labels = {}

        for label, wsi_list in label_to_wsis.items():
            wsi_array = np.array(wsi_list)

            if len(wsi_array) < 3:
                print(f"Warning: Only {len(wsi_array)} samples for label '{label}', putting all in train set")
                for wsi in wsi_array:
                    train_labels[wsi] = label
                continue

            if test_ratio > 0:
                train_val_wsis, test_wsis = train_test_split(
                    wsi_array, test_size=test_ratio, random_state=None, stratify=None
                )
                for wsi in test_wsis:
                    test_labels[wsi] = label
            else:
                train_val_wsis = wsi_array

            if val_ratio > 0 and len(train_val_wsis) > 1:
                val_ratio_adjusted = val_ratio / (train_ratio + val_ratio)
                train_wsis, val_wsis = train_test_split(
                    train_val_wsis, test_size=val_ratio_adjusted, random_state=None, stratify=None
                )

                for wsi in train_wsis:
                    train_labels[wsi] = label
                for wsi in val_wsis:
                    val_labels[wsi] = label
            else:
                for wsi in train_val_wsis:
                    train_labels[wsi] = label

        print(f"Random split results:")
        print(f"  Train: {len(train_labels)} WSIs")
        print(f"  Val: {len(val_labels)} WSIs")
        print(f"  Test: {len(test_labels)} WSIs")
        print(f"  Split ratios: train={train_ratio}, val={val_ratio}, test={test_ratio}")

        return train_labels, val_labels, test_labels

    def _extract_first_number(self, scan_block: str) -> str:
        import re
        match = re.search(r'\d+', scan_block)
        return match.group() if match else ""


def create_data_splitter(config):
    return DataSplitter(config)
