#!/usr/bin/env python
# plot.py
# --------------------------------------------------------------------

import pandas as pd
import numpy as np
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib as mpl
from pathlib import Path
from scipy.ndimage import gaussian_filter1d
from PIL import Image
import argparse

mpl.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 13,
    'axes.titlesize': 14,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'legend.fontsize': 12,
    'lines.linewidth': 1.8,
    'grid.linewidth': 0.8,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'grid.linestyle': ':',
    'axes.edgecolor': '.15',
    'figure.figsize': (3.5, 3.5),  # Square subplots
    'figure.dpi': 100,
    'axes.axisbelow': True,
    'axes.facecolor': '#FDFDFD'
})

# Dataset name to abbreviation mapping
DATASET_ABBREV = {
    "car": "CA",
    "ERA": "ER",
    "LEV": "LE",
    "SWD": "SW",
    "winequality-red": "WR"
}

METRIC_COLUMNS = {
    "NLL":    ("tr_nll",   "te_nll", r"$\mathcal{L}_{\mathrm{NLL}}$", True),
    "ACC":    ("tr_acc",   "te_acc", "Accuracy", False),
    "MAE":    ("tr_mae",   "te_mae", "MAE", True),
    "QWK":    ("tr_qwk",   "te_qwk", "QWK", False),
    "Within1": ("tr_within1", "te_within1", "Within-1 Acc", False),
    "MinSens": ("tr_min_sens", "te_min_sens", "Minimum Sensitivity", False),
    "ONC1":   ("onc1_tr",  "onc1_te", r"$\mathrm{ONC}_1$", True),
    "ONC2-1": ("onc21_tr", "onc21_te", r"$\mathrm{ONC}_{2-1}$", False),
    "ONC2-2": ("onc22_tr", "onc22_te", r"$\mathrm{ONC}_{2-2}$", False),
    "ONC3":   ("onc3_tr",  "onc3_te", r"$\mathrm{ONC}_3$", True),
}

# Metrics that should have y-axis range [0, 1] with slight margin
METRICS_01_RANGE = {"ACC", "QWK", "Within1", "MinSens", "ONC2-1", "ONC2-2"}

# Metrics arranged in 2x5 grid
# Row 1: NLL, ACC, MAE, QWK, Within1
# Row 2: MinSens, ONC1, ONC2-1, ONC2-2, ONC3
METRICS_GRID = [
    ["NLL",       "ACC",       "MAE",        "QWK",      "Within1"],
    ["MinSens",   "ONC1",      "ONC2-1",     "ONC2-2",   "ONC3"],
]

METHOD_COLORS = {
    "fix": {
        "color": "#1f77b4",  # Blue for both Fixed Train and Val
    },
    "learn": {
        "color": "#d62728",  # Red for both Learnable Train and Val
    }
}

