import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import argparse
import copy

# Import from shared utility
from plot_config_utils import (
    BASE_COLORS,
    BASE_DEFAULT_PLOT_STYLE_CONFIG,
    _apply_font_settings,
    deep_update_style_config
)

NC_PLOT_DEFAULT_STYLE_CONFIG = copy.deepcopy(BASE_DEFAULT_PLOT_STYLE_CONFIG)
NC_PLOT_DEFAULT_STYLE_CONFIG['figure_size_per_subplot'] = NC_PLOT_DEFAULT_STYLE_CONFIG.get('figure_size_per_subplot',
                                                                                           (12, 8))

PLOT_CONTENT_CONFIG = {  # This remains specific to nc_plot.py
    'value_labels': {
        'nc1': {'title': 'Intra-class Variability Collapse', 'ylabel': 'NC1'},
        'nc2': {'title': 'Class Means Separation', 'ylabel': 'NC2'},
        'nc3': {'title': 'Self-duality', 'ylabel': 'NC3'},
        'test_acc': {'title': 'Test Accuracy', 'ylabel': 'Accuracy (%)'}
    },
    'normalization_style': {
        'none': {'label': 'None', 'color': BASE_COLORS['blue']},
        'standard': {'label': 'LayerNorm', 'color': BASE_COLORS['green']},
        'rms': {'label': 'RMSNorm', 'color': BASE_COLORS['orange']}
    }
}


# ---

