from typing import Tuple, Dict, Optional, List
import numpy as np
import torch
import json
import os
import hashlib
import yaml
import glob
import gzip
from sklearn.metrics import balanced_accuracy_score
from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from torch.nn import Module

import warnings
warnings.filterwarnings('ignore', category=FutureWarning, module='sklearn.impute._base')
warnings.filterwarnings('ignore', message='y_pred contains classes not in y_true')

class DataPreprocessor:
    """Handles data preprocessing for local datasets."""
    def __init__(self, data_root: Optional[str] = None, cache_dir: Optional[str] = None):
        """
        Initialize DataPreprocessor.
        
        Args:
            data_root: Root directory containing dataset folders
            cache_dir: Directory for caching processed data
        """
        self.data_root = data_root or os.path.join("thesis", "data", "full_datasets")
        self.cache_dir = cache_dir
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
        self.imputer = SimpleImputer(strategy='constant', keep_empty_features=True)
    
    def get_available_tasks(self) -> List[int]:
        """Get list of available task IDs from local directory."""
        task_dirs = glob.glob(os.path.join(self.data_root, "openml_task_*"))
        task_ids = []
        
        for task_dir in task_dirs:
            try:
                task_id = int(os.path.basename(task_dir).split("_")[-1])
                # Check if this directory has the required files
                required_files = ["X_train.npy.gz", "X_test.npy.gz", "y_train.npy.gz", "y_test.npy.gz"]
                if all(os.path.exists(os.path.join(task_dir, f)) for f in required_files):
                    task_ids.append(task_id)
            except (ValueError, IndexError):
                continue
                
        return sorted(task_ids)
    
    def _load_compressed_numpy(self, filepath: str) -> np.ndarray:
        """Load a gzipped numpy file."""
        with gzip.open(filepath, 'rb') as f:
            return np.load(f)
    
    def get_data(
        self,
        task_id: int,
        test_size: float = 0.2,  # Not used but kept for API compatibility
        seed: int = 42  # Not used but kept for API compatibility
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Load and preprocess data from local files."""
        # Check if we have processed data in cache
        if self.cache_dir:
            cache_path = os.path.join(self.cache_dir, f'task_{task_id}_data.npz')
            if os.path.exists(cache_path):
                data = np.load(cache_path)
                return (data['X_train'], data['X_test'],
                       data['y_train'], data['y_test'])
        
        # Build path to dataset directory
        task_dir = os.path.join(self.data_root, f"openml_task_{task_id}")
        if not os.path.exists(task_dir):
            raise FileNotFoundError(f"Dataset directory not found: {task_dir}")
        
        # Load data
        X_train = self._load_compressed_numpy(os.path.join(task_dir, "X_train.npy.gz"))
        X_test = self._load_compressed_numpy(os.path.join(task_dir, "X_test.npy.gz"))
        y_train = self._load_compressed_numpy(os.path.join(task_dir, "y_train.npy.gz"))
        y_test = self._load_compressed_numpy(os.path.join(task_dir, "y_test.npy.gz"))
        
        # Load metadata to check if we need to preprocess
        metadata_path = os.path.join(task_dir, "metadata.json")
        needs_preprocessing = True
        categorical_features = []
        
        if os.path.exists(metadata_path):
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
                # Check if data is already preprocessed
                if metadata.get("is_preprocessed", False):
                    needs_preprocessing = False
                # Get categorical features if available
                categorical_indicator = metadata.get("categorical_indicator", [])
                categorical_features = [i for i, is_cat in enumerate(categorical_indicator) if is_cat]
        
        if needs_preprocessing:
            # Identify numerical features (all that are not categorical)
            numerical_features = [i for i in range(X_train.shape[1]) if i not in categorical_features]
            
            # Create preprocessing pipelines
            preprocessor = ColumnTransformer(
                transformers=[
                    ('num', Pipeline([
                        ('imputer', SimpleImputer(strategy='median')),
                        ('scaler', StandardScaler())
                    ]), numerical_features),
                    ('cat', Pipeline([
                        ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
                        ('onehot', OneHotEncoder(drop='first', sparse_output=False))
                    ]), categorical_features)
                ],
                remainder='passthrough' if not (numerical_features or categorical_features) else 'drop'
            )
            
            # Transform data
            X_train = preprocessor.fit_transform(X_train).astype(np.float32)
            X_test = preprocessor.transform(X_test).astype(np.float32)
            
            # Encode labels if they're not already encoded
            if not np.issubdtype(y_train.dtype, np.integer):
                label_encoder = LabelEncoder()
                y_train = label_encoder.fit_transform(y_train).astype(np.int64)
                y_test = label_encoder.transform(y_test).astype(np.int64)
        else:
            # Ensure correct dtypes
            X_train = X_train.astype(np.float32)
            X_test = X_test.astype(np.float32)
            y_train = y_train.astype(np.int64)
            y_test = y_test.astype(np.int64)
        
        # Cache processed data
        if self.cache_dir:
            np.savez(
                os.path.join(self.cache_dir, f'task_{task_id}_data.npz'),
                X_train=X_train, X_test=X_test,
                y_train=y_train, y_test=y_test
            )
        
        return X_train, X_test, y_train, y_test

def compute_checksum(state_dict: Dict[str, torch.Tensor]) -> str:
    """Compute SHA-256 checksum of model state dict."""
    hasher = hashlib.sha256()
    for key in sorted(state_dict.keys()):
        hasher.update(state_dict[key].cpu().detach().numpy().tobytes())
    return hasher.hexdigest()

def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    targets_all = []
    
    num_classes = model.config.num_classes  # Get number of classes from model config
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item() * inputs.size(0)
            
            # Ensure predictions are within valid class range
            preds = outputs.argmax(dim=1).clamp(0, num_classes - 1).cpu().numpy()
            predictions.extend(preds)
            targets_all.extend(targets.cpu().numpy())
    
    return (
        total_loss / len(test_loader.dataset),
        balanced_accuracy_score(targets_all, predictions)
    )

def save_results(results: Dict, filepath: str, save_yaml: bool = True) -> None:
    """Save results with optional YAML export for hyperparameters."""
    try:
        # Ensure directory exists
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        
        # Save full results as JSON
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=4)
        
        # Save hyperparameters as YAML if requested
        if save_yaml:
            yaml_path = os.path.join(
                os.path.dirname(filepath),
                f"task_{results['task_id']}_hyperparams.yml"
            )
            with open(yaml_path, 'w') as f:
                yaml.dump(results['best_config'], f, default_flow_style=False)
        
    except Exception as e:
        print(f"Error saving results: {e}")
        # Save backup
        with open(filepath + '.backup', 'w') as f:
            json.dump(results, f, indent=4)

def get_device(device_str: Optional[str] = None) -> torch.device:
    """Get PyTorch device with fallback to CPU."""
    if device_str == 'cuda' and torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')