# enhanced_sample_visualizer.py
#
# Enhanced visualizer for experiment results, refactored for clarity and maintainability.
# Features temporal distribution curves, cross-dataset support, and multi-sample analysis.

import os
import json
import glob
import argparse
import warnings
from collections import defaultdict
from typing import List, Dict, Any, Optional, Tuple
from itertools import groupby

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from PIL import Image

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# --- Configuration ---

STYLE_CONFIG = {
    # Nested lists for color palettes. Each inner list is a palette for a group.
    "memorized_colors": [
        ['#EF5350', '#F44336', '#E53935', '#D32F2F', '#C62828'], # Reds
        ['#FFA726', '#FF9800', '#FB8C00', '#F57C00', '#EF6C00'], # Oranges
        ['#EC407A', '#E91E63', '#D81B60', '#C2185B', '#AD1457']  # Pinks
    ],
    "unmemorized_colors": [
        ['#42A5F5', '#2196F3', '#1E88E5', '#1976D2', '#1565C0'], # Blues
        ['#26A69A', '#009688', '#00897B', '#00796B', '#00695C'], # Teals
        ['#9CCC65', '#8BC34A', '#7CB342', '#689F38', '#558B2F']  # Greens
    ],
    "distribution": {
        "memorized_base": '#D32F2F',
        "unmemorized_base": '#1976D2',
        "fill_alpha": 0.15,
        "line_alpha": 0.7,
        "linestyle": '--'
    },
    "sample_curve": {
        "alpha": 0.75,
        "linewidth": 2.5,
        "marker": '.',
        "markersize": 4
    },
    "font": {
        "title": {'fontsize': 16, 'fontweight': 'bold'},
        "subtitle": {'fontsize': 14, 'fontweight': 'bold'},
        "label": {'fontsize': 12, 'fontweight': 'bold'},
        "tick": {'fontsize': 10},
        "legend": {'fontsize': 9},
        "annotation": {'fontsize': 8, 'fontweight': 'bold'},
    }
}

# --- Data Loading and Management ---

