"""
Global feature normalizer
Calculates normalization parameters based on all datasets to ensure feature consistency across datasets
"""

import numpy as np
import json
import pickle
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union

# Optional dependency
try:
    import pandas as pd
    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False

from .batch_feature_extractor import BatchFeatureExtractor

class SimpleScaler:
    """Simple base class for normalizers"""
    def __init__(self):
        self.is_fitted = False
    def fit(self, X):
        raise NotImplementedError
    def transform(self, X):
        raise NotImplementedError
    def fit_transform(self, X):
        return self.fit(X).transform(X)

class StandardScaler(SimpleScaler):
    """Standardization: zero mean, unit variance"""
    def __init__(self):
        super().__init__()
        self.mean_ = None
        self.std_ = None
    def fit(self, X):
        X = np.asarray(X)
        self.mean_ = np.mean(X, axis=0)
        self.std_ = np.std(X, axis=0)
        # Avoid division by zero
        self.std_ = np.where(self.std_ == 0, 1, self.std_)
        self.is_fitted = True
        return self
    def transform(self, X):
        if not self.is_fitted:
            raise ValueError("Scaler has not been fitted")
        X = np.asarray(X)
        return (X - self.mean_) / self.std_

class MinMaxScaler(SimpleScaler):
    """Min-max normalization: scale to [0,1]"""
    def __init__(self):
        super().__init__()
        self.min_ = None
        self.max_ = None
        self.scale_ = None
    def fit(self, X):
        X = np.asarray(X)
        self.min_ = np.min(X, axis=0)
        self.max_ = np.max(X, axis=0)
        self.scale_ = self.max_ - self.min_
        # Avoid division by zero
        self.scale_ = np.where(self.scale_ == 0, 1, self.scale_)
        self.is_fitted = True
        return self
    def transform(self, X):
        if not self.is_fitted:
            raise ValueError("Scaler has not been fitted")
        X = np.asarray(X)
        return (X - self.min_) / self.scale_

class RobustScaler(SimpleScaler):
    """Robust normalization: based on median and IQR"""
    def __init__(self):
        super().__init__()
        self.median_ = None
        self.iqr_ = None
    def fit(self, X):
        X = np.asarray(X)
        self.median_ = np.median(X, axis=0)
        q25 = np.percentile(X, 25, axis=0)
        q75 = np.percentile(X, 75, axis=0)
        self.iqr_ = q75 - q25
        # Avoid division by zero
        self.iqr_ = np.where(self.iqr_ == 0, 1, self.iqr_)
        self.is_fitted = True
        return self
    def transform(self, X):
        if not self.is_fitted:
            raise ValueError("Scaler has not been fitted")
        X = np.asarray(X)
        return (X - self.median_) / self.iqr_