def plot_metrics_comparison(df_fix_list, df_learn_list, tr_col, te_col, metric, title, use_log, out_path, smooth_sigma=2):

    epochs = df_fix_list[0]["epoch"].values

    # Stack arrays: shape (folds, epochs)
    fix_tr = np.stack([df[tr_col].values for df in df_fix_list], axis=0)
    fix_te = np.stack([df[te_col].values for df in df_fix_list], axis=0)
    learn_tr = np.stack([df[tr_col].values for df in df_learn_list], axis=0)
    learn_te = np.stack([df[te_col].values for df in df_learn_list], axis=0)

    if use_log:
        # For metrics that use log scale, apply log-space statistics
        epsilon = 1e-10
        fix_tr_log = np.log(fix_tr + epsilon)
        fix_te_log = np.log(fix_te + epsilon)
        learn_tr_log = np.log(learn_tr + epsilon)
        learn_te_log = np.log(learn_te + epsilon)

        mean_fix_tr_log = np.nanmean(fix_tr_log, axis=0)
        std_fix_tr_log = np.nanstd(fix_tr_log, axis=0)
        mean_fix_te_log = np.nanmean(fix_te_log, axis=0)
        std_fix_te_log = np.nanstd(fix_te_log, axis=0)
        mean_learn_tr_log = np.nanmean(learn_tr_log, axis=0)
        std_learn_tr_log = np.nanstd(learn_tr_log, axis=0)
        mean_learn_te_log = np.nanmean(learn_te_log, axis=0)
        std_learn_te_log = np.nanstd(learn_te_log, axis=0)

        mean_fix_tr = np.exp(mean_fix_tr_log)
        upper_fix_tr = np.exp(mean_fix_tr_log + std_fix_tr_log)
        lower_fix_tr = np.exp(mean_fix_tr_log - std_fix_tr_log)

        mean_fix_te = np.exp(mean_fix_te_log)
        upper_fix_te = np.exp(mean_fix_te_log + std_fix_te_log)
        lower_fix_te = np.exp(mean_fix_te_log - std_fix_te_log)

        mean_learn_tr = np.exp(mean_learn_tr_log)
        upper_learn_tr = np.exp(mean_learn_tr_log + std_learn_tr_log)
        lower_learn_tr = np.exp(mean_learn_tr_log - std_learn_tr_log)

        mean_learn_te = np.exp(mean_learn_te_log)
        upper_learn_te = np.exp(mean_learn_te_log + std_learn_te_log)
        lower_learn_te = np.exp(mean_learn_te_log - std_learn_te_log)

    else:
        mean_fix_tr = np.nanmean(fix_tr, axis=0)
        std_fix_tr = np.nanstd(fix_tr, axis=0)
        mean_fix_te = np.nanmean(fix_te, axis=0)
        std_fix_te = np.nanstd(fix_te, axis=0)
        mean_learn_tr = np.nanmean(learn_tr, axis=0)
        std_learn_tr = np.nanstd(learn_tr, axis=0)
        mean_learn_te = np.nanmean(learn_te, axis=0)
        std_learn_te = np.nanstd(learn_te, axis=0)

        upper_fix_tr = mean_fix_tr + std_fix_tr
        lower_fix_tr = mean_fix_tr - std_fix_tr
        upper_fix_te = mean_fix_te + std_fix_te
        lower_fix_te = mean_fix_te - std_fix_te
        upper_learn_tr = mean_learn_tr + std_learn_tr
        lower_learn_tr = mean_learn_tr - std_learn_tr
        upper_learn_te = mean_learn_te + std_learn_te
        lower_learn_te = mean_learn_te - std_learn_te
    
    def smooth(data, sigma=smooth_sigma):
        return gaussian_filter1d(data, sigma=sigma) if len(data) > sigma else data
    
    mean_fix_tr_smooth = smooth(mean_fix_tr)
    upper_fix_tr_smooth = smooth(upper_fix_tr)
    lower_fix_tr_smooth = smooth(lower_fix_tr)
    
    mean_fix_te_smooth = smooth(mean_fix_te)
    upper_fix_te_smooth = smooth(upper_fix_te)
    lower_fix_te_smooth = smooth(lower_fix_te)
    
    mean_learn_tr_smooth = smooth(mean_learn_tr)
    upper_learn_tr_smooth = smooth(upper_learn_tr)
    lower_learn_tr_smooth = smooth(lower_learn_tr)
    
    mean_learn_te_smooth = smooth(mean_learn_te)
    upper_learn_te_smooth = smooth(upper_learn_te)
    lower_learn_te_smooth = smooth(lower_learn_te)

    fig, ax = plt.subplots(figsize=(3.5, 3.5))

    # Fixed training - solid line
    ax.plot(epochs, mean_fix_tr_smooth, color=METHOD_COLORS["fix"]["color"], linewidth=2.5)
    ax.fill_between(epochs, lower_fix_tr_smooth, upper_fix_tr_smooth, color=METHOD_COLORS["fix"]["color"], alpha=0.15)

    # Fixed validation - dashed line, same color
    ax.plot(epochs, mean_fix_te_smooth, color=METHOD_COLORS["fix"]["color"], linewidth=2.5, linestyle='--')
    ax.fill_between(epochs, lower_fix_te_smooth, upper_fix_te_smooth, color=METHOD_COLORS["fix"]["color"], alpha=0.15)

    # Learnable training - solid line
    ax.plot(epochs, mean_learn_tr_smooth, color=METHOD_COLORS["learn"]["color"], linewidth=2.5)
    ax.fill_between(epochs, lower_learn_tr_smooth, upper_learn_tr_smooth, color=METHOD_COLORS["learn"]["color"], alpha=0.15)

    # Learnable validation - dashed line, same color
    ax.plot(epochs, mean_learn_te_smooth, color=METHOD_COLORS["learn"]["color"], linewidth=2.5, linestyle='--')
    ax.fill_between(epochs, lower_learn_te_smooth, upper_learn_te_smooth, color=METHOD_COLORS["learn"]["color"], alpha=0.15)

    if use_log:
        ax.set_yscale('log')
    ax.set_xlabel('Epoch')
    ax.set_ylabel(title)
    ax.set_box_aspect(1)  # Ensure plot area is square

    # Set y-axis range for metrics that should be in [0, 1]
    if metric in METRICS_01_RANGE:
        ax.set_ylim(-0.05, 1.05)

    ax.grid(True, linestyle=':', alpha=0.6)
    
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

