import argparse
import os
import pathlib
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from torch.utils.data import DataLoader
from manage.files import FileHandler
from data.data import get_dataset, CurrentDatasetInfo, Modality, StateSpace
from manage.logger import Logger
from manage.generation import GenerationManager
from manage.training import TrainingManager
from evaluate.EvaluationManager import EvaluationManager
from manage.checkpoints import load_experiment, save_experiment
from manage.setup import _get_device, _optimize_gpu, _set_seed

from ddpm_init import init_method_ddpm, init_models_optmizers_ls, init_learning_schedule

from script_utils import initialize_experiment, print_dict

def save_images(images, ncols=5, squared=False, save_file=None):
    """
    images: list of torch.Tensor (C,H,W) or numpy arrays (H,W,C)
    ncols: number of columns
    squared: if True, force nrows = ncols
    save_file: filepath for the output grid image
    """
    if squared:
        nrows = ncols
    else:
        nrows = int(np.ceil(len(images) / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 10))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)
    axes = np.array(axes).reshape(-1)
    for i, ax in enumerate(axes):
        if i < len(images):
            img = images[i]
            # handle torch.Tensor
            if hasattr(img, 'permute'):
                img = img.clamp(0, 1).permute(1, 2, 0).cpu().numpy()
            ax.imshow(img)
        ax.axis('off')
    if save_file is not None:
        plt.savefig(save_file, bbox_inches='tight', dpi=300)
    else:
        plt.savefig('generated_samples.png', bbox_inches='tight', dpi=300)
    plt.close()

