# enhanced_sample_visualizer.py
#
# Enhanced visualizer for MVDream repository structure
# Features temporal distribution curves, cross-dataset support, and multi-sample analysis

import os
import json
import glob
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec
from PIL import Image
import argparse
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

class MVDreamSampleVisualizer:
    def __init__(self, output_root='output', output_dir='sample_visualizations'):
        self.output_root = output_root
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Color palette for memorization status - high contrast warm/cool
        self.mem_colors = ['#FF4444', '#FF6B35', '#FF8C42', '#FFB347', '#E85D75', '#D63031']  # Warm: reds, oranges, coral
        self.unmem_colors = ['#0066CC', '#007BFF', '#00CED1', '#20B2AA', '#4682B4', '#5DADE2']  # Cool: blues, teals, steel blue
        
        # Distribution cache for temporal curves
        self._temporal_distribution_cache = {}
        
    def load_sample_data(self, method, dataset, sample_idx, seed=None):
        """
        Loads JSON metrics and PNG images for a specific sample.
        If seed is None, returns data for all available seeds.
        """
        method_dataset_dir = os.path.join(self.output_root, method, dataset)
        
        if not os.path.exists(method_dataset_dir):
            print(f"Warning: Directory not found: {method_dataset_dir}")
            return None, None
        
        # Find JSON files (metrics) - handle seed specification
        if seed is not None:
            json_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_{int(seed):02d}_*.json')
            json_files = glob.glob(json_pattern)
            
            if not json_files:
                print(f"Warning: No JSON found for sample {sample_idx} (seed={seed}) in {method_dataset_dir}")
                return None, None
            
            # Load the specific seed
            json_path = json_files[0]
            try:
                with open(json_path, 'r') as f:
                    metrics_data = json.load(f)
            except Exception as e:
                print(f"Error loading JSON {json_path}: {e}")
                return None, None
            
            # Find corresponding PNG files
            png_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_{int(seed):02d}_*.png')
            png_files = sorted(glob.glob(png_pattern))
            
            print(f"Loaded sample {sample_idx} (seed={seed}): {len(png_files)} images, metrics from {os.path.basename(json_path)}")
            
            return metrics_data, png_files
        else:
            # Load all seeds for this sample
            json_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_*.json')
            json_files = glob.glob(json_pattern)
            
            if not json_files:
                print(f"Warning: No JSON found for sample {sample_idx} in {method_dataset_dir}")
                return None, None
            
            # Return data for all seeds
            all_seeds_data = []
            all_png_files = []
            
            for json_file in json_files:
                try:
                    with open(json_file, 'r') as f:
                        metrics_data = json.load(f)
                    
                    # Extract seed from filename if present
                    basename = os.path.basename(json_file)
                    file_seed = None
                    parts = basename.split('_')
                    if len(parts) >= 3:
                        try:
                            file_seed = int(parts[2])
                        except ValueError:
                            pass
                    
                    # Find corresponding PNG files for this seed
                    if file_seed is not None:
                        png_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_{file_seed:02d}_*.png')
                    else:
                        png_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_*.png')
                    png_files = sorted(glob.glob(png_pattern))
                    
                    all_seeds_data.append({
                        'seed': file_seed,
                        'metrics': metrics_data,
                        'png_files': png_files
                    })
                    
                except Exception as e:
                    print(f"Error loading JSON {json_file}: {e}")
                    continue
            
            print(f"Loaded sample {sample_idx}: {len(all_seeds_data)} seeds found")
            
            return all_seeds_data, None  # Return list of seed data instead of single metrics/png_files
    
    def crop_image_frame(self, img, frame_mode='full'):
        """
        Crop image to show specific frame(s).
        frame_mode: 'full' (all 4 frames) or 'first' (leftmost 25%)
        """
        if frame_mode == 'first':
            # Crop to left 25% (first frame)
            width, height = img.size
            frame_width = width // 4
            return img.crop((0, 0, frame_width, height))
        else:
            # Return full image (all 4 frames)
            return img
        """
        Get color for a sample based on memorization status.
        """
        if is_memorized:
            return self.mem_colors[sample_index % len(self.mem_colors)]
        else:
            return self.unmem_colors[sample_index % len(self.unmem_colors)]

    def get_sample_color(self, is_memorized, sample_index):
        """
        Get color for a sample based on memorization status.
        """
        if is_memorized:
            return self.mem_colors[sample_index % len(self.mem_colors)]
        else:
            return self.unmem_colors[sample_index % len(self.unmem_colors)]

    def extract_prompt_text(self, metrics_data):
        """
        Extract the actual prompt text from metrics data.
        """
        # Common keys where prompt might be stored
        prompt_keys = ['prompt', 'text', 'input_text', 'description', 'caption']
        
        # Direct lookup
        for key in prompt_keys:
            if key in metrics_data:
                prompt = metrics_data[key]
                if isinstance(prompt, str) and len(prompt.strip()) > 0:
                    return prompt.strip()
        
        # Nested search
        def search_nested(obj, depth=0):
            if depth > 3:  # Limit recursion depth
                return None
            
            if isinstance(obj, dict):
                for key in prompt_keys:
                    if key in obj and isinstance(obj[key], str) and len(obj[key].strip()) > 0:
                        return obj[key].strip()
                
                # Recursive search
                for value in obj.values():
                    result = search_nested(value, depth + 1)
                    if result:
                        return result
            
            return None
        
        prompt = search_nested(metrics_data)
        
        # Fallback: use any text field
        if not prompt:
            def find_any_text(obj, depth=0):
                if depth > 2:
                    return None
                
                if isinstance(obj, dict):
                    for key, value in obj.items():
                        if isinstance(value, str) and len(value.strip()) > 10:
                            # Skip obvious non-prompt fields
                            skip_keys = ['path', 'file', 'url', 'id', 'hash', 'time', 'date']
                            if not any(skip in key.lower() for skip in skip_keys):
                                return value.strip()
                    
                    for value in obj.values():
                        result = find_any_text(value, depth + 1)
                        if result:
                            return result
                
                return None
            
            prompt = find_any_text(metrics_data)
        
        # Final fallback
        if not prompt:
            return "No prompt found"
        
        return prompt

    def load_temporal_class_distributions(self, metric_name):
        """
        Load temporal metric distributions for memorized and unmemorized classes.
        Returns mean and std trajectories over time steps.
        """
        cache_key = metric_name
        
        if cache_key in self._temporal_distribution_cache:
            return self._temporal_distribution_cache[cache_key]
        
        print(f"Computing temporal class distributions for {metric_name}...")
        
        # Define class directories
        mem_dir = os.path.join(self.output_root, 'baseline', 'laion_memorized')
        unmem_dir = os.path.join(self.output_root, 'baseline', 'laion_unmemorized')
        
        # Collect all trajectories
        mem_trajectories = []
        unmem_trajectories = []
        
        # Load memorized trajectories
        if os.path.exists(mem_dir):
            json_files = glob.glob(os.path.join(mem_dir, 'prompt_*.json'))
            for json_file in json_files[:2500]:  # Limit for performance
                try:
                    with open(json_file, 'r') as f:
                        data = json.load(f)
                    metric_value = self._extract_metric_value(data, metric_name)
                    if metric_value is not None and isinstance(metric_value, list):
                        mem_trajectories.append(metric_value)
                except Exception:
                    continue
        
        # Load unmemorized trajectories
        if os.path.exists(unmem_dir):
            json_files = glob.glob(os.path.join(unmem_dir, 'prompt_*.json'))
            for json_file in json_files[:2500]:  # Limit for performance
                try:
                    with open(json_file, 'r') as f:
                        data = json.load(f)
                    metric_value = self._extract_metric_value(data, metric_name)
                    if metric_value is not None and isinstance(metric_value, list):
                        unmem_trajectories.append(metric_value)
                except Exception:
                    continue
        
        # Compute temporal statistics
        def compute_temporal_stats(trajectories):
            if not trajectories:
                return None, None, None
            
            # Find the maximum length
            max_len = max(len(traj) for traj in trajectories)
            
            # Pad trajectories to the same length (forward fill last value)
            padded_trajectories = []
            for traj in trajectories:
                if len(traj) < max_len:
                    padded = list(traj) + [traj[-1]] * (max_len - len(traj))
                else:
                    padded = traj[:max_len]
                padded_trajectories.append(padded)
            
            # Convert to numpy array and compute statistics
            traj_array = np.array(padded_trajectories)
            mean_traj = np.mean(traj_array, axis=0)
            std_traj = np.std(traj_array, axis=0)
            steps = np.arange(max_len)
            
            return steps, mean_traj, std_traj
        
        # Compute for both classes
        mem_stats = compute_temporal_stats(mem_trajectories)
        unmem_stats = compute_temporal_stats(unmem_trajectories)
        
        distributions = {
            'memorized': {
                'n_samples': len(mem_trajectories),
                'temporal_stats': mem_stats
            },
            'unmemorized': {
                'n_samples': len(unmem_trajectories),
                'temporal_stats': unmem_stats
            }
        }
        
        self._temporal_distribution_cache[cache_key] = distributions
        print(f"Computed temporal distributions - Mem: {len(mem_trajectories)} trajectories, Unmem: {len(unmem_trajectories)} trajectories")
        
        return distributions
    
    def load_scalar_class_distributions(self, metric_name):
        """
        Load scalar metric distributions for memorized and unmemorized classes.
        """
        print(f"Computing scalar class distributions for {metric_name}...")
        
        mem_dir = os.path.join(self.output_root, 'baseline', 'laion_memorized')
        unmem_dir = os.path.join(self.output_root, 'baseline', 'laion_unmemorized')
        
        mem_values = []
        unmem_values = []
        
        # Load memorized values
        if os.path.exists(mem_dir):
            json_files = glob.glob(os.path.join(mem_dir, 'prompt_*.json'))
            for json_file in json_files[:2500]:
                try:
                    with open(json_file, 'r') as f:
                        data = json.load(f)
                    metric_value = self._extract_metric_value(data, metric_name)
                    if metric_value is not None and not isinstance(metric_value, list):
                        mem_values.append(metric_value)
                except Exception:
                    continue
        
        # Load unmemorized values
        if os.path.exists(unmem_dir):
            json_files = glob.glob(os.path.join(unmem_dir, 'prompt_*.json'))
            for json_file in json_files[:2500]:
                try:
                    with open(json_file, 'r') as f:
                        data = json.load(f)
                    metric_value = self._extract_metric_value(data, metric_name)
                    if metric_value is not None and not isinstance(metric_value, list):
                        unmem_values.append(metric_value)
                except Exception:
                    continue
        
        return {
            'memorized': np.array(mem_values) if mem_values else np.array([]),
            'unmemorized': np.array(unmem_values) if unmem_values else np.array([])
        }
    
    def parse_sample_specs(self, sample_specs):
        """
        Parse sample specifications that can include method/dataset/sample/seed.
        Format: method/dataset/sample_idx or method/dataset/sample_idx/seed
        """
        parsed_samples = []
        
        for spec in sample_specs:
            parts = spec.split('/')
            
            if len(parts) < 3:
                print(f"Warning: Invalid sample spec '{spec}'. Expected format: method/dataset/sample_idx or method/dataset/sample_idx/seed")
                continue
            
            try:
                method = parts[0]
                dataset = parts[1]
                sample_idx = int(parts[2])
                seed = int(parts[3]) if len(parts) > 3 else None
                
                parsed_samples.append({
                    'method': method,
                    'dataset': dataset,
                    'sample_idx': sample_idx,
                    'seed': seed
                })
            except ValueError as e:
                print(f"Warning: Could not parse sample spec '{spec}': {e}")
                continue
        
        return parsed_samples
    
    def plot_cross_dataset_analysis(self, sample_specs, metric_name, show_images=True, frame_mode='full'):
        """
        Creates analysis plot for samples from potentially different datasets/methods.
        sample_specs: list of strings like ['baseline/laion_memorized/1', 'baseline/laion_unmemorized/5/42']
        """
        # Parse sample specifications
        parsed_samples = self.parse_sample_specs(sample_specs)
        
        if not parsed_samples:
            print("No valid sample specifications")
            return
        
        # Load data for all samples
        sample_data = []
        mem_count = 0
        unmem_count = 0
        
        for i, spec in enumerate(parsed_samples):
            loaded_data = self.load_sample_data(
                spec['method'], spec['dataset'], spec['sample_idx'], spec['seed']
            )
            
            if loaded_data[0] is not None:
                if spec['seed'] is not None:
                    # Single seed specified
                    metrics_data, png_files = loaded_data
                    is_memorized = metrics_data.get('memorized', False)
                    
                    color = self.get_sample_color(is_memorized, mem_count if is_memorized else unmem_count)
                    if is_memorized:
                        mem_count += 1
                    else:
                        unmem_count += 1
                    
                    prompt_text = self.extract_prompt_text(metrics_data)
                    
                    sample_data.append({
                        'spec': spec,
                        'metrics': metrics_data,
                        'images': png_files,
                        'color': color,
                        'is_memorized': is_memorized,
                        'prompt_text': prompt_text
                    })
                else:
                    # All seeds for this sample
                    all_seeds_data, _ = loaded_data
                    for seed_data in all_seeds_data:
                        metrics_data = seed_data['metrics']
                        png_files = seed_data['png_files']
                        file_seed = seed_data['seed']
                        
                        is_memorized = metrics_data.get('memorized', False)
                        
                        color = self.get_sample_color(is_memorized, mem_count if is_memorized else unmem_count)
                        if is_memorized:
                            mem_count += 1
                        else:
                            unmem_count += 1
                        
                        prompt_text = self.extract_prompt_text(metrics_data)
                        
                        # Create spec for this specific seed
                        seed_spec = {
                            'method': spec['method'],
                            'dataset': spec['dataset'],
                            'sample_idx': spec['sample_idx'],
                            'seed': file_seed
                        }
                        
                        sample_data.append({
                            'spec': seed_spec,
                            'metrics': metrics_data,
                            'images': png_files,
                            'color': color,
                            'is_memorized': is_memorized,
                            'prompt_text': prompt_text
                        })
        
        if not sample_data:
            print("No valid sample data found")
            return
        
        # Check if metric is temporal or scalar
        first_metric_value = self._extract_metric_value(sample_data[0]['metrics'], metric_name)
        if first_metric_value is None:
            available_metrics = self.get_available_metrics(sample_data[0]['metrics'])
            print(f"Metric '{metric_name}' not found.")
            print(f"Available metrics: {available_metrics[:10]}...")
            return
        
        is_temporal = isinstance(first_metric_value, list)
        
        # Create the plot layout
        if show_images and any(sample['images'] for sample in sample_data):
            fig = plt.figure(figsize=(16, 12))
            # Images on top, metrics on bottom
            gs = GridSpec(2, 1, height_ratios=[1, 1.5], hspace=0.3)
            
            # Create image subplot grid
            ax_images = fig.add_subplot(gs[0, 0])
            ax_images.axis('off')
            
            # Create individual image subplots
            n_samples = len(sample_data)
            img_cols = min(4, n_samples)
            img_rows = (n_samples + img_cols - 1) // img_cols
            
            gs_imgs = GridSpec(img_rows, img_cols, top=0.9, bottom=0.6, 
                              left=0.1, right=0.9, hspace=0.3, wspace=0.1)
            
            for i, sample in enumerate(sample_data):
                if i >= 16:  # Limit to 16 images max
                    break
                    
                row = i // img_cols
                col = i % img_cols
                ax_img = fig.add_subplot(gs_imgs[row, col])
                
                if sample['images']:
                    try:
                        img = Image.open(sample['images'][0])
                        # Apply frame cropping
                        img = self.crop_image_frame(img, frame_mode)
                        ax_img.imshow(img)
                        
                        # Add border based on memorization status
                        border_color = sample['color']
                        
                        # Add border
                        for spine in ax_img.spines.values():
                            spine.set_color(border_color)
                            spine.set_linewidth(3)
                            spine.set_visible(True)
                        
                        # Title with truncated prompt and status
                        spec = sample['spec']
                        seed_str = f" s{spec['seed']}" if spec['seed'] is not None else ""
                        mem_indicator = '' if sample['is_memorized'] else ''
                        
                        # Truncate prompt for display
                        prompt_display = sample['prompt_text']
                        if len(prompt_display) > 40:
                            prompt_display = prompt_display[:37] + "..."
                        
                        title_text = f'{mem_indicator} {prompt_display}\n#{spec["sample_idx"]:04d}{seed_str}'
                        ax_img.set_title(title_text, fontsize=8, fontweight='bold', 
                                       color=sample['color'], wrap=True)
                        
                    except Exception as e:
                        ax_img.text(0.5, 0.5, f'Image\nError', ha='center', va='center',
                                   transform=ax_img.transAxes, color=sample['color'])
                else:
                    ax_img.text(0.5, 0.5, 'No\nImage', ha='center', va='center',
                               transform=ax_img.transAxes, color=sample['color'])
                
                ax_img.axis('off')
            
            # Metrics plot
            ax_metric = fig.add_subplot(gs[1, 0])
        else:
            fig, ax_metric = plt.subplots(1, 1, figsize=(12, 8))
        
        # Main title
        dataset_list = list(set([f"{s['spec']['method']}/{s['spec']['dataset']}" for s in sample_data]))
        if len(dataset_list) == 1:
            title = f"{dataset_list[0].upper()} - Multiple Samples"
        else:
            title = f"Cross-Dataset Analysis"
        fig.suptitle(f'{title}\nMetric: {metric_name}', fontsize=14, fontweight='bold')
        
        # Plot the metric with distribution context
        if is_temporal:
            self._plot_temporal_metric_with_curves(ax_metric, sample_data, metric_name)
        else:
            self._plot_scalar_metric_with_distribution(ax_metric, sample_data, metric_name)
        
        # Save the plot
        safe_metric_name = metric_name.replace('/', '_').replace(' ', '_')
        filename = f'cross_dataset_analysis_{safe_metric_name}_{len(sample_data)}samples.png'
        filepath = os.path.join(self.output_dir, filename)
        
        plt.tight_layout()
        plt.savefig(filepath, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        print(f"Created cross-dataset analysis plot: {filepath}")
        return sample_data
    
    def _plot_temporal_metric_with_curves(self, ax, sample_data, metric_name):
        """
        Plot temporal metrics with proper temporal distribution curves (mean ± std over time).
        """
        # Load temporal distributions
        distributions = self.load_temporal_class_distributions(metric_name)
        
        # Plot distribution curves first (background)
        mem_stats = distributions['memorized']['temporal_stats']
        unmem_stats = distributions['unmemorized']['temporal_stats']
        
        if mem_stats[0] is not None:  # If we have memorized data
            steps, mean_traj, std_traj = mem_stats
            ax.fill_between(steps, mean_traj - std_traj, mean_traj + std_traj,
                           alpha=0.2, color='#e74c3c', label='Memorized')
            ax.plot(steps, mean_traj, color='#e74c3c', linestyle='--', alpha=0.8, 
                   linewidth=2, label='Memorized Mean')
        
        if unmem_stats[0] is not None:  # If we have unmemorized data
            steps, mean_traj, std_traj = unmem_stats
            ax.fill_between(steps, mean_traj - 2*std_traj, mean_traj + 2*std_traj,
                           alpha=0.2, color='#3498db', label='Unmemorized')
            ax.plot(steps, mean_traj, color='#3498db', linestyle='--', alpha=0.8, 
                   linewidth=2, label='Unmemorized Mean')
        
        # Plot individual sample trajectories
        for i, sample in enumerate(sample_data):
            metric_value = self._extract_metric_value(sample['metrics'], metric_name)
            if metric_value is not None and isinstance(metric_value, list):
                steps = np.arange(len(metric_value))
                
                # Create label with prompt text
                prompt_text = sample['prompt_text']
                if len(prompt_text) > 50:
                    prompt_text = prompt_text[:47] + "..."
                
                mem_str = " (MEM)" if sample['is_memorized'] else " (UNMEM)"
                label = f'"{prompt_text}"{mem_str}'
                
                # Plot trajectory
                ax.plot(steps, metric_value, color=sample['color'], linewidth=3, alpha=0.9,
                       label=label, marker='o', markersize=3, markevery=max(1, len(metric_value)//20))
        
        ax.set_xlabel('Timestep / DDIM Step', fontweight='bold', fontsize=12)
        ax.set_ylabel(f'{metric_name}', fontweight='bold', fontsize=12)
        ax.set_title(f'Temporal Evolution: {metric_name}', fontweight='bold', fontsize=14)
        ax.grid(True, alpha=0.3)
        ax.legend(loc='best', framealpha=0.9, fontsize=9)
        
        # Add distribution info text box
        dist_text = "Class Distributions:\n"
        if mem_stats[0] is not None:
            dist_text += f"Memorized\n"
        if unmem_stats[0] is not None:
            dist_text += f"Unmemorized"
        
        ax.text(0.02, 0.98, dist_text, transform=ax.transAxes, va='top', ha='left',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.9, edgecolor='gray'),
               fontsize=9)
    
    def _plot_scalar_metric_with_distribution(self, ax, sample_data, metric_name):
        """
        Plot scalar metrics with histogram distribution and vertical line indicators.
        """
        # Load scalar distributions
        distributions = self.load_scalar_class_distributions(metric_name)
        
        # Extract sample values
        sample_info = []
        
        for sample in sample_data:
            metric_value = self._extract_metric_value(sample['metrics'], metric_name)
            if metric_value is not None and not isinstance(metric_value, list):
                prompt_text = sample['prompt_text']
                if len(prompt_text) > 50:
                    prompt_text = prompt_text[:47] + "..."
                
                mem_str = " (MEM)" if sample['is_memorized'] else " (UNMEM)"
                label = f'"{prompt_text}"{mem_str}'
                
                sample_info.append({
                    'value': metric_value,
                    'label': label,
                    'color': sample['color'],
                    'is_memorized': sample['is_memorized']
                })
        
        if not sample_info:
            ax.text(0.5, 0.5, 'No scalar data available', ha='center', va='center',
                   transform=ax.transAxes, fontsize=12)
            return
        
        # Plot distribution histograms
        if distributions['memorized'].size > 0:
            ax.hist(distributions['memorized'], bins=40, alpha=0.4, color='#e74c3c', 
                   label=f'Memorized Distribution (n={len(distributions["memorized"])})', density=True)
        
        if distributions['unmemorized'].size > 0:
            ax.hist(distributions['unmemorized'], bins=40, alpha=0.4, color='#3498db', 
                   label=f'Unmemorized Distribution (n={len(distributions["unmemorized"])})', density=True)
        
        # Plot vertical lines for sample values
        y_max = ax.get_ylim()[1]
        for info in sample_info:
            ax.axvline(x=info['value'], color=info['color'], linestyle='-', linewidth=4, 
                      alpha=0.9, label=info['label'])
            
            # Add value annotation
            ax.text(info['value'], y_max * 0.95, f"{info['value']:.4f}", 
                   rotation=90, ha='right', va='top', fontweight='bold', 
                   color=info['color'], fontsize=9,
                   bbox=dict(boxstyle="round,pad=0.2", facecolor='white', alpha=0.9))
        
        ax.set_xlabel(f'{metric_name}', fontweight='bold', fontsize=12)
        ax.set_ylabel('Density', fontweight='bold', fontsize=12)
        ax.set_title(f'Distribution Analysis: {metric_name}', fontweight='bold', fontsize=14)
        ax.grid(True, alpha=0.3)
        ax.legend(loc='best', framealpha=0.9, fontsize=9)
    
    def _extract_metric_value(self, data, metric_name):
        """
        Extracts a metric value from nested JSON data.
        """
        # Direct lookup first
        if metric_name in data:
            return data[metric_name]
        
        # Comprehensive search through all nested keys
        def comprehensive_search(obj, target, current_path=""):
            if isinstance(obj, dict):
                for key, value in obj.items():
                    new_path = f"{current_path}_{key}" if current_path else key
                    
                    # Check if current path matches target
                    if new_path == target or new_path.replace('-', '_') == target:
                        return value
                    
                    # Recursive search
                    result = comprehensive_search(value, target, new_path)
                    if result is not None:
                        return result
            return None
        
        result = comprehensive_search(data, metric_name)
        if result is not None:
            return result
        
        # Final fallback: partial match
        def partial_match_search(obj, target):
            if isinstance(obj, dict):
                for key, value in obj.items():
                    if target.lower() in key.lower() or key.lower() in target.lower():
                        if isinstance(value, (list, float, int)):
                            return value
                    result = partial_match_search(value, target)
                    if result is not None:
                        return result
            return None
        
        return partial_match_search(data, metric_name)
    
    def get_available_metrics(self, metrics_data):
        """
        Recursively extracts all available metric names from JSON data.
        """
        def flatten_keys(obj, prefix=""):
            keys = []
            if isinstance(obj, dict):
                for key, value in obj.items():
                    if key == 'memorized':
                        continue
                    new_key = f"{prefix}{key}" if prefix else key
                    if isinstance(value, dict):
                        keys.extend(flatten_keys(value, f"{new_key}_"))
                    elif isinstance(value, (list, float, int)):
                        keys.append(new_key.replace('-', '_'))
            return keys
        
        return flatten_keys(metrics_data)
    
    def list_available_samples(self, method, dataset, max_samples=20):
        """
        Lists available samples in a method/dataset directory.
        """
        method_dataset_dir = os.path.join(self.output_root, method, dataset)
        
        if not os.path.exists(method_dataset_dir):
            print(f"Directory not found: {method_dataset_dir}")
            return {}
        
        json_files = glob.glob(os.path.join(method_dataset_dir, 'prompt_*.json'))
        sample_seeds = defaultdict(list)
        
        for json_file in json_files:
            basename = os.path.basename(json_file)
            try:
                parts = basename.split('_')
                if len(parts) >= 2 and parts[0] == 'prompt':
                    sample_idx = int(parts[1])
                    
                    # Check for seed in filename
                    seed = None
                    if 'seed' in parts:
                        seed_idx = parts.index('seed')
                        if seed_idx + 1 < len(parts):
                            try:
                                seed = int(parts[seed_idx + 1])
                            except ValueError:
                                pass
                    
                    if seed not in sample_seeds[sample_idx]:
                        sample_seeds[sample_idx].append(seed)
            except ValueError:
                continue
        
        sorted_samples = dict(sorted(list(sample_seeds.items())[:max_samples]))
        
        print(f"Found samples in {method}/{dataset}:")
        for sample_idx, seeds in sorted_samples.items():
            seeds_str = f"seeds: {seeds}" if any(s is not None for s in seeds) else "no seeds"
            print(f"  Sample {sample_idx:04d} ({seeds_str})")
        
        return sorted_samples
    
    def list_available_metrics(self, method, dataset, sample_idx, seed=None):
        """
        Lists available metrics for a specific sample.
        """
        metrics_data, _ = self.load_sample_data(method, dataset, sample_idx, seed)
        
        if metrics_data is None:
            return []
        
        return self.get_available_metrics(metrics_data)


def main():
    """
    Enhanced command-line interface with cross-dataset support.
    """
    parser = argparse.ArgumentParser(
        description='Visualize MVDream samples across datasets with enhanced analysis',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Sample specification formats:
  method/dataset/sample_idx          - e.g., baseline/laion_memorized/1
  method/dataset/sample_idx/seed     - e.g., baseline/laion_memorized/1/42

Examples:
  # Single dataset, multiple samples
  python script.py --samples baseline/laion_memorized/1 baseline/laion_memorized/5 --metric Noise_Difference_Norm_noise_diff_norm_traj
  
  # Cross-dataset analysis
  python script.py --samples baseline/laion_memorized/1 baseline/laion_unmemorized/1 baseline/objaverse_sonic/1 --metric Noise_Difference_Norm_noise_diff_norm_traj
  
  # Multiple seeds
  python script.py --samples baseline/laion_memorized/1/42 baseline/laion_memorized/1/123 --metric Noise_Difference_Norm_noise_diff_norm_traj
        """
    )
    
    # Primary mode: sample specifications
    parser.add_argument('--samples', type=str, nargs='+',
                       help='Sample specifications in format method/dataset/sample_idx[/seed]')
    parser.add_argument('--metric', type=str, 
                       help='Metric name to visualize')
    parser.add_argument('--no_images', action='store_true',
                       help='Skip showing images (metrics only)')
    parser.add_argument('--first_frame_only', action='store_true',
                       help='Show only the first frame (leftmost 25%) of each image')
    
    # Legacy single-dataset mode
    parser.add_argument('--method', type=str, default='baseline',
                       help='Method name for legacy mode')
    parser.add_argument('--dataset', type=str,
                       help='Dataset name for legacy mode')
    parser.add_argument('--sample_idx', type=int, nargs='+',
                       help='Sample indices for legacy mode')
    parser.add_argument('--seeds', type=int, nargs='*',
                       help='Seeds for legacy mode')
    
    # Utility modes
    parser.add_argument('--list_samples', action='store_true',
                       help='List available samples (requires --method and --dataset)')
    parser.add_argument('--list_metrics', action='store_true',
                       help='List available metrics (requires --method, --dataset, and --sample_idx)')
    
    parser.add_argument('--output_root', type=str, default='output',
                       help='Root directory containing experimental results')
    
    args = parser.parse_args()
    
    visualizer = MVDreamSampleVisualizer(output_root=args.output_root)
    
    # Handle utility modes
    if args.list_samples:
        if not args.method or not args.dataset:
            print("Error: --method and --dataset required for --list_samples")
            return
        visualizer.list_available_samples(args.method, args.dataset)
        return
    
    if args.list_metrics:
        if not args.method or not args.dataset or not args.sample_idx:
            print("Error: --method, --dataset, and --sample_idx required for --list_metrics")
            return
        sample_idx = args.sample_idx[0]
        seed = args.seeds[0] if args.seeds else None
        metrics = visualizer.list_available_metrics(args.method, args.dataset, sample_idx, seed)
        print(f"Available metrics for {args.method}/{args.dataset}/sample_{sample_idx} (seed={seed}):")
        for i, metric in enumerate(metrics[:50]):
            print(f"  {i+1:2d}. {metric}")
        if len(metrics) > 50:
            print(f"  ... and {len(metrics) - 50} more metrics")
        return
    
    if not args.metric:
        print("Error: --metric is required")
        return
    
    # Main visualization modes
    if args.samples:
        # New cross-dataset mode
        frame_mode = 'first' if args.first_frame_only else 'full'
        visualizer.plot_cross_dataset_analysis(args.samples, args.metric, 
                                              show_images=not args.no_images,
                                              frame_mode=frame_mode)
    elif args.method and args.dataset and args.sample_idx:
        # Legacy single-dataset mode - convert to sample specs
        sample_specs = []
        seeds = args.seeds if args.seeds else [None] * len(args.sample_idx)
        
        for i, sample_idx in enumerate(args.sample_idx):
            seed = seeds[i] if i < len(seeds) else None
            if seed is not None:
                spec = f"{args.method}/{args.dataset}/{sample_idx}/{seed}"
            else:
                spec = f"{args.method}/{args.dataset}/{sample_idx}"
            sample_specs.append(spec)
        
        frame_mode = 'first' if args.first_frame_only else 'full'
        visualizer.plot_cross_dataset_analysis(sample_specs, args.metric,
                                              show_images=not args.no_images,
                                              frame_mode=frame_mode)
    else:
        print("Error: Either --samples OR (--method + --dataset + --sample_idx) required")
        print("Use --help for usage examples")


def example_enhanced_analysis():
    """
    Example: Enhanced cross-dataset analysis.
    """
    visualizer = MVDreamSampleVisualizer()
    
    print("=== Enhanced Cross-Dataset Analysis Examples ===")
    
    # Example 1: Cross-dataset comparison
    print("\n1. Cross-dataset comparison:")
    cross_dataset_samples = [
        'baseline/laion_memorized/1',
        'baseline/laion_unmemorized/1', 
        'baseline/objaverse_sonic/1',
        'baseline/objaverse_fazbear/1'
    ]
    visualizer.plot_cross_dataset_analysis(cross_dataset_samples, 
                                          'Noise_Difference_Norm_noise_diff_norm_traj')
    
    # Example 2: Multiple samples same dataset
    print("\n2. Multiple samples from same dataset:")
    same_dataset_samples = [
        'baseline/laion_memorized/1',
        'baseline/laion_memorized/5',
        'baseline/laion_memorized/10'
    ]
    visualizer.plot_cross_dataset_analysis(same_dataset_samples, 
                                          'Noise_Difference_Norm_noise_diff_norm_mean')
    
    # Example 3: Multiple seeds
    print("\n3. Multiple seeds analysis:")
    multi_seed_samples = [
        'baseline/laion_memorized/1/42',
        'baseline/laion_memorized/1/123',
        'baseline/laion_memorized/1/456'
    ]
    visualizer.plot_cross_dataset_analysis(multi_seed_samples, 
                                          'Hessian_SAIL_diff_sum_t50')
    
    print("\nEnhanced analysis complete!")


if __name__ == "__main__":
    if len(os.sys.argv) == 1:
        print("Running enhanced cross-dataset example analysis...")
        example_enhanced_analysis()
    else:
        main()
        """
        python evaluation/visualizer_v1.py --method  baseline --dataset laion_memorized --sample_idx 0 --seeds 0 --list_metrics
        Available metrics for baseline/laion_memorized/sample_0 (seed=0):
   1. Noise_Difference_Norm_noise_diff_norm_mean
   2. Noise_Difference_Norm_noise_diff_norm_traj
   3. Hessian_SAIL_Metric_hessian_sail_norm
   4. Hessian_SAIL_Metric_visualizations_t50_cond_magnitudes
   5. Hessian_SAIL_Metric_visualizations_t50_uncond_magnitudes
   6. Hessian_SAIL_Metric_visualizations_t1_cond_magnitudes
   7. Hessian_SAIL_Metric_visualizations_t1_uncond_magnitudes
   8. Hessian_SAIL_Metric_visualizations_t20_cond_magnitudes
   9. Hessian_SAIL_Metric_visualizations_t20_uncond_magnitudes
  10. HessianMetric_t50_cond_eigvals
  11. HessianMetric_t50_uncond_eigvals
  12. HessianMetric_t1_cond_eigvals
  13. HessianMetric_t1_uncond_eigvals
  14. HessianMetric_t20_cond_eigvals
  15. HessianMetric_t20_uncond_eigvals
  16. BrightEnding_LD_Score_ld_score
  17. BrightEnding_LD_Score_d_score
  18. BrightEnding_LD_Score_be_intensity
  19. CrossAttention_Entropy_entropy
        """