import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import os
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from model.vae_models.vae import VAE, vae_loss


class VAEEvaluator:
    """Evaluator class for VAE models."""
    
    def __init__(self, model: VAE, config: Dict, device: torch.device):
        self.model = model.to(device)
        self.config = config
        self.device = device
        self.results_dir = config['results']['save_dir']
        os.makedirs(self.results_dir, exist_ok=True)
    
    def evaluate_model(self, loader: DataLoader, prefix: str = 'test') -> Dict[str, float]:
        """Evaluate model on a given data loader with metric key prefix."""
        self.model.eval()
        total_loss = 0
        total_recon_loss = 0
        total_kl_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for data, _ in loader:
                data = data.to(self.device)
                # Prefer model-delegated ELBO if available (for hierarchical models)
                if hasattr(self.model, 'compute_elbo_loss'):
                    loss, recon_loss, kl_loss = self.model.compute_elbo_loss(data, beta=self.config['training']['beta'])
                else:
                    recon_batch, mu, logvar = self.model(data)
                    loss, recon_loss, kl_loss = vae_loss(recon_batch, data, mu, logvar, self.config['training']['beta'])
                
                # Update statistics
                total_loss += loss.item()
                total_recon_loss += recon_loss.item()
                total_kl_loss += kl_loss.item()
                num_batches += 1
        
        return {
            f'{prefix}_loss': total_loss / num_batches,
            f'{prefix}_recon_loss': total_recon_loss / num_batches,
            f'{prefix}_kl_loss': total_kl_loss / num_batches
        }
    
    def generate_reconstructions(self, test_loader: DataLoader, num_samples: int = 16) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate reconstructions for visualization."""
        self.model.eval()
        original_images = []
        reconstructed_images = []
        
        with torch.no_grad():
            for data, _ in test_loader:
                data = data.to(self.device)
                recon_batch, _, _ = self.model(data)
                
                original_images.append(data.cpu())
                reconstructed_images.append(recon_batch.cpu())
                
                if len(original_images) * data.size(0) >= num_samples:
                    break
        
        original = torch.cat(original_images, dim=0)[:num_samples]
        reconstructed = torch.cat(reconstructed_images, dim=0)[:num_samples]
        
        return original, reconstructed
    
    def visualize_reconstructions(self, test_loader: DataLoader, save_path: Optional[str] = None) -> None:
        """Visualize original and reconstructed images."""
        original, reconstructed = self.generate_reconstructions(test_loader)
        
        # Handle both flattened (B, 784) and image (B, 1, 28, 28)
        if original.dim() == 2:
            original = original.view(-1, 28, 28)
        elif original.dim() == 4 and original.size(1) == 1:
            original = original.squeeze(1)
        if reconstructed.dim() == 2:
            reconstructed = reconstructed.view(-1, 28, 28)
        elif reconstructed.dim() == 4 and reconstructed.size(1) == 1:
            reconstructed = reconstructed.squeeze(1)
        
        fig, axes = plt.subplots(4, 8, figsize=(16, 8))
        fig.suptitle('VAE Reconstructions', fontsize=16)
        
        for i in range(16):
            row = i // 4
            col = i % 4
            
            # Original
            axes[row, col*2].imshow(original[i], cmap='gray')
            axes[row, col*2].set_title(f'Original {i+1}')
            axes[row, col*2].axis('off')
            
            # Reconstructed
            axes[row, col*2+1].imshow(reconstructed[i], cmap='gray')
            axes[row, col*2+1].set_title(f'Reconstructed {i+1}')
            axes[row, col*2+1].axis('off')
        
        plt.tight_layout()
        
        if save_path is None:
            save_path = os.path.join(self.results_dir, 'reconstructions.png')
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Reconstructions saved to {save_path}")
    
    def visualize_latent_space(self, test_loader: DataLoader, method: str = 'tsne', 
                              save_path: Optional[str] = None) -> None:
        """Visualize latent space using dimensionality reduction."""
        self.model.eval()
        latent_vectors = []
        labels = []
        
        with torch.no_grad():
            for data, target in test_loader:
                data = data.to(self.device)
                mu, _ = self.model.encode(data)
                latent_vectors.append(mu.cpu().numpy())
                labels.extend(target.numpy())
                
                if len(latent_vectors) * data.size(0) >= 1000:  # Limit for visualization
                    break
        
        latent_vectors = np.concatenate(latent_vectors, axis=0)[:1000]
        labels = np.array(labels)[:1000]
        
        # Dimensionality reduction
        if method == 'tsne':
            reducer = TSNE(n_components=2, random_state=42)
        elif method == 'pca':
            reducer = PCA(n_components=2)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        latent_2d = reducer.fit_transform(latent_vectors)
        
        # Plot
        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap='tab10', alpha=0.6)
        plt.colorbar(scatter)
        plt.title(f'Latent Space Visualization ({method.upper()})')
        plt.xlabel(f'{method.upper()} 1')
        plt.ylabel(f'{method.upper()} 2')
        
        if save_path is None:
            save_path = os.path.join(self.results_dir, f'latent_space_{method}.png')
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Latent space visualization saved to {save_path}")
    
    def generate_samples(self, num_samples: int = 16, save_path: Optional[str] = None) -> None:
        """Generate samples from the latent space."""
        self.model.eval()
        
        with torch.no_grad():
            samples = self.model.sample(num_samples, self.device)
        
        # Handle both flattened (B, 784) and image (B, 1, 28, 28)
        if samples.dim() == 2:
            samples = samples.view(-1, 28, 28).cpu()
        elif samples.dim() == 4 and samples.size(1) == 1:
            samples = samples.squeeze(1).cpu()
        else:
            samples = samples.cpu()
        
        # Plot
        fig, axes = plt.subplots(4, 4, figsize=(8, 8))
        fig.suptitle('Generated Samples', fontsize=16)
        
        for i in range(num_samples):
            row = i // 4
            col = i % 4
            axes[row, col].imshow(samples[i], cmap='gray')
            axes[row, col].set_title(f'Sample {i+1}')
            axes[row, col].axis('off')
        
        plt.tight_layout()
        
        if save_path is None:
            save_path = os.path.join(self.results_dir, 'generated_samples.png')
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Generated samples saved to {save_path}")
    
    def compute_metrics(self, test_loader: DataLoader, train_loader: Optional[DataLoader] = None) -> Dict[str, float]:
        """Compute evaluation metrics on test (and optionally train) sets."""
        test_metrics = self.evaluate_model(test_loader, prefix='test')
        results: Dict[str, float] = dict(test_metrics)
        
        if train_loader is not None:
            train_metrics = self.evaluate_model(train_loader, prefix='train')
            results.update(train_metrics)
            # Empirical generalization gaps
            results['gap_loss'] = results['test_loss'] - results['train_loss']
            results['gap_recon_loss'] = results['test_recon_loss'] - results['train_recon_loss']
            results['gap_kl_loss'] = results['test_kl_loss'] - results['train_kl_loss']
        
        # Additional metrics can be added here
        # For example: reconstruction quality metrics, diversity metrics, etc.
        
        return results
    
    def save_results(self, results: Dict[str, float], filename: str = 'evaluation_results.json') -> None:
        """Save evaluation results to file."""
        save_path = os.path.join(self.results_dir, filename)
        
        with open(save_path, 'w') as f:
            import json
            json.dump(results, f, indent=2)
        
        print(f"Results saved to {save_path}")
    
    def create_evaluation_report(self, test_loader: DataLoader, train_loader: Optional[DataLoader] = None) -> None:
        """Create a comprehensive evaluation report."""
        print("Creating evaluation report...")
        
        # Compute metrics
        metrics = self.compute_metrics(test_loader, train_loader=train_loader)
        self.save_results(metrics)
        
        # Generate visualizations
        self.visualize_reconstructions(test_loader)
        self.visualize_latent_space(test_loader, method='tsne')
        self.visualize_latent_space(test_loader, method='pca')
        self.generate_samples()
        
        print("Evaluation report completed!") 