#!/usr/bin/env python3
# combine_plots.py
# --------------------------------------------------------------------
# Note: This functionality has been integrated into plot.py
# This file is kept for backward compatibility

import argparse
from pathlib import Path
from PIL import Image

# Dataset name to abbreviation mapping
DATASET_ABBREV = {
    "car": "CA",
    "ERA": "ER",
    "LEV": "LE",
    "SWD": "SW",
    "winequality-red": "WR"
}

# Metrics to combine, arranged in 2 rows and 5 columns (10 metrics total)
METRICS = [
    ["NLL",       "ACC",       "MAE",        "QWK",      "Within1"],     # First row
    ["MinSens",   "ONC1",      "ONC2-1",     "ONC2-2",   "ONC3"]         # Second row
]

def combine_plots(dataset_name, vis_dir, link_function="logit", output_dir=None, spacing=5):
    vis_path = Path(vis_dir)
    dataset_path = vis_path / dataset_name / link_function
    
    if not dataset_path.exists():
        print(f"Error: Dataset directory {dataset_path} does not exist!")
        return False
    
    # If no output directory is specified, use the visualization directory
    if output_dir is None:
        output_dir = vis_path / dataset_name / link_function
    else:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
    
    image_grid = []
    for row in METRICS:
        row_images = []
        for metric in row:
            img_path = dataset_path / f"{dataset_name}_{link_function}_{metric}_comparison.png"
            if img_path.exists():
                row_images.append(Image.open(img_path))
            else:
                print(f"Warning: Image for metric {metric} does not exist, skipping: {img_path}")
                return False
        image_grid.append(row_images)
    
    if len(image_grid) != 2 or any(len(row) != 5 for row in image_grid):
        print(f"Error: Could not find all required metric images!")
        return False

    row_heights = [max(img.height for img in row) for row in image_grid]
    col_widths = [max(image_grid[row][col].width for row in range(2)) for col in range(5)]

    total_width = sum(col_widths) + spacing * 4  # 5 columns with 4 spaces between
    total_height = sum(row_heights) + spacing * 1  # 2 rows with 1 space between
    
    combined = Image.new('RGB', (total_width, total_height), (255, 255, 255))
    
    y_offset = 0
    for row_idx, row in enumerate(image_grid):
        x_offset = 0
        for col_idx, img in enumerate(row):
            cell_width = col_widths[col_idx]
            cell_height = row_heights[row_idx]
            
            # Center in the cell
            x_pos = x_offset + (cell_width - img.width) // 2
            y_pos = y_offset + (cell_height - img.height) // 2
            
            combined.paste(img, (x_pos, y_pos))
            x_offset += cell_width + spacing
        
        y_offset += row_heights[row_idx] + spacing
    
    output_path_png = output_dir / f"{dataset_name}_{link_function}_combined_metrics.png"
    combined.save(output_path_png, "PNG", quality=95)

    # Get dataset abbreviation for PDF naming
    dataset_abbrev = DATASET_ABBREV.get(dataset_name, dataset_name)
    output_path_pdf = output_dir / f"metrics_curve_{link_function}_{dataset_abbrev}.pdf"
    combined.save(output_path_pdf, "PDF", quality=95)

    print(f"Successfully combined {len(METRICS) * len(METRICS[0])} images to:")
    print(f"  PNG: {output_path_png}")
    print(f"  PDF: {output_path_pdf}")

    return True

def combine_all_datasets(vis_dir, link_function="logit", output_dir=None, spacing=5):
    """Combine metric plots for all datasets"""
    vis_path = Path(vis_dir)
    
    dataset_dirs = [d for d in vis_path.iterdir() if d.is_dir()]
    if not dataset_dirs:
        print(f"Error: No dataset directories found in {vis_dir}!")
        return
    
    success_count = 0
    for dataset_dir in dataset_dirs:
        link_func_dir = dataset_dir / link_function
        if link_func_dir.exists() and link_func_dir.is_dir():
            if combine_plots(dataset_dir.name, vis_dir, link_function, output_dir, spacing):
                success_count += 1
    
    print(f"Completed image combination for {success_count} datasets.")

def main():
    parser = argparse.ArgumentParser(description="Combine multiple metric plots for a dataset into a 2x3 format")
    parser.add_argument("--dataset", help="Dataset name, processes all datasets if not specified")
    parser.add_argument("--vis-dir", required=True, help="Visualization directory path")
    parser.add_argument("--output-dir", help="Output directory path, defaults to visualization directory")
    parser.add_argument("--link-function", choices=["logit", "probit"], default="logit", 
                        help="Link function type, defaults to logit")
    parser.add_argument("--spacing", type=int, default=5, help="Spacing between images (pixels), defaults to 5")
    args = parser.parse_args()
    
    vis_path = Path(args.vis_dir)
    if not vis_path.exists():
        print(f"Error: Visualization directory {args.vis_dir} does not exist!")
        return
    
    if args.dataset:
        combine_plots(args.dataset, args.vis_dir, args.link_function, args.output_dir, args.spacing)
    else:
        combine_all_datasets(args.vis_dir, args.link_function, args.output_dir, args.spacing)
    
    print("Done!")

if __name__ == "__main__":
    main()