#!/usr/bin/env python
# combine_plots_vis.py - Combine feature space and latent variable space visualizations (supporting different link functions)
# --------------------------------------------------------------------
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import numpy as np

def combine_latent_and_feature(feat_path, z_path, output_path):
    feat_img = Image.open(feat_path)
    z_img = Image.open(z_path)
    
    if feat_img.width != z_img.width:
        min_width = min(feat_img.width, z_img.width)
        feat_img = feat_img.resize((min_width, feat_img.height), Image.Resampling.LANCZOS)
        z_img = z_img.resize((min_width, z_img.height), Image.Resampling.LANCZOS)
    
    gap = 2
    
    combined_height = z_img.height + feat_img.height + gap
    combined_img = Image.new('RGB', (z_img.width, combined_height), (255, 255, 255))
    
    combined_img.paste(z_img, (0, 0))
    combined_img.paste(feat_img, (0, z_img.height + gap))
    
    combined_img.save(output_path, dpi=(300, 300))
    return combined_img

def combine_epochs_horizontal(combined_images, output_path):
    images = [Image.open(img_path) for img_path in combined_images]
    
    max_height = max(img.height for img in images)
    
    resized_images = []
    for img in images:
        if img.height != max_height:
            new_width = int(img.width * max_height / img.height)
            img = img.resize((new_width, max_height), Image.Resampling.LANCZOS)
        resized_images.append(img)
    
    total_width = sum(img.width for img in resized_images)
    
    final_img = Image.new('RGB', (total_width, max_height), (255, 255, 255))
    
    x_offset = 0
    for img in resized_images:
        final_img.paste(img, (x_offset, 0))
        x_offset += img.width
    
    final_img.save(output_path, dpi=(300, 300))
    return final_img

def create_combined_visualization(dataset, method, fold, link, vis_dir, output_dir):
    target_epochs = [0, 10, 50, 5000]
    
    tag = f"{dataset}_{link}_{method}_f{fold}"
    feat_dir = Path(vis_dir) / dataset / f"{tag}_feat_epochs"
    z_dir = Path(vis_dir) / dataset / f"{tag}_z_epochs"
    
    if not feat_dir.exists() or not z_dir.exists():
        print(f"Directories do not exist: {feat_dir} or {z_dir}")
        return False
    
    temp_dir = Path(output_dir) / "temp"
    temp_dir.mkdir(parents=True, exist_ok=True)
    
    combined_paths = []
    
    for epoch in target_epochs:
        epoch_str = f"{epoch:04d}"
        
        feat_path = feat_dir / f"epoch_{epoch_str}.png"
        z_path = z_dir / f"epoch_{epoch_str}.png"
        
        if not feat_path.exists() or not z_path.exists():
            print(f"Files do not exist: {feat_path} or {z_path}")
            continue
        
        combined_path = temp_dir / f"{tag}_epoch_{epoch_str}_combined.png"
        
        try:
            combine_latent_and_feature(feat_path, z_path, combined_path)
            combined_paths.append(combined_path)
            print(f"Created combined image: {combined_path}")
        except Exception as e:
            print(f"Error combining images for epoch {epoch}: {e}")
            continue
    
    if combined_paths:
        final_output_path = Path(output_dir) / f"{tag}_combined_visualization.png"
        
        try:
            combine_epochs_horizontal(combined_paths, final_output_path)
            print(f"Created final combined image: {final_output_path}")
            
            for temp_path in combined_paths:
                temp_path.unlink()
            
            return True
        except Exception as e:
            print(f"Error creating final combined image: {e}")
            return False
    else:
        print("Not enough epoch images found for combination")
        return False

def get_available_datasets(vis_dir):
    return ["ERA", "LEV", "SWD", "car", "winequality-red"]

def main():
    parser = argparse.ArgumentParser(description="Combine feature space and latent variable space visualizations (supporting different link functions)")
    parser.add_argument("--dataset", help="Dataset name, process all if not specified")
    parser.add_argument("--datasets", nargs='+', help="Multiple dataset names, separated by space")
    parser.add_argument("--vis-dir", default="vis", help="Visualization images directory")
    parser.add_argument("--output-dir", default="vis/combined", help="Output directory")
    parser.add_argument("--link", type=str, default="logit", 
                        choices=["logit", "probit"], 
                        help="Link function type: logit, probit")
    parser.add_argument("--fold", type=int, default=29, help="Fold number, default is 29")
    args = parser.parse_args()
    
    vis_dir = Path(args.vis_dir)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    if args.datasets:
        datasets = args.datasets
    elif args.dataset:
        datasets = [args.dataset]
    else:
        datasets = get_available_datasets(vis_dir)
    
    link_functions = [args.link]
    
    for dataset in datasets:
        print(f"\n{'='*50}")
        print(f"Processing dataset: {dataset} (link function: {args.link})")
        print('='*50)
        
        dataset_dir = vis_dir / dataset
        if not dataset_dir.exists():
            print(f"Dataset directory does not exist: {dataset_dir}")
            continue
        
        print(f"\nProcessing fixed threshold method ({args.link})...")
        success_fix = create_combined_visualization(
            dataset, "fix", args.fold, args.link, vis_dir, output_dir
        )
        
        print(f"\nProcessing learnable threshold method ({args.link})...")
        success_learn = create_combined_visualization(
            dataset, "learn", args.fold, args.link, vis_dir, output_dir
        )
        
        if success_fix and success_learn:
            print(f"✓ Dataset {dataset} processing completed")
        else:
            print(f"✗ Dataset {dataset} processing failed")
            if not success_fix:
                print("  - Fixed threshold method failed")
            if not success_learn:
                print("  - Learnable threshold method failed")
    
    print(f"\n{'='*50}")
    print("All combined visualizations completed!")
    print(f"Output directory: {output_dir}")
    print('='*50)

if __name__ == "__main__":
    main()