def create_legend_image(out_dir):
    """Create a legend image that fits in the grid cell"""
    # Create figure with same size as metric plots
    fig, ax = plt.subplots(figsize=(3.5, 3.5))

    # Fixed and Learnable with train/test distinction
    fixed_train = ax.plot([], [], color=METHOD_COLORS["fix"]["color"],
                         label="Fixed Train", linewidth=2.5)[0]
    fixed_test = ax.plot([], [], color=METHOD_COLORS["fix"]["color"],
                        label="Fixed Val", linewidth=2.5, linestyle='--')[0]
    learn_train = ax.plot([], [], color=METHOD_COLORS["learn"]["color"],
                         label="Learnable Train", linewidth=2.5)[0]
    learn_test = ax.plot([], [], color=METHOD_COLORS["learn"]["color"],
                        label="Learnable Val", linewidth=2.5, linestyle='--')[0]

    # Create legend in the center - 4 rows, 1 column
    legend = ax.legend(handles=[fixed_train, fixed_test, learn_train, learn_test],
                      loc='center', ncol=1, frameon=True, framealpha=1.0,
                      fontsize=12, handlelength=2.5, borderpad=1.2)

    # Make background white
    ax.set_facecolor('white')
    ax.set_box_aspect(1)  # Ensure legend box is square
    ax.axis('off')

    out_path = out_dir / "legend.png"
    plt.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    print(f"  Created legend image: {out_path}")

def find_history_files(ds_dir, link_function):
    """Find history files for the specified link function (logit or probit)"""
    fix_files = sorted(ds_dir.glob(f"fold*_history_{link_function}_fix.csv"))
    learn_files = sorted(ds_dir.glob(f"fold*_history_{link_function}_learn.csv"))
    
    if not fix_files:
        fix_files = sorted(ds_dir.glob("fold*_history_fix.csv"))
    if not learn_files:
        learn_files = sorted(ds_dir.glob("fold*_history_learn.csv"))
        
    return fix_files, learn_files
    
def process_dataset(dataset, results_dir, vis_dir, link_function="logit"):
    ds_dir = results_dir / dataset
    if not ds_dir.is_dir():
        print(f"Skipping non-existent dataset: {dataset}")
        return
    
    print(f"Processing dataset: {dataset} with link function: {link_function}")
    
    fix_files, learn_files = find_history_files(ds_dir, link_function)
    
    if not fix_files or not learn_files:
        print(f"  {dataset} missing history files for {link_function}, skipping.")
        return
    
    dfs_fix = [pd.read_csv(fp) for fp in fix_files]
    dfs_learn = [pd.read_csv(fp) for fp in learn_files]
    
    agg = pd.DataFrame({"epoch": dfs_fix[0]["epoch"]})

    dataset_out_dir = vis_dir / dataset / link_function
    dataset_out_dir.mkdir(parents=True, exist_ok=True)

    # Plot and aggregate each metric
    for metric, (tr_col, te_col, title, use_log) in METRIC_COLUMNS.items():

        if tr_col not in dfs_fix[0].columns or te_col not in dfs_fix[0].columns:
            print(f"  Warning: {metric} columns ({tr_col}, {te_col}) not found in fixed threshold data, skipping.")
            continue
            
        if tr_col not in dfs_learn[0].columns or te_col not in dfs_learn[0].columns:
            print(f"  Warning: {metric} columns ({tr_col}, {te_col}) not found in learnable threshold data, skipping.")
            continue
            
        out_png = dataset_out_dir / f"{dataset}_{link_function}_{metric}_comparison.png"
        plot_metrics_comparison(
            dfs_fix, dfs_learn, tr_col, te_col, 
            metric, title, use_log, out_png
        )
        
        arr_fix_tr = np.stack([df[tr_col].values for df in dfs_fix], axis=0)
        arr_fix_te = np.stack([df[te_col].values for df in dfs_fix], axis=0)
        arr_learn_tr = np.stack([df[tr_col].values for df in dfs_learn], axis=0)
        arr_learn_te = np.stack([df[te_col].values for df in dfs_learn], axis=0)
        
        agg[f"{metric}_fix_tr_mean"] = arr_fix_tr.mean(axis=0)
        agg[f"{metric}_fix_tr_std"] = arr_fix_tr.std(axis=0)
        agg[f"{metric}_fix_te_mean"] = arr_fix_te.mean(axis=0)
        agg[f"{metric}_fix_te_std"] = arr_fix_te.std(axis=0)
        agg[f"{metric}_learn_tr_mean"] = arr_learn_tr.mean(axis=0)
        agg[f"{metric}_learn_tr_std"] = arr_learn_tr.std(axis=0)
        agg[f"{metric}_learn_te_mean"] = arr_learn_te.mean(axis=0)
        agg[f"{metric}_learn_te_std"] = arr_learn_te.std(axis=0)
    
    out_csv = ds_dir / f"agg_history_{link_function}_combined.csv"
    agg.to_csv(out_csv, index=False)
    print(f"  Wrote {out_csv} and plots to {dataset_out_dir}")

    # Combine plots into 2x5 grid
    print(f"  Combining plots into 2x5 grid...")
    combine_plots(dataset, link_function, vis_dir)

