"""
Metrics calculation for DATE-GFN experiments.
"""

import numpy as np
import torch
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
import time


class PerformanceMetrics:
    """Calculate performance metrics for different environments."""
    
    @staticmethod
    def calculate_hypergrid_metrics(trajectories: List[List[np.ndarray]], 
                                  environment) -> Dict[str, float]:
        """Calculate metrics specific to Hypergrid environment."""
        if not trajectories:
            return {
                'rel_l1_error': 1.0,
                'modes_discovered': 0,
                'mode_coverage': 0.0,
                'top_100_reward': 0.0
            }
        
        # L1 error
        l1_error = environment.calculate_l1_error(trajectories)
        
        # Mode coverage
        coverage, num_modes = environment.calculate_mode_coverage(trajectories)
        
        # Top-100 reward (reward of best trajectories)
        rewards = [environment.get_reward(traj[-1]) for traj in trajectories]
        top_rewards = sorted(rewards, reverse=True)[:min(100, len(rewards))]
        top_100_reward = np.mean(top_rewards) if top_rewards else 0.0
        
        return {
            'rel_l1_error': l1_error,
            'modes_discovered': num_modes,
            'mode_coverage': coverage,
            'top_100_reward': top_100_reward
        }
    
    
    @staticmethod
    def calculate_general_metrics(trajectories: List[List[np.ndarray]], 
                                environment) -> Dict[str, float]:
        """Calculate general metrics applicable to any environment."""
        if not trajectories:
            return {
                'mean_reward': 0.0,
                'max_reward': 0.0,
                'reward_std': 0.0,
                'mean_length': 0.0,
                'success_rate': 0.0
            }
        
        # Reward statistics
        rewards = [environment.get_reward(traj[-1]) for traj in trajectories]
        
        # Trajectory lengths
        lengths = [len(traj) for traj in trajectories]
        
        # Success rate (trajectories that reach terminal states)
        successes = [environment.is_terminal(traj[-1]) for traj in trajectories]
        success_rate = np.mean(successes)
        
        return {
            'mean_reward': np.mean(rewards),
            'max_reward': np.max(rewards),
            'reward_std': np.std(rewards),
            'mean_length': np.mean(lengths),
            'success_rate': success_rate
        }


class DiversityMetrics:
    """Calculate diversity metrics for trajectory sets."""
    
    @staticmethod
    def calculate_diversity(trajectories: List[List[np.ndarray]], 
                          environment, 
                          diversity_type: str = 'auto') -> float:
        """
        Calculate diversity based on environment type.
        
        Args:
            trajectories: List of trajectories
            environment: Environment object
            diversity_type: 'auto', 'hamming', 'cosine', 'euclidean'
        """
        if len(trajectories) < 2:
            return 0.0
        
        if diversity_type == 'auto':
            # Use environment-specific diversity if available
            if hasattr(environment, 'calculate_diversity'):
                return environment.calculate_diversity(trajectories)
            else:
                diversity_type = 'euclidean'
        
        final_states = [traj[-1] for traj in trajectories]
        
        if diversity_type == 'hamming':
            return DiversityMetrics._hamming_diversity(final_states)
        elif diversity_type == 'cosine':
            return DiversityMetrics._cosine_diversity(final_states)
        elif diversity_type == 'euclidean':
            return DiversityMetrics._euclidean_diversity(final_states)
        else:
            raise ValueError(f"Unknown diversity type: {diversity_type}")
    
    @staticmethod
    def _hamming_diversity(states: List[np.ndarray]) -> float:
        """Calculate Hamming diversity."""
        total_distance = 0.0
        num_pairs = 0
        
        for i in range(len(states)):
            for j in range(i + 1, len(states)):
                distance = np.sum(states[i] != states[j])
                total_distance += distance
                num_pairs += 1
        
        max_distance = len(states[0])  # Maximum possible Hamming distance
        avg_distance = total_distance / num_pairs if num_pairs > 0 else 0
        
        return avg_distance / max_distance
    
    @staticmethod
    def _cosine_diversity(states: List[np.ndarray]) -> float:
        """Calculate cosine diversity."""
        total_distance = 0.0
        num_pairs = 0
        
        for i in range(len(states)):
            for j in range(i + 1, len(states)):
                state1, state2 = states[i], states[j]
                norm1, norm2 = np.linalg.norm(state1), np.linalg.norm(state2)
                
                if norm1 > 1e-8 and norm2 > 1e-8:
                    cosine_sim = np.dot(state1, state2) / (norm1 * norm2)
                    cosine_dist = 1 - cosine_sim
                else:
                    cosine_dist = 1.0
                
                total_distance += cosine_dist
                num_pairs += 1
        
        return total_distance / num_pairs if num_pairs > 0 else 0.0
    
    @staticmethod
    def _euclidean_diversity(states: List[np.ndarray]) -> float:
        """Calculate normalized Euclidean diversity."""
        total_distance = 0.0
        num_pairs = 0
        
        for i in range(len(states)):
            for j in range(i + 1, len(states)):
                distance = np.linalg.norm(states[i] - states[j])
                total_distance += distance
                num_pairs += 1
        
        # Normalize by maximum possible distance
        state_dim = len(states[0])
        max_distance = np.sqrt(state_dim)  # Assuming normalized states
        avg_distance = total_distance / num_pairs if num_pairs > 0 else 0
        
        return min(avg_distance / max_distance, 1.0)
    
    @staticmethod
    def calculate_entropy(trajectories: List[List[np.ndarray]], 
                         environment, 
                         bins: int = 10) -> float:
        """Calculate entropy of final state distribution."""
        if not trajectories:
            return 0.0
        
        final_states = np.array([traj[-1] for traj in trajectories])
        
        # Create histogram in each dimension
        total_entropy = 0.0
        
        for dim in range(final_states.shape[1]):
            values = final_states[:, dim]
            hist, _ = np.histogram(values, bins=bins, density=True)
            
            # Calculate entropy
            hist = hist + 1e-8  # Avoid log(0)
            hist = hist / np.sum(hist)  # Normalize
            entropy = -np.sum(hist * np.log2(hist))
            total_entropy += entropy
        
        return total_entropy
    
    @staticmethod
    def calculate_mode_collapse_metric(trajectories: List[List[np.ndarray]], 
                                     threshold: float = 0.1) -> float:
        """
        Calculate mode collapse metric.
        
        Returns fraction of trajectories that end in the most common final state cluster.
        """
        if len(trajectories) < 2:
            return 1.0
        
        final_states = np.array([traj[-1] for traj in trajectories])
        
        # Cluster final states
        from sklearn.cluster import DBSCAN
        
        clustering = DBSCAN(eps=threshold, min_samples=1).fit(final_states)
        labels = clustering.labels_
        
        # Find largest cluster
        unique_labels, counts = np.unique(labels, return_counts=True)
        max_cluster_size = np.max(counts)
        
        # Mode collapse metric: fraction in largest cluster
        mode_collapse = max_cluster_size / len(trajectories)
        
        return mode_collapse