def display_exp():
    parser = argparse.ArgumentParser(description='Display or save images from a training experiment')
    
    parser.add_argument(
        '--folders', 
        type=str, 
        nargs='+', 
        required=False,
        help="One or more folder globs, each containing exactly one eval* and one parameters* file"
    )    
    
    parser.add_argument(
        '--num_images',
        type=int,
        default=25,
        help="Number of images to generate"
    )
    
    parser.add_argument(
        '--reverse_steps',
        type=int,
        required='--saved_images_folder' not in os.sys.argv,
        help="Number of reverse steps to use for generation"
    )
    
    parser.add_argument(
        '--deterministic',
        action='store_true',
        default=False,
        help="Use deterministic generation"
    )
    
    parser.add_argument(
        '--output', 
        type=str, 
        default=None,
        help="Directory in which to save outputs (individual images or grid plot)"
    )
    
    parser.add_argument(
        '--only_save_images',
        action='store_true',
        default=False,
        help="Only save individual generated images; skip grid plot"
    )
    
    parser.add_argument(
            '--saved_images_folder',
            type=str,
            nargs='+',
            default=None,
            help="One or more directories containing saved .png images to plot (will produce one grid per folder)"
        )
    
    parser.add_argument(
        '--grid_size',
        type=int,
        default=None,
        help="Force an N×N grid when plotting (otherwise auto‐square to fit all images)"
    )
    
    parser.add_argument(
    '--saved_grid_folder',
    type=str,
    default=None,
    help="Directory containing grid PNGs named with '..._lr{lr}_bs{bs}_..._reversesteps{rs}_deterministic{True|False}.png'"
    )
    
    parser.add_argument('--lr', 
                        type=float, 
                        required=False,
                        default=None, 
                        help="learning rate to select")
    
    parser.add_argument('--bs', 
                        type=int, 
                        required=False, 
                        default=None,
                        help="learning rate to select")
    
    
    args = parser.parse_args()
    
    output_base = args.output or './figures'
    
    if args.saved_grid_folder:
        bool_lr = args.lr is not None
        bool_bs = args.bs is not None
        assert bool_lr or bool_bs, "Must provide --lr or --bs when using --saved_grid_folder"
        # only one of them should be active
        assert not (bool_lr and bool_bs), "Cannot provide both --lr and --bs when using --saved_grid_folder"
        
        import re
        folder = pathlib.Path(args.saved_grid_folder)
        # build a glob pattern to match only the requested hyperparams:
        det_str = str(bool(args.deterministic))
        if args.lr is not None:
            pattern = f"*lr_{args.lr}_bs*_*reversesteps_{args.reverse_steps}_deterministic_{det_str}_grid*.png"
        else:
            pattern = f"*lr*bs_{args.bs}_*reversesteps_{args.reverse_steps}_deterministic_{det_str}_grid*.png"
        all_grids = sorted(folder.glob(pattern))
        if not all_grids:
            raise FileNotFoundError(f"No grids matching {pattern!r} in {folder}")



        # extract (param_value, Path) tuples
        param_and_files = []
        for f in all_grids:
            if args.lr is not None:
                # parsing batch size
                m = re.search(r'bs_(\d+)', f.stem)
                if not m:
                    print(f"[WARN] cannot parse batch size from {f.name}, skipping")
                    continue
                val = int(m.group(1))
            else:
                # parsing learning rate (float, possibly scientific notation)
                m = re.search(r'lr_([0-9]+(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?)', f.stem)
                if not m:
                    print(f"[WARN] cannot parse learning rate from {f.name}, skipping")
                    continue
                val = float(m.group(1))
            param_and_files.append((val, f))
        if not param_and_files:
            raise RuntimeError("No valid parameters parsed from filenames")
        # sort by the parsed parameter
        param_and_files.sort(key=lambda x: x[0])
        params, files = zip(*param_and_files)

        # load images
        imgs = [mpimg.imread(str(f)) for f in files]

        # determine composite grid size
        import math
        N = len(imgs)
        G = args.grid_size or math.ceil(math.sqrt(N))

        # create composite figure
        fig, axes = plt.subplots(G, G, figsize=(G*3, G*3))
        axes = axes.flatten()
        for idx, ax in enumerate(axes):
            if idx < N:
                ax.imshow(imgs[idx])
                label = f"bs={params[idx]}" if args.lr is not None else f"lr={params[idx]:g}"
                ax.set_title(label)
            ax.axis('off')
        # super title & output name
        if args.lr is not None:
            title = f"lr={args.lr}, reversesteps={args.reverse_steps}, deterministic={det_str}"
            out_name = f"composite_lr_{args.lr}_rs_{args.reverse_steps}_det_{det_str}.png"
        else:
            title = f"bs={args.bs}, reversesteps={args.reverse_steps}, deterministic={det_str}"
            out_name = f"composite_bs_{args.bs}_rs_{args.reverse_steps}_det_{det_str}.png"
        fig.suptitle(title, fontsize=16)
        plt.tight_layout(rect=[0,0,1,0.95])

        # save and close
        out_path = os.path.join(output_base, out_name)
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        fig.savefig(out_path, dpi=300)
        plt.close(fig)
        print(f"Saved composite ({N} panels) to {out_path}")
        return

    if args.saved_images_folder:
        # filter ad keep only directories
        args.saved_images_folder = [
            f for f in args.saved_images_folder if os.path.isdir(f)
        ]
        # --- Plot‐only mode: for each provided folder, load .png and make a square grid ---
        for folder in args.saved_images_folder:
            pth = pathlib.Path(folder)
            img_files = sorted(pth.glob('*.png'))
            if not img_files:
                print(f"[WARN] No .png images found in {pth}, skipping.")
                continue

            # read all images as numpy arrays
            images = [mpimg.imread(str(f)) for f in img_files]

            # determine grid size (force square)
            if args.grid_size:
                gs = args.grid_size
            else:
                gs = int(np.ceil(np.sqrt(len(images))))

            # build output path, using the folder's basename
            out_fname = f"{pth.name}_grid_{gs}x{gs}.png"
            grid_path = os.path.join(output_base, out_fname)

            # save
            save_images(images, ncols=gs, squared=True, save_file=grid_path)
            print(f"Saved grid plot ({gs}×{gs}) of {len(images)} images from “{pth.name}” to {grid_path}")

        return

    os.makedirs(output_base, exist_ok=True)
    # --- Generation mode: for each experiment folder ---
    dirs = []
    for pat in args.folders or []:
        matched = list(pathlib.Path('.').glob(pat))
        if not matched:
            raise FileNotFoundError(f"No folders matched pattern {pat!r}")
        dirs.extend(matched)
        print(f"Matched {len(matched)} folders for pattern {pat!r}")
    dirs = [d for d in dirs if d.is_dir()]
    if not dirs:
        raise FileNotFoundError("No directories found in --folders globs")
    print(f"Found {len(dirs)} experiment directories.")
    
    for path in dirs:
        # find the first parameters file
        params_files = list(path.glob('*/param*.pt'))
        if not params_files:
            print(f"Skipping {path}: no parameter file found")
            continue
        p_path = params_files[0]
        print(f"Loading parameters from {p_path}")
        p = torch.load(p_path, weights_only=False)
        print_dict(p)
        
        trainer, logger, file_handler, models, optimizers, learning_schedules, method, eval, gen_manager = initialize_experiment(p)
        load_experiment(p=p, trainer=trainer, fh=file_handler, save_dir=path, checkpoint_steps=None)
        
        # pick EMA model if available
        select_ema_model = True
        if select_ema_model:
            ema_model = None 
            mus = [0.999, 0.9999]
            for ema_dict in trainer.ema_objects:
                if ema_dict['default'].mu in mus:
                    ema_model = ema_dict['default']
                    break
            if ema_model is None:
                raise ValueError('No EMA model with mu in {} found'.format(mus))
            models = {'default': ema_model.get_ema_model()}
            print(f"Using EMA model with mu={ema_model.mu}")
        
        # generate samples
        nsamples = args.num_images
        gen_manager.generate(
            models,
            nsamples=nsamples,
            reverse_steps=args.reverse_steps,
            deterministic=args.deterministic,
            print_progression=True,
            get_sample_history=False
        )
        samples = gen_manager.samples.clamp(0, 1)
        
        # prepare naming
        run_info = [
            p['data']['dataset'],
            p['optim']['optimizer'],
            'lr', p['optim']['lr'],
            'bs', p['training']['batch_size'],
            trainer.total_steps,
            'reversesteps', args.reverse_steps,
            'deterministic', args.deterministic
        ]
        plot_title = '_'.join(map(str, run_info))
        
        # individual‐save mode
        subdir = os.path.join(output_base, plot_title)
        os.makedirs(subdir, exist_ok=True)
        for idx, img in enumerate(samples):
            # img: torch.Tensor, shape (C, H, W)
            arr = img.clamp(0, 1).cpu().numpy()   # shape (C, H, W)

            if arr.ndim == 3 and arr.shape[0] == 1:
                # MNIST case: squeeze out channel, get (H, W)
                gray = arr[0, :, :]
                fname = os.path.join(subdir, f'image_{idx}.png')
                plt.imsave(fname, gray, cmap='gray', dpi=300)
            else:
                # RGB or RGBA: transpose to (H, W, C)
                rgb = np.transpose(arr, (1, 2, 0))
                fname = os.path.join(subdir, f'image_{idx}.png')
                plt.imsave(fname, rgb, dpi=300)
        print(f"Saved {len(samples)} individual images to {subdir}")
    
        # grid‐plot mode (unless --only_save_images)
        if not args.only_save_images:
            if args.grid_size:
                gs = args.grid_size
            else:
                gs = int(np.ceil(np.sqrt(len(samples))))
            grid_file = os.path.join(output_base, f'{plot_title}.png')
            save_images(list(samples), ncols=gs, squared=True, save_file=grid_file)
            print(f"Saved grid plot ({gs}×{gs}) of {len(samples)} samples to {grid_file}")
        
        if logger is not None:
            logger.stop()

if __name__ == '__main__':
    display_exp()

    