#!/usr/bin/env python3
"""
Visualize toy action samples in a single PNG.
Rows: algorithms, Columns: hyperparameter variants.
"""

import argparse
import re
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt


def pick_algorithm_color(algorithm_name):
    if 'RAFMAC' in algorithm_name:
        return '#2ca02c'  # green
    if 'RADAC' in algorithm_name:
        return '#2ca02c'  # green
    return '#e41a1c'      # red


def extract_hyperparameter_value(filename):
    filename_lower = filename.lower()

    eta_match = re.search(r'eta([0-9]+\.?[0-9]*)', filename_lower)
    if eta_match:
        value = float(eta_match.group(1))
        if 'radac' in filename_lower:
            return ('RADAC', 'eta', value)
        if 'fql' in filename_lower:
            return ('FQL', 'eta', value)
        if 'diff' in filename_lower or 'diffusion' in filename_lower:
            return ('Diff-QL', 'eta', value)

    lamda_match = re.search(r'lam([0-9]+\.?[0-9]*)', filename_lower)
    if lamda_match:
        value = float(lamda_match.group(1))
        if 'oraac' in filename_lower:
            if 'diffusion' in filename_lower:
                return ('ORAAC-Diffusion', 'lamda', value)
            if 'flow' in filename_lower:
                return ('ORAAC-Flow', 'lamda', value)
            return ('ORAAC', 'lamda', value)

    if 'ql_cvae' in filename_lower or 'cvae' in filename_lower:
        return ('QL-CVAE', None, None)

    return None


def load_action_files(directory):
    directory = Path(directory)
    if not directory.exists():
        raise ValueError(f"Directory not found: {directory}")

    grouped = {}
    for npy_file in directory.glob('*_actions_*.npy'):
        info = extract_hyperparameter_value(npy_file.name)
        if info is None:
            continue

        algorithm, param_name, param_value = info
        try:
            actions = np.load(npy_file)
        except Exception as exc:
            print(f"Warning: failed to load {npy_file}: {exc}")
            continue

        if actions.ndim != 2 or actions.shape[1] != 2:
            print(f"Warning: {npy_file.name} has shape {actions.shape}, expected (N, 2)")
            continue

        grouped.setdefault(algorithm, []).append({
            'param_name': param_name,
            'param_value': param_value if param_value is not None else 0.0,
            'actions': actions,
            'filename': npy_file.name,
        })

    for algorithm in grouped:
        grouped[algorithm].sort(key=lambda x: x['param_value'])

    return grouped


def _format_param_value(value):
    if value is None:
        return ''
    return f'{float(value):.1f}'


def plot_combined_grid(grouped_data, output_path, max_samples=2000, algorithm_order=None):
    if not grouped_data:
        return

    if algorithm_order:
        algorithms = [name for name in algorithm_order if name in grouped_data]
        algorithms.extend([name for name in grouped_data if name not in algorithms])
    else:
        algorithms = list(grouped_data.keys())

    max_variants = max(len(data_list) for data_list in grouped_data.values())
    n_rows = len(algorithms)
    n_cols = max_variants

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3.2, n_rows * 3.2))
    if n_rows == 1 and n_cols == 1:
        axs = np.array([[axs]])
    elif n_rows == 1:
        axs = axs.reshape(1, -1)
    elif n_cols == 1:
        axs = axs.reshape(-1, 1)

    rng = np.random.default_rng(0)

    for row, algorithm_name in enumerate(algorithms):
        data_list = grouped_data[algorithm_name]
        color = pick_algorithm_color(algorithm_name)
        param_name = data_list[0]['param_name'] if data_list else None

        for col in range(n_cols):
            ax = axs[row, col]
            if col >= len(data_list):
                ax.axis('off')
                continue

            data = data_list[col]
            actions = data['actions']
            n_samples = len(actions)
            if n_samples > max_samples:
                indices = rng.choice(n_samples, size=max_samples, replace=False)
                actions_plot = actions[indices]
            else:
                actions_plot = actions

            ax.scatter(actions_plot[:, 0], actions_plot[:, 1],
                       c=color, s=10, alpha=0.5, edgecolors='none')
            ax.set_xlim(-1.1, 1.1)
            ax.set_ylim(-1.1, 1.1)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.grid(alpha=0.2, linewidth=0.5)

            if param_name is not None:
                ax.set_title(f'{param_name}={_format_param_value(data["param_value"])}', fontsize=8)

        axs[row, 0].text(-0.35, 0.5, algorithm_name, transform=axs[row, 0].transAxes,
                         ha='right', va='center', fontsize=9, fontweight='bold')

    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    fig.subplots_adjust(left=0.18, wspace=0.25, hspace=0.3)
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved combined plot: {output_path}")
    plt.close(fig)


def main():
    parser = argparse.ArgumentParser(description='Visualize toy action samples')
    parser.add_argument('--input_dir', type=str,
                        default='frozen_logs/toys/npy',
                        help='Directory containing toy .npy action files')
    parser.add_argument('--output', type=str,
                        default='frozen_logs/toys/toy_algorithms.png',
                        help='Output PNG path')
    parser.add_argument('--max_samples', type=int, default=2000,
                        help='Max samples per subplot')
    args = parser.parse_args()

    grouped = load_action_files(args.input_dir)
    if not grouped:
        print("No valid action files found.")
        return

    algorithm_order = [
        'ORAAC',
        'ORAAC-Diffusion',
        'ORAAC-Flow',
        'Diff-QL',
        'FQL',
        'QL-CVAE',
        'RADAC',
    ]
    plot_combined_grid(grouped, args.output, max_samples=args.max_samples, algorithm_order=algorithm_order)


if __name__ == '__main__':
    main()