def combine_plots(dataset, link_function, vis_dir, spacing=5):
    """Combine multiple metric plots into a 2x5 grid"""
    vis_path = Path(vis_dir)
    dataset_dir = vis_path / dataset / link_function

    if not dataset_dir.exists():
        print(f"  Error: Dataset directory {dataset_dir} does not exist!")
        return False

    # Load images for each metric in the grid
    image_grid = []
    for row in METRICS_GRID:
        row_images = []
        for metric in row:
            img_path = dataset_dir / f"{dataset}_{link_function}_{metric}_comparison.png"
            if img_path.exists():
                row_images.append(Image.open(img_path))
            else:
                print(f"  Warning: Image not found: {img_path}")
                return False
        image_grid.append(row_images)

    # Calculate dimensions
    row_heights = []
    for row in image_grid:
        row_heights.append(max(img.height for img in row))

    col_widths = []
    for col_idx in range(5):
        col_widths.append(max(image_grid[row][col_idx].width for row in range(2)))

    total_width = sum(col_widths) + spacing * 4  # 5 columns, 4 spaces
    total_height = sum(row_heights) + spacing * 1  # 2 rows, 1 space

    # Create combined image
    combined = Image.new('RGB', (total_width, total_height), (255, 255, 255))

    # Paste images into grid
    y_offset = 0
    for row_idx, row in enumerate(image_grid):
        x_offset = 0
        for col_idx, img in enumerate(row):
            cell_width = col_widths[col_idx]
            cell_height = row_heights[row_idx]

            # Center in the cell
            x_pos = x_offset + (cell_width - img.width) // 2
            y_pos = y_offset + (cell_height - img.height) // 2

            combined.paste(img, (x_pos, y_pos))
            x_offset += cell_width + spacing

        y_offset += row_heights[row_idx] + spacing

    # Save combined image as PNG and PDF
    output_path_png = dataset_dir / f"{dataset}_{link_function}_combined_metrics.png"
    combined.save(output_path_png, "PNG", quality=95)

    # Get dataset abbreviation for PDF naming
    dataset_abbrev = DATASET_ABBREV.get(dataset, dataset)
    output_path_pdf = dataset_dir / f"metrics_curve_{link_function}_{dataset_abbrev}.pdf"
    combined.save(output_path_pdf, "PDF", quality=95)
    print(f"  Combined image saved: {output_path_pdf}")

    return True

def get_available_datasets(results_dir):
    datasets = set()
    for pattern in ["fold*_history_logit_*.csv", "fold*_history_probit_*.csv", "fold*_history_*.csv"]:
        for file_path in results_dir.glob(f"*/{pattern}"):
            datasets.add(file_path.parent.name)
    
    return list(sorted(datasets))

def main():
    """Main function - parse arguments and process datasets"""
    parser = argparse.ArgumentParser(description="Generate improved visualizations for ordinal regression results")
    parser.add_argument("--dataset", default="ERA", help="Dataset name, defaults to ERA")
    parser.add_argument("--results-dir", default="results", help="Results directory path")
    parser.add_argument("--vis-dir", default="vis/aggregated", help="Visualization output directory")
    parser.add_argument("--link-function", choices=["logit", "probit", "both"], default="logit", 
                       help="Link function to use, defaults to logit")
    args = parser.parse_args()
    
    results_dir = Path(args.results_dir)
    vis_dir = Path(args.vis_dir)
    
    link_functions = []
    if args.link_function == "both":
        link_functions = ["logit", "probit"]
    else:
        link_functions = [args.link_function]
    
    if args.dataset:
        if not (results_dir / args.dataset).exists():
            print(f"Error: Dataset '{args.dataset}' does not exist in {results_dir}!")
            return
        
        print(f"Processing single dataset: {args.dataset}")
        for link_function in link_functions:
            process_dataset(args.dataset, results_dir, vis_dir, link_function)
    else:
        datasets = get_available_datasets(results_dir)
        if not datasets:
            print(f"Error: No valid datasets found in {results_dir}!")
            return
        
        print(f"Found {len(datasets)} datasets: {', '.join(datasets)}")
        for dataset in datasets:
            for link_function in link_functions:
                process_dataset(dataset, results_dir, vis_dir, link_function)
    
    print("All visualizations completed!")
    
if __name__ == "__main__":
    main()