#!/usr/bin/env python3
"""
Metrics Recording Utilities for GFedCL

This module provides utilities to record:
1. Computation overhead (training time per epoch)
2. IS (Inception Score) for synthetic data encodings
3. FID (Fréchet Inception Distance) for synthetic data encodings
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import os
import json
import logging
from scipy import linalg
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd

logger = logging.getLogger('GFedCL')


class ComputationTracker:
    """
    Tracks computation time for different phases of training
    """
    def __init__(self, output_dir='./dump/metrics'):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        self.timing_data = defaultdict(list)
        self.epoch_times = defaultdict(lambda: defaultdict(list))  # {task_id: {round_id: [epoch_times]}}
        
    def start_timer(self):
        """Start a timer and return the start time"""
        return time.time()
    
    def end_timer(self, start_time, phase_name, task_id=None, round_id=None, client_id=None):
        """
        End a timer and record the elapsed time
        
        Args:
            start_time: Start time from start_timer()
            phase_name: Name of the phase (e.g., 'client_training', 'server_training')
            task_id: Current task ID
            round_id: Current round ID
            client_id: Client ID (if applicable)
        """
        elapsed_time = time.time() - start_time
        
        # Create a record
        record = {
            'phase': phase_name,
            'elapsed_time': elapsed_time,
            'task_id': task_id,
            'round_id': round_id,
            'client_id': client_id,
            'timestamp': time.time()
        }
        
        self.timing_data[phase_name].append(record)
        
        # Special handling for epoch times
        if phase_name == 'epoch':
            if task_id not in self.epoch_times:
                self.epoch_times[task_id] = defaultdict(list)
            self.epoch_times[task_id][round_id].append(elapsed_time)
        
        return elapsed_time
    
    def get_average_epoch_time(self, task_id=None, round_id=None):
        """Get average epoch time for a specific task/round or overall"""
        if task_id is not None and round_id is not None:
            if task_id in self.epoch_times and round_id in self.epoch_times[task_id]:
                times = self.epoch_times[task_id][round_id]
                return np.mean(times) if times else 0
        elif task_id is not None:
            # Average across all rounds for a task
            all_times = []
            if task_id in self.epoch_times:
                for round_times in self.epoch_times[task_id].values():
                    all_times.extend(round_times)
            return np.mean(all_times) if all_times else 0
        else:
            # Overall average
            all_times = []
            for task_times in self.epoch_times.values():
                for round_times in task_times.values():
                    all_times.extend(round_times)
            return np.mean(all_times) if all_times else 0
    
    def save_timing_data(self):
        """Save timing data to JSON and CSV files"""
        # Save raw data as JSON
        json_path = os.path.join(self.output_dir, 'timing_data.json')
        with open(json_path, 'w') as f:
            json.dump({
                'timing_data': self.timing_data,
                'epoch_times': self.epoch_times
            }, f, indent=2)
        
        # Create summary CSV
        summary_data = []
        for task_id in self.epoch_times:
            for round_id in self.epoch_times[task_id]:
                epoch_times = self.epoch_times[task_id][round_id]
                if epoch_times:
                    summary_data.append({
                        'task_id': task_id,
                        'round_id': round_id,
                        'avg_epoch_time': np.mean(epoch_times),
                        'std_epoch_time': np.std(epoch_times),
                        'min_epoch_time': np.min(epoch_times),
                        'max_epoch_time': np.max(epoch_times),
                        'num_epochs': len(epoch_times)
                    })
        
        if summary_data:
            df = pd.DataFrame(summary_data)
            csv_path = os.path.join(self.output_dir, 'epoch_timing_summary.csv')
            df.to_csv(csv_path, index=False)
            logger.info(f"Saved timing summary to {csv_path}")
    
    def plot_timing_analysis(self):
        """Create visualizations of timing data"""
        # Plot average epoch time per task
        plt.figure(figsize=(10, 6))
        
        task_avg_times = []
        task_ids = sorted(self.epoch_times.keys())
        
        for task_id in task_ids:
            avg_time = self.get_average_epoch_time(task_id=task_id)
            task_avg_times.append(avg_time)
        
        if task_avg_times:
            plt.bar(task_ids, task_avg_times, color='skyblue', edgecolor='black')
            plt.xlabel('Task ID', fontsize=12)
            plt.ylabel('Average Epoch Time (seconds)', fontsize=12)
            plt.title('Average Training Time per Epoch by Task', fontsize=14)
            plt.grid(axis='y', alpha=0.3)
            
            # Add value labels on bars
            for i, (task_id, avg_time) in enumerate(zip(task_ids, task_avg_times)):
                plt.text(task_id, avg_time + 0.01, f'{avg_time:.2f}s', 
                        ha='center', va='bottom')
            
            plt.tight_layout()
            plot_path = os.path.join(self.output_dir, 'epoch_timing_by_task.png')
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            plt.close()
            logger.info(f"Saved timing plot to {plot_path}")


class SyntheticDataEvaluator:
    """
    Evaluates the quality of synthetic data encodings using IS and FID scores
    """
    def __init__(self, feature_dim=512, num_classes=26, device='cuda'):
        self.feature_dim = feature_dim
        self.num_classes = num_classes
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        # Storage for scores
        self.is_scores = []
        self.fid_scores = []
        
        # Initialize a simple classifier for IS calculation
        self.classifier = self._create_classifier()
        
    def _create_classifier(self):
        """Create a simple classifier for IS calculation"""
        classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, self.num_classes)
        ).to(self.device)
        return classifier

    def _flatten_encodings(self, encodings):
        if encodings.dim() > 2:
            return encodings.reshape(encodings.size(0), -1)
        return encodings
    
    def train_classifier(self, real_encodings, labels, epochs=50):
        """
        Train the classifier on real encodings for IS calculation
        
        Args:
            real_encodings: List of real encoding tensors
            labels: Corresponding labels
            epochs: Number of training epochs
        """
        self.classifier.train()
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        
        # Prepare data
        if isinstance(real_encodings, list):
            encodings = torch.cat(real_encodings, dim=0).to(self.device)
        else:
            encodings = real_encodings.to(self.device)
        encodings = self._flatten_encodings(encodings)
            
        if isinstance(labels, list):
            labels = torch.cat(labels, dim=0).to(self.device)
        else:
            labels = labels.to(self.device)
        
        if labels.dim() > 1:
            if labels.size(-1) == self.num_classes:
                labels = labels.argmax(dim=-1)
            else:
                labels = labels.view(-1)

        # Ensure labels are long type for CrossEntropyLoss
        labels = labels.long()

        # Training loop
        for epoch in range(epochs):
            optimizer.zero_grad()
            outputs = self.classifier(encodings)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            if (epoch + 1) % 10 == 0:
                logger.debug(f"Classifier training epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
        
        self.classifier.eval()
        logger.info("Classifier training completed for IS calculation")
    
    def calculate_inception_score(self, synthetic_encodings, splits=10):
        """
        Calculate Inception Score for synthetic encodings
        
        Args:
            synthetic_encodings: Synthetic encoding tensors
            splits: Number of splits for IS calculation
            
        Returns:
            Tuple of (mean_score, std_score)
        """
        self.classifier.eval()
        
        # Prepare data
        if isinstance(synthetic_encodings, list):
            encodings = torch.cat(synthetic_encodings, dim=0).to(self.device)
        else:
            encodings = synthetic_encodings.to(self.device)
        encodings = self._flatten_encodings(encodings)
        
        n_samples = encodings.shape[0]
        
        with torch.no_grad():
            # Get predictions
            logits = self.classifier(encodings)
            preds = F.softmax(logits, dim=1).cpu().numpy()
        
        # Calculate IS
        split_scores = []
        for k in range(splits):
            part = preds[k * (n_samples // splits): (k + 1) * (n_samples // splits), :]
            py = np.mean(part, axis=0)
            scores = []
            for i in range(part.shape[0]):
                pyx = part[i, :]
                scores.append(np.sum(pyx * (np.log(pyx + 1e-8) - np.log(py + 1e-8))))
            split_scores.append(np.exp(np.mean(scores)))
        
        is_mean = np.mean(split_scores)
        is_std = np.std(split_scores)
        
        self.is_scores.append({
            'mean': is_mean,
            'std': is_std,
            'timestamp': time.time()
        })
        
        return is_mean, is_std
    
    def calculate_fid_score(self, real_encodings, synthetic_encodings):
        """
        Calculate Fréchet Inception Distance between real and synthetic encodings
        
        Args:
            real_encodings: Real encoding tensors
            synthetic_encodings: Synthetic encoding tensors
            
        Returns:
            FID score
        """
        # Prepare data
        if isinstance(real_encodings, list):
            real = torch.cat(real_encodings, dim=0)
        else:
            real = real_encodings
        real = self._flatten_encodings(real).cpu().numpy()
            
        if isinstance(synthetic_encodings, list):
            synthetic = torch.cat(synthetic_encodings, dim=0)
        else:
            synthetic = synthetic_encodings
        synthetic = self._flatten_encodings(synthetic).cpu().numpy()
        
        # Calculate mean and covariance
        mu_real = np.mean(real, axis=0)
        mu_synthetic = np.mean(synthetic, axis=0)
        
        sigma_real = np.cov(real, rowvar=False)
        sigma_synthetic = np.cov(synthetic, rowvar=False)
        
        # Calculate FID
        diff = mu_real - mu_synthetic
        
        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_synthetic), disp=False)
        
        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        
        fid = diff.dot(diff) + np.trace(sigma_real + sigma_synthetic - 2 * covmean)
        
        self.fid_scores.append({
            'score': fid,
            'timestamp': time.time()
        })
        
        return fid
    
    def save_scores(self, output_dir='./dump/metrics'):
        """Save IS and FID scores to files"""
        os.makedirs(output_dir, exist_ok=True)
        
        # Save scores as JSON
        scores_data = {
            'is_scores': self.is_scores,
            'fid_scores': self.fid_scores
        }
        
        json_path = os.path.join(output_dir, 'synthetic_quality_scores.json')
        with open(json_path, 'w') as f:
            json.dump(scores_data, f, indent=2)
        
        # Create summary
        if self.is_scores:
            is_means = [s['mean'] for s in self.is_scores]
            is_stds = [s['std'] for s in self.is_scores]
            logger.info(f"Average IS: {np.mean(is_means):.4f} ± {np.mean(is_stds):.4f}")
        
        if self.fid_scores:
            fid_values = [s['score'] for s in self.fid_scores]
            logger.info(f"Average FID: {np.mean(fid_values):.4f}")
        
        logger.info(f"Saved synthetic quality scores to {json_path}")
    
    def plot_score_evolution(self, output_dir='./dump/metrics'):
        """Plot the evolution of IS and FID scores over time"""
        os.makedirs(output_dir, exist_ok=True)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Plot IS scores
        if self.is_scores:
            is_means = [s['mean'] for s in self.is_scores]
            is_stds = [s['std'] for s in self.is_scores]
            x = range(len(is_means))
            
            ax1.errorbar(x, is_means, yerr=is_stds, marker='o', linewidth=2, 
                        capsize=5, capthick=2, markersize=8)
            ax1.set_xlabel('Evaluation Step', fontsize=12)
            ax1.set_ylabel('Inception Score', fontsize=12)
            ax1.set_title('Inception Score Evolution', fontsize=14)
            ax1.grid(True, alpha=0.3)
        
        # Plot FID scores
        if self.fid_scores:
            fid_values = [s['score'] for s in self.fid_scores]
            x = range(len(fid_values))
            
            ax2.plot(x, fid_values, 'o-', linewidth=2, markersize=8, color='orange')
            ax2.set_xlabel('Evaluation Step', fontsize=12)
            ax2.set_ylabel('FID Score', fontsize=12)
            ax2.set_title('FID Score Evolution', fontsize=14)
            ax2.grid(True, alpha=0.3)
        
        plt.suptitle('Synthetic Data Quality Metrics', fontsize=16)
        plt.tight_layout()
        
        plot_path = os.path.join(output_dir, 'synthetic_quality_evolution.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Saved quality evolution plot to {plot_path}")


class MetricsRecorder:
    """
    Main class to coordinate all metrics recording
    """
    def __init__(self, opt):
        self.opt = opt
        self.output_dir = os.path.join(opt.output_dir, 'metrics')
        os.makedirs(self.output_dir, exist_ok=True)
        
        # Initialize components
        self.computation_tracker = ComputationTracker(self.output_dir)
        self.synthetic_evaluator = SyntheticDataEvaluator(
            feature_dim=opt.nh,  # Use hidden dimension from config
            num_classes=opt.num_classes,
            device=opt.device
        )
        
        # Storage for encodings
        self.real_encodings_buffer = []
        self.synthetic_encodings_buffer = []
        self.labels_buffer = []
        
    def record_epoch_time(self, start_time, task_id, round_id, client_id=None):
        """Record the time for an epoch"""
        return self.computation_tracker.end_timer(
            start_time, 'epoch', task_id, round_id, client_id
        )
    
    def record_phase_time(self, start_time, phase_name, task_id=None, round_id=None):
        """Record the time for a training phase"""
        return self.computation_tracker.end_timer(
            start_time, phase_name, task_id, round_id
        )
    
    def add_real_encodings(self, encodings, labels):
        """Add real encodings for quality evaluation"""
        self.real_encodings_buffer.extend(encodings)
        self.labels_buffer.extend(labels)
        
        # Limit buffer size to prevent memory issues
        max_buffer_size = 10000
        if len(self.real_encodings_buffer) > max_buffer_size:
            self.real_encodings_buffer = self.real_encodings_buffer[-max_buffer_size:]
            self.labels_buffer = self.labels_buffer[-max_buffer_size:]
    
    def evaluate_synthetic_quality(self, synthetic_encodings):
        """Evaluate the quality of synthetic encodings"""
        if not self.real_encodings_buffer:
            logger.warning("No real encodings available for quality evaluation")
            return None, None
        
        # Train classifier if needed (only once)
        if not hasattr(self, '_classifier_trained'):
            logger.info("Training classifier for IS calculation...")
            real_tensor = torch.stack(self.real_encodings_buffer[:5000])  # Use subset
            labels_tensor = torch.stack(self.labels_buffer[:5000])
            self.synthetic_evaluator.train_classifier(real_tensor, labels_tensor)
            self._classifier_trained = True
        
        # Calculate IS
        is_mean, is_std = self.synthetic_evaluator.calculate_inception_score(synthetic_encodings)
        logger.info(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
        
        # Calculate FID
        real_sample = torch.stack(self.real_encodings_buffer[:len(synthetic_encodings)])
        fid_score = self.synthetic_evaluator.calculate_fid_score(real_sample, synthetic_encodings)
        logger.info(f"FID Score: {fid_score:.4f}")
        
        return is_mean, fid_score
    
    def save_all_metrics(self):
        """Save all collected metrics"""
        # Save timing data
        self.computation_tracker.save_timing_data()
        self.computation_tracker.plot_timing_analysis()
        
        # Save synthetic quality scores
        self.synthetic_evaluator.save_scores(self.output_dir)
        self.synthetic_evaluator.plot_score_evolution(self.output_dir)
        
        logger.info(f"All metrics saved to {self.output_dir}")
        
        # Create summary report
        self._create_summary_report()
    
    def _create_summary_report(self):
        """Create a summary report of all metrics"""
        report_path = os.path.join(self.output_dir, 'metrics_summary.txt')
        
        with open(report_path, 'w') as f:
            f.write("=== GFedCL Metrics Summary ===\n\n")
            
            # Timing summary
            f.write("1. Computation Overhead:\n")
            avg_epoch_time = self.computation_tracker.get_average_epoch_time()
            f.write(f"   - Average epoch time: {avg_epoch_time:.4f} seconds\n")
            
            # Task-wise timing
            for task_id in sorted(self.computation_tracker.epoch_times.keys()):
                avg_task_time = self.computation_tracker.get_average_epoch_time(task_id=task_id)
                f.write(f"   - Task {task_id} average: {avg_task_time:.4f} seconds\n")
            
            f.write("\n2. Synthetic Data Quality:\n")
            
            # IS scores
            if self.synthetic_evaluator.is_scores:
                is_means = [s['mean'] for s in self.synthetic_evaluator.is_scores]
                is_stds = [s['std'] for s in self.synthetic_evaluator.is_scores]
                f.write(f"   - Average IS: {np.mean(is_means):.4f} ± {np.mean(is_stds):.4f}\n")
                f.write(f"   - Best IS: {np.max(is_means):.4f}\n")
                f.write(f"   - Worst IS: {np.min(is_means):.4f}\n")
            
            # FID scores
            if self.synthetic_evaluator.fid_scores:
                fid_values = [s['score'] for s in self.synthetic_evaluator.fid_scores]
                f.write(f"   - Average FID: {np.mean(fid_values):.4f}\n")
                f.write(f"   - Best FID: {np.min(fid_values):.4f}\n")
                f.write(f"   - Worst FID: {np.max(fid_values):.4f}\n")
            
            f.write(f"\n3. Configuration:\n")
            f.write(f"   - Number of clients: {self.opt.num_clients}\n")
            f.write(f"   - Number of tasks: {self.opt.num_task}\n")
            f.write(f"   - Classes per task: {self.opt.class_per_task}\n")
            f.write(f"   - Local epochs: {self.opt.num_local_epochs}\n")
            f.write(f"   - Communication rounds: {self.opt.num_rounds}\n")
        
        logger.info(f"Metrics summary saved to {report_path}")
