"""
Utility functions for DATE-GFN experiments.
"""

import torch
import numpy as np
import random
import os
import json
import pickle
from typing import Any, Dict, Optional, Union, List
import time
from pathlib import Path
import logging


def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # For deterministic behavior (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def setup_logging(log_file: Optional[str] = None, level: int = logging.INFO):
    """Setup logging configuration."""
    handlers = [logging.StreamHandler()]
    
    if log_file:
        os.makedirs(os.path.dirname(log_file), exist_ok=True)
        handlers.append(logging.FileHandler(log_file))
    
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=handlers
    )
    
    return logging.getLogger(__name__)


def save_checkpoint(state: Dict[str, Any], filepath: Union[str, Path], 
                   is_best: bool = False):
    """Save model checkpoint."""
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    torch.save(state, filepath)
    
    if is_best:
        best_path = filepath.parent / f"best_{filepath.name}"
        torch.save(state, best_path)


def load_checkpoint(filepath: Union[str, Path]) -> Dict[str, Any]:
    """Load model checkpoint."""
    filepath = Path(filepath)
    if not filepath.exists():
        raise FileNotFoundError(f"Checkpoint not found: {filepath}")
    
    return torch.load(filepath, map_location='cpu')


def save_results(results: Dict[str, Any], filepath: Union[str, Path]):
    """Save experiment results to JSON file."""
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    # Convert numpy arrays to lists for JSON serialization
    def convert_numpy(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {k: convert_numpy(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_numpy(item) for item in obj]
        else:
            return obj
    
    serializable_results = convert_numpy(results)
    
    with open(filepath, 'w') as f:
        json.dump(serializable_results, f, indent=2)


def load_results(filepath: Union[str, Path]) -> Dict[str, Any]:
    """Load experiment results from JSON file."""
    filepath = Path(filepath)
    if not filepath.exists():
        raise FileNotFoundError(f"Results file not found: {filepath}")
    
    with open(filepath, 'r') as f:
        return json.load(f)


def save_pickle(obj: Any, filepath: Union[str, Path]):
    """Save object using pickle."""
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    
    with open(filepath, 'wb') as f:
        pickle.dump(obj, f)


def load_pickle(filepath: Union[str, Path]) -> Any:
    """Load object using pickle."""
    filepath = Path(filepath)
    if not filepath.exists():
        raise FileNotFoundError(f"Pickle file not found: {filepath}")
    
    with open(filepath, 'rb') as f:
        return pickle.load(f)


class Timer:
    """Context manager for timing code execution."""
    
    def __init__(self, name: str = "Operation"):
        self.name = name
        self.start_time = None
        self.end_time = None
        
    def __enter__(self):
        self.start_time = time.time()
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end_time = time.time()
        
    @property
    def elapsed(self) -> float:
        """Get elapsed time in seconds."""
        if self.end_time is None:
            return time.time() - self.start_time
        return self.end_time - self.start_time


class ProgressTracker:
    """Track training progress and estimate completion time."""
    
    def __init__(self, total_steps: int, log_every: int = 100):
        self.total_steps = total_steps
        self.log_every = log_every
        self.start_time = time.time()
        self.step_times = []
        
    def update(self, step: int, metrics: Optional[Dict[str, float]] = None):
        """Update progress tracker."""
        current_time = time.time()
        self.step_times.append(current_time)
        
        if step % self.log_every == 0 or step == self.total_steps:
            self._log_progress(step, metrics)
    
    def _log_progress(self, step: int, metrics: Optional[Dict[str, float]] = None):
        """Log current progress."""
        elapsed = time.time() - self.start_time
        progress = step / self.total_steps
        
        if len(self.step_times) > 1:
            recent_times = self.step_times[-min(self.log_every, len(self.step_times)):]
            avg_step_time = np.mean(np.diff(recent_times))
            remaining_steps = self.total_steps - step
            eta = remaining_steps * avg_step_time
        else:
            eta = 0
        
        message = f"Step {step}/{self.total_steps} ({progress:.1%}) | "
        message += f"Elapsed: {elapsed:.1f}s | ETA: {eta:.1f}s"
        
        if metrics:
            metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
            message += f" | {metric_str}"
        
        print(message)


def get_device() -> torch.device:
    """Get the best available device."""
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')


def count_parameters(model: torch.nn.Module) -> int:
    """Count the number of parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def calculate_memory_usage() -> Dict[str, float]:
    """Calculate current memory usage."""
    memory_info = {}
    
    if torch.cuda.is_available():
        memory_info['cuda_allocated'] = torch.cuda.memory_allocated() / 1024**3  # GB
        memory_info['cuda_reserved'] = torch.cuda.memory_reserved() / 1024**3    # GB
        memory_info['cuda_max_allocated'] = torch.cuda.max_memory_allocated() / 1024**3  # GB
    
    # Add system memory info if psutil is available
    try:
        import psutil
        memory_info['system_memory_percent'] = psutil.virtual_memory().percent
        memory_info['system_memory_available'] = psutil.virtual_memory().available / 1024**3  # GB
    except ImportError:
        pass
    
    return memory_info


class ConfigValidator:
    """Validate experiment configurations."""
    
    @staticmethod
    def validate_base_config(config: Dict[str, Any]) -> None:
        """Validate base configuration parameters."""
        required_keys = ['seed', 'device', 'log_dir', 'experiment_name']
        for key in required_keys:
            if key not in config:
                raise ValueError(f"Missing required config key: {key}")
    
    @staticmethod
    def validate_training_config(config: Dict[str, Any]) -> None:
        """Validate training configuration parameters."""
        required_keys = ['num_steps', 'batch_size', 'learning_rate']
        for key in required_keys:
            if key not in config:
                raise ValueError(f"Missing required training config key: {key}")
        
        if config['num_steps'] < 2000:
            raise ValueError("num_steps must be at least 2000 for statistical significance")
    
    @staticmethod
    def validate_date_gfn_config(config: Dict[str, Any]) -> None:
        """Validate DATE-GFN specific configuration."""
        required_keys = ['population_size', 'elite_ratio', 'teachability_weight', 'student_updates']
        for key in required_keys:
            if key not in config:
                raise ValueError(f"Missing required DATE-GFN config key: {key}")
        
        if not 0 <= config['elite_ratio'] <= 1:
            raise ValueError("elite_ratio must be between 0 and 1")
        
        if config['teachability_weight'] < 0:
            raise ValueError("teachability_weight must be non-negative")


def create_experiment_directory(base_dir: Union[str, Path], 
                               experiment_name: str,
                               timestamp: bool = True) -> Path:
    """Create experiment directory with timestamp."""
    base_dir = Path(base_dir)
    
    if timestamp:
        timestamp_str = time.strftime("%Y%m%d_%H%M%S")
        exp_dir = base_dir / f"{experiment_name}_{timestamp_str}"
    else:
        exp_dir = base_dir / experiment_name
    
    exp_dir.mkdir(parents=True, exist_ok=True)
    
    # Create subdirectories
    for subdir in ['checkpoints', 'logs', 'results', 'plots']:
        (exp_dir / subdir).mkdir(exist_ok=True)
    
    return exp_dir


def aggregate_results_across_seeds(results: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Aggregate results across multiple random seeds."""
    if not results:
        return {}
    
    aggregated = {}
    
    # Get all metric keys
    all_keys = set()
    for result in results:
        all_keys.update(result.keys())
    
    for key in all_keys:
        values = []
        for result in results:
            if key in result:
                value = result[key]
                if isinstance(value, (int, float)):
                    values.append(value)
                elif isinstance(value, (list, np.ndarray)):
                    values.extend(np.array(value).flatten())
        
        if values:
            aggregated[f"{key}_mean"] = np.mean(values)
            aggregated[f"{key}_std"] = np.std(values)
            aggregated[f"{key}_min"] = np.min(values)
            aggregated[f"{key}_max"] = np.max(values)
            aggregated[f"{key}_median"] = np.median(values)
    
    return aggregated


def statistical_significance_test(group1: List[float], group2: List[float]) -> Dict[str, float]:
    """Perform statistical significance tests between two groups."""
    try:
        from scipy import stats
        
        # Shapiro-Wilk test for normality
        _, p_normal1 = stats.shapiro(group1) if len(group1) >= 3 else (None, 0)
        _, p_normal2 = stats.shapiro(group2) if len(group2) >= 3 else (None, 0)
        
        # Choose appropriate test
        if p_normal1 > 0.05 and p_normal2 > 0.05:
            # Both groups are normal, use t-test
            statistic, p_value = stats.ttest_ind(group1, group2)
            test_used = "t-test"
        else:
            # Non-normal data, use Mann-Whitney U test
            statistic, p_value = stats.mannwhitneyu(group1, group2, alternative='two-sided')
            test_used = "mann-whitney"
        
        # Effect size (Cohen's d)
        pooled_std = np.sqrt(((len(group1) - 1) * np.var(group1, ddof=1) + 
                             (len(group2) - 1) * np.var(group2, ddof=1)) / 
                            (len(group1) + len(group2) - 2))
        cohens_d = (np.mean(group1) - np.mean(group2)) / pooled_std if pooled_std > 0 else 0
        
        return {
            'p_value': p_value,
            'statistic': statistic,
            'cohens_d': cohens_d,
            'test_used': test_used,
            'significant': p_value < 0.05,
            'effect_size': 'large' if abs(cohens_d) > 0.8 else 'medium' if abs(cohens_d) > 0.5 else 'small'
        }
    
    except ImportError:
        # Fallback to simple comparison if scipy not available
        return {
            'mean_diff': np.mean(group1) - np.mean(group2),
            'std_diff': np.std(group1) - np.std(group2),
            'test_used': 'simple_comparison'
        }