class EfficiencyMetrics:
    """Calculate efficiency and computational metrics."""
    
    def __init__(self):
        self.timing_data = defaultdict(list)
        self.step_times = []
        self.memory_usage = []
        
    def start_timing(self, operation: str):
        """Start timing an operation."""
        if not hasattr(self, '_start_times'):
            self._start_times = {}
        self._start_times[operation] = time.time()
    
    def end_timing(self, operation: str):
        """End timing an operation."""
        if hasattr(self, '_start_times') and operation in self._start_times:
            elapsed = time.time() - self._start_times[operation]
            self.timing_data[operation].append(elapsed)
            del self._start_times[operation]
            return elapsed
        return 0.0
    
    def log_step_time(self, step_time: float):
        """Log time for a training step."""
        self.step_times.append(step_time)
    
    def log_memory_usage(self, memory_mb: float):
        """Log memory usage in MB."""
        self.memory_usage.append(memory_mb)
    
    def get_efficiency_metrics(self) -> Dict[str, float]:
        """Get efficiency metrics summary."""
        metrics = {}
        
        # Timing statistics
        for operation, times in self.timing_data.items():
            if times:
                metrics[f"{operation}_mean_time"] = np.mean(times)
                metrics[f"{operation}_total_time"] = np.sum(times)
                metrics[f"{operation}_std_time"] = np.std(times)
        
        # Step timing
        if self.step_times:
            metrics['mean_step_time'] = np.mean(self.step_times)
            metrics['total_training_time'] = np.sum(self.step_times)
            metrics['steps_per_second'] = 1.0 / np.mean(self.step_times) if self.step_times else 0
        
        # Memory usage
        if self.memory_usage:
            metrics['peak_memory_mb'] = np.max(self.memory_usage)
            metrics['mean_memory_mb'] = np.mean(self.memory_usage)
        
        return metrics
    
    def calculate_sample_efficiency(self, performance_history: List[float], 
                                  target_performance: float) -> Dict[str, int]:
        """Calculate sample efficiency metrics."""
        metrics = {}
        
        # Steps to reach target performance
        for i, perf in enumerate(performance_history):
            if perf >= target_performance:
                metrics['steps_to_target'] = i + 1
                break
        else:
            metrics['steps_to_target'] = len(performance_history)
        
        # Steps to reach 50%, 90%, 95% of final performance
        if performance_history:
            final_perf = performance_history[-1]
            
            for threshold in [0.5, 0.9, 0.95]:
                target = threshold * final_perf
                for i, perf in enumerate(performance_history):
                    if perf >= target:
                        metrics[f'steps_to_{int(threshold*100)}pct'] = i + 1
                        break
                else:
                    metrics[f'steps_to_{int(threshold*100)}pct'] = len(performance_history)
        
        return metrics
    
    def calculate_convergence_metrics(self, performance_history: List[float], 
                                    window_size: int = 100) -> Dict[str, float]:
        """Calculate convergence-related metrics."""
        if len(performance_history) < window_size:
            return {'convergence_variance': float('inf'), 'is_converged': False}
        
        # Look at variance in recent window
        recent_performance = performance_history[-window_size:]
        convergence_variance = np.var(recent_performance)
        
        # Check if converged (low variance)
        is_converged = convergence_variance < 0.01  # Threshold for convergence
        
        # Calculate improvement rate
        if len(performance_history) >= 2 * window_size:
            early_performance = np.mean(performance_history[:window_size])
            recent_performance_mean = np.mean(recent_performance)
            improvement_rate = (recent_performance_mean - early_performance) / early_performance
        else:
            improvement_rate = 0.0
        
        return {
            'convergence_variance': convergence_variance,
            'is_converged': is_converged,
            'improvement_rate': improvement_rate
        }


