"""
Visualization tools
Provides functionality for plotting learning curves and phase diagrams
"""

import math
import os
from typing import Dict, Tuple, List

import numpy as np

try:
    import matplotlib.pyplot as plt
    plt.style.use('seaborn-v0_8-whitegrid')
except ImportError:
    plt = None


def plot_learning_curves(
    results: Dict[Tuple[int, float], Dict[str, float]],
    n_list: List[int],
    delta_list: List[float],
    methods: List[str],
    title_prefix: str = "",
    save_path: str = None,
):
    """Plot learning curves (MSE vs sample size)"""
    if plt is None:
        print("matplotlib not available, skipping plots.")
        return

    for delta in delta_list:
        plt.figure(figsize=(8, 6))
        for m in methods:
            ys = [results[(n, delta)][m] for n in n_list]
            plt.loglog(n_list, ys, marker="o", label=m)
        plt.xlabel("Sample size n (log scale)")
        plt.ylabel("Mean squared error")
        plt.title(f"{title_prefix} Learning curves (δ = {delta})")
        plt.legend()
        plt.grid(True, which="both", ls="--", alpha=0.5)
        
        if save_path:
            # Ensure directory exists
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(f"{save_path}_delta_{delta}.png", dpi=300, bbox_inches='tight')
        plt.close()


def plot_phase_diagram(
    results: Dict[Tuple[int, float], Dict[str, float]],
    n_list: List[int],
    delta_list: List[float],
    method: str,
    title: str = "",
    save_path: str = None,
):
    """Plot phase diagram (risk vs sample size and black-box quality)"""
    if plt is None:
        print("matplotlib not available, skipping plots.")
        return

    risk_mat = np.zeros((len(n_list), len(delta_list)))
    for i, n in enumerate(n_list):
        for j, delta in enumerate(delta_list):
            risk_mat[i, j] = results[(n, delta)][method]

    plt.figure(figsize=(10, 8))
    im = plt.imshow(
        np.log10(risk_mat + 1e-10),  # Avoid log(0) error
        origin="lower",
        aspect="auto",
        extent=[
            math.log10(min(delta_list) + 1e-12),
            math.log10(max(delta_list) + 1e-12),
            math.log10(min(n_list)),
            math.log10(max(n_list)),
        ],
    )
    plt.colorbar(im, label="log10 Mean squared error")
    plt.xlabel("log10 δ (black-box bias)")
    plt.ylabel("log10 n (sample size)")
    plt.title(title or f"Phase diagram - {method}")
    
    # Add grid lines
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    
    # Add labels
    x_ticks = np.linspace(
        math.log10(min(delta_list) + 1e-12),
        math.log10(max(delta_list) + 1e-12),
        len(delta_list)
    )
    y_ticks = np.linspace(
        math.log10(min(n_list)),
        math.log10(max(n_list)),
        len(n_list)
    )
    
    plt.xticks(x_ticks, [f"{d:.2f}" for d in delta_list])
    plt.yticks(y_ticks, [str(n) for n in n_list])
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(f"{save_path}_{method}.png", dpi=300, bbox_inches='tight')
    plt.close()


def plot_method_comparison(
    results: Dict[Tuple[int, float], Dict[str, float]],
    n_list: List[int],
    delta_list: List[float],
    save_path: str = None,
):
    """Plot comparison of different methods under the same parameters"""
    if plt is None:
        print("matplotlib not available, skipping plots.")
        return
    
    fig, axes = plt.subplots(len(delta_list), 1, figsize=(10, 6*len(delta_list)))
    if len(delta_list) == 1:
        axes = [axes]
    
    methods = list(next(iter(results.values())).keys())
    
    for i, delta in enumerate(delta_list):
        for method in methods:
            ys = [results[(n, delta)][method] for n in n_list]
            axes[i].semilogy(n_list, ys, marker="o", label=method)
        
        axes[i].set_xlabel("Sample size n")
        axes[i].set_ylabel("Mean squared error (log scale)")
        axes[i].set_title(f"Method comparison (δ = {delta})")
        axes[i].legend()
        axes[i].grid(True, which="both", ls="--", alpha=0.5)
    
    plt.tight_layout()
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(f"{save_path}_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()