class DataLoader:
    """Handles all file system interactions for loading experiment data."""

    def __init__(self, root_dir: str):
        self.root_dir = root_dir
        self._distribution_cache = {}

    def load_sample_data(self, method: str, dataset: str, sample_idx: int, seed: Optional[int] = None) -> List[Dict[str, Any]]:
        """Loads data for a sample, handling single seed or all seeds."""
        method_dataset_dir = os.path.join(self.root_dir, method, dataset)
        if not os.path.exists(method_dataset_dir):
            print(f"Warning: Directory not found: {method_dataset_dir}")
            return []

        if seed is not None:
            json_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_{int(seed):02d}_*.json')
        else:
            json_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_*.json')

        json_files = glob.glob(json_pattern)
        if not json_files:
            print(f"Warning: No JSON found for sample {sample_idx} in {method_dataset_dir}")
            return []

        loaded_data = []
        for json_path in json_files:
            try:
                with open(json_path, 'r') as f:
                    metrics = json.load(f)

                # Infer seed from filename
                basename = os.path.basename(json_path)
                parts = basename.replace('.json', '').split('_')
                current_seed = int(parts[2]) if len(parts) > 2 and parts[2].isdigit() else None

                png_pattern = os.path.join(method_dataset_dir, f'prompt_{sample_idx:04d}_{current_seed:02d}_*.png')
                png_files = sorted(glob.glob(png_pattern))

                loaded_data.append({
                    'spec': {'method': method, 'dataset': dataset, 'sample_idx': sample_idx, 'seed': current_seed},
                    'metrics': metrics,
                    'images': png_files,
                    'is_memorized': metrics.get('memorized', False),
                    'prompt': self._extract_prompt_text(metrics)
                })
            except (json.JSONDecodeError, IndexError, ValueError) as e:
                print(f"Warning: Could not process file {json_path}: {e}")
                continue
        return loaded_data

    @staticmethod
    def _extract_prompt_text(metrics_data: Dict[str, Any]) -> str:
        """Robustly extracts prompt text from nested metrics data."""
        prompt_keys = ['prompt', 'text', 'caption', 'description']
        for key in prompt_keys:
            if isinstance(metrics_data.get(key), str):
                return metrics_data[key]
        # Fallback for nested structures
        for value in metrics_data.values():
            if isinstance(value, dict):
                for key in prompt_keys:
                    if isinstance(value.get(key), str):
                        return value[key]
        return "Prompt not found"

    def get_class_distributions(self, metric_name: str, is_temporal: bool) -> Dict[str, Any]:
        """Loads and caches metric distributions for memorized/unmemorized classes."""
        cache_key = (metric_name, is_temporal)
        if cache_key in self._distribution_cache:
            return self._distribution_cache[cache_key]

        print(f"Computing class distributions for metric: {metric_name}...")
        results = {}
        for status in ['memorized', 'unmemorized']:
            # Assuming a standard location for baseline distributions
            path = os.path.join(self.root_dir, 'baseline', f'laion_{status}')
            values = self._load_values_from_dir(path, metric_name)
            if is_temporal:
                results[status] = self._compute_temporal_stats(values)
            else:
                results[status] = np.array(values)

        self._distribution_cache[cache_key] = results
        return results

    def _load_values_from_dir(self, directory: str, metric_name: str) -> List[Any]:
        """Helper to extract metric values from all JSONs in a directory."""
        if not os.path.exists(directory):
            return []
        
        values = []
        json_files = glob.glob(os.path.join(directory, 'prompt_*.json'))
        for json_file in json_files[:2500]: # Limit for performance
            try:
                with open(json_file, 'r') as f:
                    data = json.load(f)
                metric_value = self._find_metric_value(data, metric_name)
                if metric_value is not None:
                    values.append(metric_value)
            except Exception:
                continue
        return values

    @staticmethod
    def _find_metric_value(data: Dict, metric_name: str) -> Optional[Any]:
        """Recursively finds a metric value in a nested dictionary."""
        if metric_name in data:
            return data[metric_name]
        
        # Handle cases where metric names are nested, e.g., "Noise_Difference_Norm_noise_diff_norm_traj"
        # might be under data['Noise_Difference_Norm']['noise_diff_norm_traj']
        parts = metric_name.split('_')
        for i in range(1, len(parts)):
            prefix = "_".join(parts[:i])
            suffix = "_".join(parts[i:])
            if prefix in data and isinstance(data[prefix], dict):
                nested_val = DataLoader._find_metric_value(data[prefix], suffix)
                if nested_val is not None:
                    return nested_val
        return None
        
    @staticmethod
    def _compute_temporal_stats(trajectories: List[List[float]]) -> Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
        """Computes mean and std for a list of temporal trajectories."""
        if not trajectories:
            return None
        max_len = max(len(t) for t in trajectories)
        padded = [np.pad(t, (0, max_len - len(t)), 'edge') for t in trajectories]
        traj_array = np.array(padded)
        return np.arange(max_len), np.mean(traj_array, axis=0), np.std(traj_array, axis=0)

# --- Plotting Service ---