class ExperimentTracker:
    """High-level experiment tracking with metric aggregation."""
    
    def __init__(self, experiment_name: str):
        self.experiment_name = experiment_name
        self.performance_metrics = PerformanceMetrics()
        self.diversity_metrics = DiversityMetrics()
        self.efficiency_metrics = EfficiencyMetrics()
        
        # Storage for all metrics
        self.metrics_history = defaultdict(list)
        self.step_results = []
        
    def log_step_results(self, step: int, trajectories: List[List[np.ndarray]], 
                        environment, method_name: str, 
                        additional_metrics: Optional[Dict[str, float]] = None):
        """Log results for a single training step."""
        step_start_time = time.time()
        
        # Calculate all metrics
        perf_metrics = self.performance_metrics.calculate_general_metrics(trajectories, environment)
        
        # Environment-specific metrics
        if hasattr(environment, 'calculate_l1_error'):  # Hypergrid
            env_metrics = self.performance_metrics.calculate_hypergrid_metrics(trajectories, environment)
        else:
            env_metrics = {}
        
        # Diversity metrics
        diversity = self.diversity_metrics.calculate_diversity(trajectories, environment)
        entropy = self.diversity_metrics.calculate_entropy(trajectories, environment)
        mode_collapse = self.diversity_metrics.calculate_mode_collapse_metric(trajectories)
        
        # Combine all metrics
        step_metrics = {
            'step': step,
            'method': method_name,
            'num_trajectories': len(trajectories),
            'diversity': diversity,
            'entropy': entropy,
            'mode_collapse': mode_collapse,
            **perf_metrics,
            **env_metrics
        }
        
        if additional_metrics:
            step_metrics.update(additional_metrics)
        
        # Store metrics
        self.step_results.append(step_metrics)
        for key, value in step_metrics.items():
            if isinstance(value, (int, float)):
                self.metrics_history[key].append(value)
        
        # Log timing
        step_time = time.time() - step_start_time
        self.efficiency_metrics.log_step_time(step_time)
        
        return step_metrics
    
    def get_summary_statistics(self, last_n_steps: Optional[int] = None) -> Dict[str, Any]:
        """Get summary statistics for the experiment."""
        if last_n_steps is not None:
            relevant_results = self.step_results[-last_n_steps:]
        else:
            relevant_results = self.step_results
        
        if not relevant_results:
            return {}
        
        summary = {'experiment_name': self.experiment_name}
        
        # Aggregate numeric metrics
        numeric_metrics = defaultdict(list)
        for result in relevant_results:
            for key, value in result.items():
                if isinstance(value, (int, float)) and key not in ['step', 'num_trajectories']:
                    numeric_metrics[key].append(value)
        
        # Calculate statistics
        for metric_name, values in numeric_metrics.items():
            if values:
                summary[f'{metric_name}_mean'] = np.mean(values)
                summary[f'{metric_name}_std'] = np.std(values)
                summary[f'{metric_name}_max'] = np.max(values)
                summary[f'{metric_name}_min'] = np.min(values)
                summary[f'{metric_name}_final'] = values[-1]
        
        # Add efficiency metrics
        efficiency_metrics = self.efficiency_metrics.get_efficiency_metrics()
        summary.update(efficiency_metrics)
        
        # Add sample efficiency if performance metric available
        if 'mean_reward' in numeric_metrics:
            sample_eff = self.efficiency_metrics.calculate_sample_efficiency(
                numeric_metrics['mean_reward'], 
                target_performance=0.8  # 80% of max possible
            )
            summary.update(sample_eff)
        
        return summary
    
    def compare_methods(self, other_trackers: List['ExperimentTracker']) -> Dict[str, Any]:
        """Compare this method with other methods."""
        all_trackers = [self] + other_trackers
        comparison = {}
        
        # Get final performance for each method
        for tracker in all_trackers:
            if tracker.step_results:
                last_result = tracker.step_results[-1]
                method_name = last_result.get('method', 'Unknown')
                
                # Key metrics for comparison
                for metric in ['mean_reward', 'diversity', 'rel_l1_error']:
                    if metric in last_result:
                        comparison[f'{method_name}_{metric}'] = last_result[metric]
        
        return comparison