def plot_experiment_results(network, dataset, directory, value_name_list,
                            style_config_overrides=None, content_config_overrides=None,
                            save_figure=False, pdf=False):
    """
    Plot aggregated metrics for experiments and optionally save the figure.
    """
    current_style_config = copy.deepcopy(NC_PLOT_DEFAULT_STYLE_CONFIG)  # Use nc_plot's specific default
    if style_config_overrides:
        deep_update_style_config(current_style_config, style_config_overrides)

    current_content_config = copy.deepcopy(PLOT_CONTENT_CONFIG)
    if content_config_overrides:
        deep_update_style_config(current_content_config, content_config_overrides)  # Reusing helper

    s_cfg = current_style_config
    c_cfg = current_content_config

    _apply_font_settings(s_cfg['font_settings'])

    n_vals = len(value_name_list)
    if n_vals < 1:
        raise ValueError("Number of metrics must be at least 1.")

    prefix = f"{network}_{dataset}"
    subdirs = [os.path.join(directory, d) for d in os.listdir(directory)
               if d.startswith(prefix) and os.path.isdir(os.path.join(directory, d))]
    if not subdirs:
        raise FileNotFoundError(f"No subdirs starting with '{prefix}' in '{directory}'")

    fig_width_per, fig_height = s_cfg['figure_size_per_subplot']
    fig, axes = plt.subplots(1, n_vals, figsize=(fig_width_per * n_vals, fig_height), squeeze=False)
    axes = axes[0]

    legend_handles = []
    legend_labels = []
    added_labels = set()

    for ax_idx, (ax, value_name) in enumerate(zip(axes, value_name_list)):
        plot_details = c_cfg['value_labels'].get(value_name, {})
        subplot_title = plot_details.get('title', f"{value_name.replace('_', ' ').capitalize()} vs Epochs")
        y_label_text = plot_details.get('ylabel', value_name.replace('_', ' ').capitalize())

        ax.set_title(subplot_title, fontsize=s_cfg['title_fontsize'], fontweight='bold')
        ax.set_xlabel('Epoch', fontsize=s_cfg['axis_label_fontsize'])
        if y_label_text:
            ax.set_ylabel(y_label_text, fontsize=s_cfg['axis_label_fontsize'])

        if value_name == 'test_acc':
            if dataset.lower() == 'mnist':
                ax.set_ylim(99, 99.6)
            elif dataset.lower() == 'cifar10':
                if network.lower() == 'resnet18':
                    ax.set_ylim(75, 83.1)
                elif network.lower() == 'resnet50':
                    ax.set_ylim(55, 82.1)
            elif 'imagenet' in dataset.lower():
                ax.set_ylim(50, 85.1)
        else:
            if dataset.lower() == 'cifar10' and network.lower() == 'resnet18':
                ax.set_ylim(0, 1.4)
            if dataset.lower() == 'cifar10' and network.lower() == 'resnet50':
                ax.set_ylim(0, 1.5)

        for sub_dir_idx, sub in enumerate(subdirs):
            csv_path = os.path.join(sub, 'metrics_aggregated.csv')
            sum_path = os.path.join(sub, 'summary_experiment.txt')
            if not os.path.exists(csv_path) or not os.path.exists(sum_path):
                print(f"Warning: Missing files in {sub}, skipping.")
                continue
            df = pd.read_csv(csv_path)
            mean_col, std_col = f"{value_name}_mean", f"{value_name}_std"
            if mean_col not in df or std_col not in df:
                print(f"Warning: Metric {value_name} not found in {csv_path}, skipping.")
                continue

            epochs = df['epoch'].to_numpy() if 'epoch' in df else np.arange(len(df))
            means, stds = df[mean_col].to_numpy(), df[std_col].to_numpy()

            norm_type = 'unknown'
            with open(sum_path) as f:
                for line_content in f:
                    if line_content.startswith('Layer Normalization:'):
                        norm_type = line_content.split(':', 1)[1].strip().lower()
                        break

            norm_style_dict = c_cfg['normalization_style']
            norm_details = norm_style_dict.get(norm_type,
                                               {'label': norm_type.capitalize(),
                                                'color': BASE_COLORS['unknown_norm']})
            line_label = norm_details['label']
            line_color = norm_details['color']

            line_plot, = ax.plot(epochs, means, label=line_label, color=line_color, linewidth=s_cfg['plot_linewidth'])
            ax.fill_between(epochs, means - stds, means + stds, alpha=s_cfg['fill_alpha'], color=line_color)

            if ax_idx == 3:
                if line_label not in added_labels:
                    legend_handles.append(line_plot)
                    legend_labels.append(line_label)
                    added_labels.add(line_label)

        if ax_idx == 3:
            if legend_handles:
                legend = ax.legend(
                    handles=legend_handles,
                    labels=legend_labels,
                    title='Normalization',
                    fontsize=s_cfg['legend_fontsize'],
                    title_fontsize=s_cfg['legend_title_fontsize'],
                    frameon=s_cfg['legend_frameon'],
                    fancybox=s_cfg['legend_fancybox'],
                    shadow=s_cfg['legend_shadow'],
                    facecolor=s_cfg['legend_facecolor'],
                    edgecolor=s_cfg['legend_edgecolor'],
                    framealpha=0.9,
                    loc='lower right'
                )
                if legend:
                    legend.get_frame().set_linewidth(s_cfg['legend_frame_linewidth'])
            else:
                print(f"Warning: No data to plot for legend on the first subplot for value '{value_name}'.")

        ax.grid(True, linestyle=s_cfg['grid_linestyle'],
                linewidth=s_cfg['grid_linewidth'],
                alpha=s_cfg['grid_alpha'],
                color=BASE_COLORS['light_gray'])

        ax.tick_params(axis='both', which='major',
                       labelsize=s_cfg['tick_label_fontsize'],
                       width=s_cfg['axes_linewidth'])

        for spine_pos in ['top', 'right', 'bottom', 'left']:
            ax.spines[spine_pos].set_linewidth(s_cfg['axes_linewidth'])
            ax.spines[spine_pos].set_color(BASE_COLORS['gray'])

    plt.tight_layout(pad=s_cfg['tight_layout_pad'])

    save_path = None
    if save_figure:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        value_str = "-".join(value_name_list)
        filename = f"{network}_{dataset}_{value_str}_{timestamp}.png" if not pdf else f"{network}_{dataset}_{value_str}_{timestamp}.pdf"
        save_path = os.path.join(directory, filename)
        fig.savefig(save_path, dpi=s_cfg['save_dpi'],
                    bbox_inches='tight', facecolor=s_cfg['save_facecolor'])
        print(f"Saved figure to: {save_path}")
        plt.close(fig)
    else:
        plt.show()

    return fig, axes, save_path


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plot experimental results for neural nets.')
    # ... (rest of argparse setup remains the same as your latest nc_plot.py) ...
    parser.add_argument('--network', type=str, required=True, help='Network name (e.g., resnet18)')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset name (e.g., cifar10)')
    parser.add_argument('--directory', type=str, required=True, help='Root experiment directory')
    parser.add_argument('--values', type=str, nargs='+', required=True,
                        help='Metric names to plot (e.g., test_acc nc1)')
    parser.add_argument('--save', action='store_true', help='Save figure instead of showing interactively')
    parser.add_argument('--font', type=str, default='sans-serif', choices=['sans-serif', 'serif', 'times new roman'],
                        help="Choose font style: 'sans-serif' (default), 'serif', or 'times new roman'.")
    parser.add_argument('--pdf', action='store_true', help="save figure as pdf")

    args = parser.parse_args()

    custom_style_overrides = {
        'title_fontsize': 70,  # User value
        }

    if args.font == 'times new roman' or args.font == 'serif':
        custom_style_overrides['font_settings'] = {
            'family': 'serif',
            'serif': ['Times New Roman'] + [f for f in NC_PLOT_DEFAULT_STYLE_CONFIG['font_settings'].get('serif', []) if
                                            f != 'Times New Roman']
        }
    elif args.font == 'sans-serif':
        custom_style_overrides['font_settings'] = {'family': 'sans-serif'}

    print(f"--- Running nc_plot.py with Plot Style Overrides ---")
    # ... (rest of the __main__ block from your latest nc_plot.py) ...
    effective_style_config_for_print = copy.deepcopy(NC_PLOT_DEFAULT_STYLE_CONFIG)
    if custom_style_overrides:
        deep_update_style_config(effective_style_config_for_print, custom_style_overrides)

    print(f"Effective font family for plotting: {effective_style_config_for_print['font_settings']['family']}")
    if effective_style_config_for_print['font_settings']['family'] == 'serif':
        print(f"Preferred serif fonts: {effective_style_config_for_print['font_settings']['serif']}")

    fig_out, axes_out, path_out = plot_experiment_results(
        args.network,
        args.dataset,
        args.directory,
        args.values,
        style_config_overrides=custom_style_overrides,
        save_figure=args.save,
        pdf=args.pdf,
    )