class PlottingService:
    """Handles the creation of all visualizations."""

    def __init__(self, config: Dict[str, Any]):
        self.config = config

    def create_analysis_plot(self, samples: List[Dict], metric_name: str, distributions: Dict, output_path: str, frame_mode: str, legend_cols: int):
        """Generates and saves the main analysis plot."""
        show_images = any(s['images'] for s in samples)
        
        if show_images:
            fig = plt.figure(figsize=(20, 9))
            gs = GridSpec(1, 2, width_ratios=[2.2, 1], wspace=0.05)
            ax_metric = fig.add_subplot(gs[0])
            legend_panel_spec = gs[1]
            self._plot_legend_panel(fig, legend_panel_spec, samples, frame_mode, legend_cols)
        else:
            fig, ax_metric = plt.subplots(figsize=(12, 8))

        # Main Title
        title = self._generate_main_title(samples)
        fig.suptitle(f"{title} | {metric_name}", **self.config['font']['title'])

        # Plot metric data
        is_temporal = isinstance(self._get_metric_from_sample(samples[0], metric_name), list)
        if is_temporal:
            self._plot_temporal_metric(ax_metric, samples, metric_name, distributions)
        else:
            self._plot_scalar_metric(ax_metric, samples, metric_name, distributions)

        fig.tight_layout(rect=[0, 0.02, 1, 0.95]) # Adjust for suptitle and bottom margin
        plt.savefig(output_path, dpi=250, bbox_inches='tight', facecolor='white')
        plt.savefig(output_path.replace("png", "pdf"), dpi=250, bbox_inches='tight', facecolor='white')
        plt.close(fig)
        print(f"Successfully saved visualization to: {output_path}")

    def _plot_legend_panel(self, fig: plt.Figure, gs_main: GridSpec, samples: List[Dict], frame_mode: str, num_columns: int):
        """Renders a custom legend panel with images and grouped prompts."""
        # The samples are pre-sorted and colored by the Visualizer class
        grouped_samples = groupby(samples, key=lambda s: s['group_key'])

        group_layouts = []
        total_rows = 0
        for key, group in grouped_samples:
            group_list = list(group)
            if not group_list: continue
            
            n_item_rows = (len(group_list) + num_columns - 1) // num_columns
            group_layouts.append({
                'key': key, 'samples': group_list, 'start_row': total_rows, 'rows': n_item_rows
            })
            total_rows += n_item_rows

        if total_rows == 0: return

        legend_gs = gs_main.subgridspec(total_rows, num_columns, wspace=0.1, hspace=0.0)

        for layout in group_layouts:
            # Group Items
            for i, sample in enumerate(layout['samples']):
                row = i // num_columns
                col = i % num_columns
                ax = fig.add_subplot(legend_gs[layout['start_row'] + row, col])
                ax.axis('off')

                if sample['images']:
                    try:
                        img = Image.open(sample['images'][0])
                        if frame_mode == 'first':
                            img = img.crop((0, 0, img.width // 4, img.height))
                        ax.imshow(img)
                        plt.setp(ax.spines.values(), color=sample['color'], linewidth=3.5, visible=True)
                    except Exception:
                        ax.text(0.5, 0.5, 'Img Error', ha='center', va='center')
                
                prompt_short = (sample['prompt'][:25] + '...') if len(sample['prompt']) > 25 else sample['prompt']
                full_text = prompt_short
                ax.set_title(full_text, color=sample['color'], fontsize=9, y=-0.15, wrap=True)

    def _plot_temporal_metric(self, ax: plt.Axes, samples: List[Dict], metric_name: str, dists: Dict):
        """Plots time-series data with distribution bands."""
        # ax.set_title(f"Temporal Evolution", **self.config['font']['subtitle'])
        
        # Plot distribution bands
        for status, stats in dists.items():
            if stats:
                steps, mean_traj, std_traj = stats
                color = self.config['distribution'][f'{status}_base']
                ax.fill_between(np.arange(len(mean_traj)), mean_traj - std_traj, mean_traj + std_traj, color=color, alpha=self.config['distribution']['fill_alpha'])
                ax.plot(np.arange(len(mean_traj)), mean_traj, color=color, alpha=self.config['distribution']['line_alpha'],
                        linestyle=self.config['distribution']['linestyle'], label=f'{status.capitalize()} Mean')

        # Plot individual sample trajectories
        curve_style = self.config['sample_curve']
        for sample in samples:
            metric_value = self._get_metric_from_sample(sample, metric_name)
            if isinstance(metric_value, list):
                ax.plot(np.arange(len(metric_value)), metric_value, color=sample['color'], **curve_style)

        ax.set_xticks(np.flip(list(range(0,50,5))))
        ax.set_xlabel("Timestep", **self.config['font']['label'])
        # ax.set_ylabel(metric_name, **self.config['font']['label'])
        ax.grid(True, linestyle=':', alpha=0.6)
        ax.legend(loc='upper left', bbox_to_anchor=(0, 1.0), **self.config['font']['legend'])

    def _plot_scalar_metric(self, ax: plt.Axes, samples: List[Dict], metric_name: str, dists: Dict):
        """Plots scalar data with distribution histograms."""
        ax.set_title(f"Scalar Distribution", **self.config['font']['subtitle'])
        
        # Plot distribution histograms
        for status, values in dists.items():
            if values.any():
                color = self.config['distribution'][f'{status}_base']
                ax.hist(values, bins=50, color=color, alpha=0.5, density=True, label=f'{status.capitalize()} Dist.')

        # Plot vertical lines for each sample
        y_max = ax.get_ylim()[1] or 1.0
        for sample in samples:
            metric_value = self._get_metric_from_sample(sample, metric_name)
            if not isinstance(metric_value, list):
                ax.axvline(x=metric_value, color=sample['color'], linestyle='-', linewidth=2.5, alpha=0.8)
                ax.text(metric_value, y_max * 0.9, f"{metric_value:.3f}", rotation=90, ha='right',
                        va='top', color=sample['color'], **self.config['font']['annotation'])
        
        ax.set_xlabel(metric_name, **self.config['font']['label'])
        ax.set_ylabel("Density", **self.config['font']['label'])
        ax.grid(True, linestyle=':', alpha=0.6)
        ax.legend(loc='upper left', bbox_to_anchor=(0, 1.0), **self.config['font']['legend'])

    @staticmethod
    def _get_metric_from_sample(sample: Dict, metric_name: str) -> Optional[Any]:
        return DataLoader._find_metric_value(sample['metrics'], metric_name)
        
    @staticmethod
    def _generate_main_title(samples: List[Dict]) -> str:
        unique_sources = set(f"{s['spec']['method']}/{s['spec']['dataset']}" for s in samples)
        if len(unique_sources) == 1:
            return f"Analysis for {list(unique_sources)[0]}"
        return "MVDream"

# --- Main Orchestrator ---

class Visualizer:
    """Coordinates data loading and plotting."""

    def __init__(self, output_root: str, output_dir: str):
        self.output_root = output_root
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.loader = DataLoader(output_root)
        self.plotter = PlottingService(STYLE_CONFIG)

    def generate_visualization(self, sample_specs_str: List[str], metric_name: str, show_images: bool, frame_mode: str, legend_cols: int):
        """Main entry point to generate a visualization from specs."""
        parsed_specs = self._parse_sample_specs(sample_specs_str)
        if not parsed_specs:
            print("Error: No valid sample specifications provided.")
            return
            
        all_data_points = []
        for spec in parsed_specs:
            loaded_data = self.loader.load_sample_data(**spec)
            all_data_points.extend(loaded_data)

        if not all_data_points:
            print("Error: Could not load data for any of the specified samples.")
            return
        
        samples = self._assign_colors_to_samples(all_data_points)

        first_metric = self.plotter._get_metric_from_sample(samples[0], metric_name)
        if first_metric is None:
            print(f"Error: Metric '{metric_name}' not found in the first valid sample.")
            return

        is_temporal = isinstance(first_metric, list)
        distributions = self.loader.get_class_distributions(metric_name, is_temporal)

        safe_metric = metric_name.replace('/', '_').replace(' ', '_')
        filename = f"analysis_{safe_metric}_{len(samples)}samples.png"
        output_path = os.path.join(self.output_dir, filename)

        self.plotter.create_analysis_plot(samples, metric_name, distributions, output_path, frame_mode, legend_cols)

    @staticmethod
    def _get_grouping_key(s):
        dataset_name = s['spec']['dataset']
        is_memorized = s['is_memorized']
        if 'objaverse' in dataset_name:
            return (is_memorized, 'Objaverse Samples')
        return (is_memorized, dataset_name)
        
    def _assign_colors_to_samples(self, data_points: List[Dict]) -> List[Dict]:
        """Assigns colors to samples based on their group, ensuring intra-group similarity."""
        
        data_points.sort(key=self._get_grouping_key)
        
        colored_samples = []
        mem_group_idx = 0
        unmem_group_idx = 0
        current_group_key = None
        intra_group_count = 0

        for data_point in data_points:
            group_key = self._get_grouping_key(data_point)
            
            if group_key != current_group_key:
                # New group starts, reset intra-group counter and advance group palette
                current_group_key = group_key
                intra_group_count = 0
                if data_point['is_memorized']:
                    mem_group_idx += 1
                else:
                    unmem_group_idx += 1
            
            data_point['group_key'] = group_key

            if data_point['is_memorized']:
                palettes = STYLE_CONFIG['memorized_colors']
                # Use (mem_group_idx - 1) to get 0-based index for the palette
                palette = palettes[(mem_group_idx - 1) % len(palettes)]
            else:
                palettes = STYLE_CONFIG['unmemorized_colors']
                palette = palettes[(unmem_group_idx - 1) % len(palettes)]

            color = palette[intra_group_count % len(palette)]
            data_point['color'] = color
            colored_samples.append(data_point)
            
            intra_group_count += 1
            
        return colored_samples

    @staticmethod
    def _parse_sample_specs(specs_str: List[str]) -> List[Dict]:
        """Parses string specs into dictionary format for the loader."""
        parsed = []
        for spec in specs_str:
            parts = spec.split('/')
            try:
                parsed.append({
                    'method': parts[0],
                    'dataset': parts[1],
                    'sample_idx': int(parts[2]),
                    'seed': int(parts[3]) if len(parts) > 3 else None
                })
            except (IndexError, ValueError):
                print(f"Warning: Skipping invalid sample spec format: '{spec}'")
        return parsed

# --- Command-Line Interface ---

def main():
    """Defines and handles command-line arguments."""
    parser = argparse.ArgumentParser(
        description='Visualize experiment samples with enhanced analysis.',
        formatter_class=argparse.RawTextHelpFormatter,
        epilog="""
Examples:
  # Analyze multiple samples from one dataset
  python visualizer.py --samples baseline/laion_memorized/1 baseline/laion_memorized/5 --metric Noise_Difference_Norm_noise_diff_norm_traj

  # Cross-dataset analysis (memorized vs. unmemorized)
  python visualizer.py --samples baseline/laion_memorized/10 baseline/laion_unmemorized/10 --metric Noise_Difference_Norm_noise_diff_norm_traj

  # Specify a particular seed
  python visualizer.py --samples baseline/laion_memorized/1/42 --metric Hessian_SAIL_diff_sum_t50
"""
    )
    
    parser.add_argument('--samples', type=str, nargs='+', required=True,
                        help='Sample specifications in format method/dataset/sample_idx[/seed]')
    parser.add_argument('--metric', type=str, required=True,
                        help='Metric name to visualize')
    parser.add_argument('--output_root', type=str, default='output',
                        help='Root directory of experimental results')
    parser.add_argument('--output_dir', type=str, default='visualizations',
                        help='Directory to save generated plots')
    parser.add_argument('--no_images', action='store_true',
                        help='Generate a metrics-only plot without images')
    parser.add_argument('--first_frame_only', action='store_true',
                        help='Show only the first 25%% of the image (first frame)')
    parser.add_argument('--legend_cols', type=int, default=3,
                        help='Number of columns for the legend panel grid')


    args = parser.parse_args()

    visualizer = Visualizer(args.output_root, args.output_dir)
    visualizer.generate_visualization(
        sample_specs_str=args.samples,
        metric_name=args.metric,
        show_images=not args.no_images,
        frame_mode='first' if args.first_frame_only else 'full',
        legend_cols=args.legend_cols
    )

if __name__ == "__main__":
    main()