class GlobalFeatureNormalizer:
    """
    Global feature normalizer
    Calculates normalization parameters based on all dataset features
    """
    def __init__(self, features_dir: str = "features", normalization_method: str = "standard"):
        """
        Initialize the normalizer
        Args:
            features_dir: Feature file directory
            normalization_method: Normalization method ['standard', 'minmax', 'robust']
        """
        self.features_dir = Path(features_dir)
        self.normalization_method = normalization_method
        self.batch_extractor = BatchFeatureExtractor(output_dir=features_dir)
        # Select normalizer
        if normalization_method == "standard":
            self.scaler = StandardScaler()  # zero mean, unit variance
        elif normalization_method == "minmax":
            self.scaler = MinMaxScaler()    # scale to [0,1]
        elif normalization_method == "robust":
            self.scaler = RobustScaler()    # robust to outliers
        else:
            raise ValueError(f"Unknown normalization method: {normalization_method}")
        self.is_fitted = False
        self.feature_names = None
        self.global_stats = None
    def collect_all_features(self, folders: Optional[List[str]] = None) -> Tuple[np.ndarray, List[str], List[str]]:
        """
        Collect feature data from all folders
        Returns:
            (combined_features, file_labels, feature_names)
        """
        if folders is None:
            # Auto-discover all processed folders
            folders = []
            for folder_path in self.features_dir.iterdir():
                if folder_path.is_dir():
                    npz_file = folder_path / f"{folder_path.name}_features.npz"
                    if npz_file.exists():
                        folders.append(folder_path.name)
        print(f"Collecting features from {len(folders)} folders: {folders}")
        all_features = []
        all_labels = []
        feature_names = None
        for folder_name in folders:
            folder_data = self.batch_extractor.load_folder_features(folder_name, 'npz')
            if folder_data is not None:
                features = folder_data['feature_matrix']
                filenames = folder_data['filenames']
                if feature_names is None:
                    feature_names = folder_data['feature_names']
                # Create labels: foldername_filename
                labels = [f"{folder_name}_{filename}" for filename in filenames]
                all_features.append(features)
                all_labels.extend(labels)
                print(f"  {folder_name}: {len(filenames)} files")
            else:
                print(f"  Warning: failed to load features for {folder_name}")
        if all_features:
            combined_features = np.vstack(all_features)
            print(f"Total collected: {combined_features.shape[0]} files, {combined_features.shape[1]} feature dimensions")
            return combined_features, all_labels, feature_names
        else:
            return np.array([]), [], []
    def fit(self, folders: Optional[List[str]] = None) -> 'GlobalFeatureNormalizer':
        """
        Compute normalization parameters based on all data
        """
        print(f"Start computing global normalization parameters (method: {self.normalization_method})")
        # Collect all features
        features, labels, feature_names = self.collect_all_features(folders)
        if len(features) == 0:
            raise ValueError("No feature data found")
        self.feature_names = feature_names
        # Fit normalizer
        self.scaler.fit(features)
        self.is_fitted = True
        # Compute and save global statistics
        self.global_stats = self._calculate_global_stats(features)
        print(f"Normalization parameter computation complete, based on {len(features)} samples")
        print(f"Feature statistics:")
        for i, name in enumerate(feature_names[:10]):  # Show first 10 features
            print(f"  {name}: mean={self.global_stats['mean'][i]:.4f}, "
                  f"std={self.global_stats['std'][i]:.4f}")
        return self
    def _calculate_global_stats(self, features: np.ndarray) -> Dict:
        """Compute global statistics"""
        return {
            'mean': np.mean(features, axis=0),
            'std': np.std(features, axis=0),
            'min': np.min(features, axis=0),
            'max': np.max(features, axis=0),
            'median': np.median(features, axis=0),
            'q25': np.percentile(features, 25, axis=0),
            'q75': np.percentile(features, 75, axis=0),
            'sample_count': features.shape[0],
            'feature_count': features.shape[1]
        }
    def transform(self, features: np.ndarray) -> np.ndarray:
        """
        Apply normalization transformation
        """
        if not self.is_fitted:
            raise ValueError("Normalizer has not been fitted. Call fit() first.")
        return self.scaler.transform(features)
    def fit_transform(self, folders: Optional[List[str]] = None) -> Tuple[np.ndarray, List[str], List[str]]:
        """
        Fit and transform all features
        """
        # Collect features
        features, labels, feature_names = self.collect_all_features(folders)
        if len(features) == 0:
            return np.array([]), [], []
        # Fit normalizer
        self.feature_names = feature_names
        normalized_features = self.scaler.fit_transform(features)
        self.is_fitted = True
        # Compute statistics
        self.global_stats = self._calculate_global_stats(features)
        print(f"Normalization complete: {normalized_features.shape}")
        print(f"Post-normalization stats - mean: {np.mean(normalized_features, axis=0)[:5]}")
        print(f"Post-normalization stats - std: {np.std(normalized_features, axis=0)[:5]}")
        return normalized_features, labels, feature_names
    def transform_folder(self, folder_name: str) -> Optional[np.ndarray]:
        """
        Normalize features for a single folder
        """
        if not self.is_fitted:
            raise ValueError("Normalizer has not been fitted. Call fit() first.")
        folder_data = self.batch_extractor.load_folder_features(folder_name, 'npz')
        if folder_data is None:
            return None
        features = folder_data['feature_matrix']
        normalized_features = self.scaler.transform(features)
        return normalized_features
    def save_normalizer(self, output_path: str):
        """Save normalizer and statistics"""
        if not self.is_fitted:
            raise ValueError("Normalizer has not been fitted. Call fit() first.")
        normalizer_data = {
            'scaler': self.scaler,
            'normalization_method': self.normalization_method,
            'feature_names': self.feature_names,
            'global_stats': self.global_stats,
            'is_fitted': self.is_fitted
        }
        with open(output_path, 'wb') as f:
            pickle.dump(normalizer_data, f)
        # Also save human-readable statistics
        stats_path = Path(output_path).with_suffix('.json')
        readable_stats = {
            'normalization_method': self.normalization_method,
            'feature_names': self.feature_names,
            'global_stats': {
                key: value.tolist() if isinstance(value, np.ndarray) else value
                for key, value in self.global_stats.items()
            }
        }
        with open(stats_path, 'w') as f:
            json.dump(readable_stats, f, indent=2)
        print(f"Normalizer saved to: {output_path}")
        print(f"Statistics saved to: {stats_path}")
    @classmethod
    def load_normalizer(cls, normalizer_path: str) -> 'GlobalFeatureNormalizer':
        """Load a saved normalizer"""
        with open(normalizer_path, 'rb') as f:
            normalizer_data = pickle.load(f)
        # Create new instance
        normalizer = cls.__new__(cls)
        normalizer.scaler = normalizer_data['scaler']
        normalizer.normalization_method = normalizer_data['normalization_method']
        normalizer.feature_names = normalizer_data['feature_names']
        normalizer.global_stats = normalizer_data['global_stats']
        normalizer.is_fitted = normalizer_data['is_fitted']
        print(f"Normalizer loaded: {normalizer_path}")
        print(f"Method: {normalizer.normalization_method}")
        print(f"Number of features: {len(normalizer.feature_names)}")
        return normalizer
    def get_feature_importance_by_variance(self):
        """
        Analyze feature importance by variance (original variance before normalization)
        """
        if self.global_stats is None:
            raise ValueError("No global statistics available. Call fit() first.")
        variance = self.global_stats['std'] ** 2
        range_values = self.global_stats['max'] - self.global_stats['min']
        # Create ranking indices
        variance_ranks = np.argsort(variance)[::-1]  # Descending
        range_ranks = np.argsort(range_values)[::-1]
        importance_data = []
        for i, name in enumerate(self.feature_names):
            importance_data.append({
                'feature_name': name,
                'variance': variance[i],
                'std': self.global_stats['std'][i],
                'mean': self.global_stats['mean'][i],
                'range': range_values[i],
                'variance_rank': np.where(variance_ranks == i)[0][0] + 1,
                'range_rank': np.where(range_ranks == i)[0][0] + 1
            })
        # Sort by variance
        importance_data.sort(key=lambda x: x['variance'], reverse=True)
        if HAS_PANDAS:
            return pd.DataFrame(importance_data)
        else:
            return importance_data
    def create_normalized_dataset(self, output_dir: str = "normalized_features", folders: Optional[List[str]] = None):
        """
        Create normalized dataset files
        """
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        # Fit normalizer if not already
        if not self.is_fitted:
            self.fit(folders)
        # Create normalized version for each folder
        if folders is None:
            folders = []
            for folder_path in self.features_dir.iterdir():
                if folder_path.is_dir():
                    npz_file = folder_path / f"{folder_path.name}_features.npz"
                    if npz_file.exists():
                        folders.append(folder_path.name)
        for folder_name in folders:
            folder_data = self.batch_extractor.load_folder_features(folder_name, 'npz')
            if folder_data is not None:
                # Normalize features
                normalized_features = self.transform(folder_data['feature_matrix'])
                # Save normalized features
                normalized_file = output_path / f"{folder_name}_normalized.npz"
                np.savez_compressed(
                    normalized_file,
                    features=normalized_features,
                    original_features=folder_data['feature_matrix'],
                    filenames=np.array(folder_data['filenames']),
                    feature_names=np.array(self.feature_names),
                    normalization_method=self.normalization_method,
                    folder_name=folder_name
                )
                print(f"Saved normalized features: {normalized_file}")
        # Save combined normalized dataset
        combined_normalized, labels, feature_names = self.fit_transform(folders)
        if len(combined_normalized) > 0:
            combined_file = output_path / "all_datasets_normalized.npz"
            np.savez_compressed(
                combined_file,
                features=combined_normalized,
                labels=np.array(labels),
                feature_names=np.array(feature_names),
                normalization_method=self.normalization_method
            )
            print(f"Saved combined normalized dataset: {combined_file}")
        # Save normalizer
        normalizer_file = output_path / "global_normalizer.pkl"
        self.save_normalizer(str(normalizer_file))

def main():
    """Main function example"""
    import argparse
    parser = argparse.ArgumentParser(description='Global feature normalization')
    parser.add_argument('--features-dir', default='features',
                       help='Features directory')
    parser.add_argument('--output-dir', default='normalized_features',
                       help='Output directory for normalized features')
    parser.add_argument('--method', choices=['standard', 'minmax', 'robust'],
                       default='standard', help='Normalization method')
    parser.add_argument('--folders', nargs='+',
                       help='Specific folders to process (default: all)')
    args = parser.parse_args()
    # Create global normalizer
    normalizer = GlobalFeatureNormalizer(
        features_dir=args.features_dir,
        normalization_method=args.method
    )
    # Create normalized dataset
    normalizer.create_normalized_dataset(
        output_dir=args.output_dir,
        folders=args.folders
    )
    # Show feature importance analysis
    print("\n=== Feature importance analysis (by variance) ===")
    importance_df = normalizer.get_feature_importance_by_variance()
    print(importance_df.head(10))

if __name__ == "__main__":
    main()