# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""2D toy example from the paper "Guiding a Diffusion Model with a Bad Version of Itself"."""

import os
import copy
import pickle
import warnings
import functools
import numpy as np
import torch
import matplotlib.pyplot as plt
import click
import tqdm
import dnnlib
from torch_utils import persistence

from toy_example import gt

warnings.filterwarnings('ignore', 'You are using `torch.load` with `weights_only=False`')

#----------------------------------------------------------------------------
# egg EMA
class EMA:
    def __init__(self, m=0, beta=0.9, epsilon=1e-8):
        self.m = m
        self.beta = beta
        self.epsilon = epsilon

    def update(self, x):
        self.m = self.beta * self.m + x
        return self.m

    def scale(self, x1, x2):
        if isinstance(x1, int):
            return torch.tensor(0)
        x1_norm = torch.norm(x1, dim=1, keepdim=True)
        x2_norm = torch.norm(x2, dim=1, keepdim=True)
        return x1 * x2_norm / (x1_norm + self.epsilon)
    
    def __call__(self, x, norm_scale=True):
        m_hat = self.update(x)
        return self.scale(m_hat, x) if norm_scale else m_hat
    
#----------------------------------------------------------------------------
# apg momentum

# self.running_average = self.momentum * self.running_average + update_value
class MomentumBuffer:
    def __init__(self, momentum: float):
        self.momentum = momentum
        self.running_average = 0
        
    def update(self, update_value: torch.Tensor):
        new_average = self.momentum * self.running_average
        self.running_average = update_value + new_average
        
# decompose v0 into parrallel and orthogonal components with respect to v1
def project(
        v0: torch.Tensor, # [B, C, H, W]
        v1: torch.Tensor, # [B, C, H, W]
    ):
    dtype = v0.dtype
    v0, v1 = v0.double(), v1.double()
    # v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
    v1 = torch.nn.functional.normalize(v1, dim=[-1])
    # v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
    v0_parallel = (v0 * v1).sum(dim=-1, keepdim=True) * v1
    v0_orthogonal = v0 - v0_parallel
    return v0_parallel.to(dtype), v0_orthogonal.to(dtype)

# convert pred_cond to normmalized_pred
def adaptive_projected_guidance(
        pred_cond: torch.Tensor, # [B, C, H, W]
        pred_uncond: torch.Tensor, # [B, C, H, W]
        guidance_scale: float,
        momentum_buffer: MomentumBuffer = None,
        eta: float = 0.0, # 1.0
        norm_threshold: float = 2.5, # 0.0
    ):
    diff = pred_cond - pred_uncond
    if momentum_buffer is not None:
        momentum_buffer.update(diff)
        diff = momentum_buffer.running_average
    if norm_threshold > 0:
        ones = torch.ones_like(diff)
        # diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
        diff_norm = diff.norm(p=2, dim=[-1], keepdim=True)
        scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
        diff = diff * scale_factor
    diff_parallel, diff_orthogonal = project(diff, pred_cond)
    normalized_update = diff_orthogonal + eta * diff_parallel
    pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
    return pred_guided

#----------------------------------------------------------------------------
# # fdg
# from kornia.geometry.transform import build_laplacian_pyramid
# def project(
#     v0: torch.Tensor, # [B, C, H, W]
#     v1: torch.Tensor, # [B, C, H, W]
# ):
#     dtype = v0.dtype
#     v0, v1 = v0.double(), v1.double()
#     # v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
#     v1 = torch.nn.functional.normalize(v1, dim=[-1])
#     # v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
#     v0_parallel = (v0 * v1).sum(dim=-1, keepdim=True) * v1
#     v0_orthogonal = v0 - v0_parallel
#     return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
    
# def build_image_from_pyramid(pyramid):
#     img = pyramid[-1]
#     for i in range(len(pyramid) - 2, -1, -1):
#         img = kornia.geometry.pyrup(img) + pyramid[i]
#     return img

# # We assume all model predictions are converted to "x_0" prediction. 
# def laplacian_guidance(
#     pred_cond: torch.Tensor, # [B, C, H, W]
#     pred_uncond: torch.Tensor, # [B, C, H, W]
#     guidance_scale=[1.0, 1.0], # Guidance scales from high- to low-frequency
#     parallel_weights=None, # Optional weights for projection
# ):
#     levels = len(guidance_scale)
#     if parallel_weights == None:
#         parallel_weights = [1.0] * levels
    
#     pred_cond_pyramid = build_laplacian_pyramid(pred_cond, levels)
#     pred_uncond_pyramid = build_laplacian_pyramid(pred_uncond, levels)
    
#     pred_guided_pyramid = []
#     parameters = zip(
#         pred_cond_pyramid, pred_uncond_pyramid, guidance_scale, parallel_weights
#         )
        
#     for idx, (p_cond, p_uncond, scale, par_weight) in enumerate(parameters):
#         diff = p_cond - p_uncond
#         diff_parallel, diff_orthogonal = project(diff, p_cond)
#         diff = par_weight * diff_parallel + diff_orthogonal
#         p_guided = p_cond + (scale - 1) * diff
#         pred_guided_pyramid.append(p_guided)
#     pred_guided = build_image_from_pyramid(pred_guided_pyramid)
#     return pred_guided.to(pred_cond.dtype)

#----------------------------------------------------------------------------
# Simulate the EDM sampling ODE for the given set of initial sample points.
# Adapted from generate_images.py.

def do_sample(net, x_init, guidance=1, gnet=None, num_steps=32, sigma_min=0.002, sigma_max=5, rho=7, **kwargs):
    # Guided denoiser.
    if kwargs.get('pred_type') == 'cfg':
        def denoise(x, sigma):
            score = net.score(x, sigma)
            if gnet is not None:
                score = gnet.score(x, sigma).lerp(score, guidance)
            return x + score * (sigma ** 2)
    if kwargs.get('pred_type') == 'ag':
        def denoise(x, sigma):
            score = net.score(x, sigma)
            if gnet is not None:
                score = gnet.score(x, sigma).lerp(score, guidance)
            return x + score * (sigma ** 2)
    elif kwargs.get('pred_type') == 'reg':
        cond_ema = EMA()
        uncond_ema = EMA()
        def denoise(x, sigma):
            score = net.score(x, sigma)
            if gnet is not None:
                score_uncond = gnet.score(x, sigma)
                score_cond_ema = cond_ema(score)
                score_uncond_ema = uncond_ema(score_uncond)
                score_uncond += kwargs.get('gamma') * (score_cond_ema - score_uncond_ema)
                score = score_uncond.lerp(score, guidance)
            return x + score * (sigma ** 2)
    elif kwargs.get('pred_type') == 'apg':
        momentum_buffer = MomentumBuffer(momentum=kwargs.get('beta_1', -0.75))
        def denoise(x, sigma):
            score = net.score(x, sigma)
            if gnet is not None:
                score_uncond = gnet.score(x, sigma)
                score = adaptive_projected_guidance(
                    score, score_uncond, guidance, momentum_buffer=momentum_buffer,
                    eta=kwargs.get('eta', 0.0), norm_threshold=kwargs.get('r_scale', 2.5),
                )
            return x + score * (sigma ** 2)
    # elif kwargs.get('pred_type') == 'fdg':
    #     def denoise(x, sigma):
    #         score = net.score(x, sigma)
    #         if gnet is not None:
    #             score_uncond = gnet.score(x, sigma)
    #             score = laplacian_guidance(
    #                 score, score_uncond, guidance_scale=[kwargs.get('w_high'), kwargs.get('w_low')],
    #                 parallel_weights=[kwargs.get('w_par'), kwargs.get('w_par')],
    #             )
    #         return x + score * (sigma ** 2)
    elif kwargs.get('pred_type') == 'ig':
        def denoise(x, sigma):
            score = net.score(x, sigma)
            if gnet is not None or (kwargs.get('sigma_low') < sigma.item() <= kwargs.get('sigma_high')):
                score = gnet.score(x, sigma).lerp(score, guidance)
            return x + score * (sigma ** 2)
    elif kwargs.get('pred_type') == 'tcfg':
        def denoise(x, sigma):
            score = net.score(x, sigma)
            if gnet is not None:
                score_uncond = gnet.score(x, sigma)
                all_noise = torch.stack((score, score_uncond), dim=1).to(dtype=torch.float32)
                all_noise = all_noise.reshape(all_noise.size(0), all_noise.size(1), -1)
                U, S, Vh = torch.linalg.svd(all_noise, full_matrices=False)
                Vh = Vh.to(all_noise.device)
                Vh_modified = Vh.clone().to(all_noise.device)
                Vh_modified[:,1] = 0
                noise_null_flat = score_uncond.reshape(score_uncond.size(0), 1, -1).to(dtype=torch.float32)
                noise_null_flat = noise_null_flat.to(Vh.device)
                x_Vh = torch.matmul(noise_null_flat, Vh.transpose(-2, -1))
                x_Vh_V = torch.matmul(x_Vh, Vh_modified)
                score_uncond = x_Vh_V.reshape(*score_uncond.shape).to(score.dtype).to(score.device)
                score = score_uncond.lerp(score, guidance)
            return x + score * (sigma ** 2)
    elif kwargs.get('pred_type') == 'cfgpp':
        def denoise(x, sigma):
            score = net.score(x, sigma)
            score_uncond = None
            if gnet is not None:
                score_uncond = gnet.score(x, sigma).lerp(score, guidance)
            return x + score * (sigma ** 2), x + score_uncond * (sigma ** 2)
    

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float32, device=x_init.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    x_cur = x_init
    trajectory = [x_cur]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1

        # Euler step.
        if kwargs.get('pred_type') == 'cfgpp':
            score, score_uncond = denoise(x_cur, t_cur)
            d_cur = (x_cur - score) / t_cur
            d_cur_ref = (x_cur - score_uncond) / t_cur
            x_next = x_cur + - t_cur * d_cur + t_next * d_cur_ref
        else:
            d_cur = (x_cur - denoise(x_cur, t_cur)) / t_cur
            x_next = x_cur + (t_next - t_cur) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            if kwargs.get('pred_type') == 'cfgpp':
                score_prime, score_prime_ref = denoise(x_next, t_next)
                d_prime = (x_next - score_prime) / t_next
                d_prime_ref = (x_next - score_prime_ref) / t_next
                x_next = x_cur - t_cur * (0.5 * d_cur + 0.5 * d_prime) + t_next * (0.5 * d_cur_ref + 0.5 * d_prime_ref)
            else:
                d_prime = (x_next - denoise(x_next, t_next)) / t_next
                x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)

        # Record trajectory.
        x_cur = x_next
        trajectory.append(x_cur)
    return torch.stack(trajectory)

#----------------------------------------------------------------------------
# Draw the given set of plot elements using matplotlib.

def do_plot(
    net=None, guidance=1, gnet=None, elems={'gt_uncond', 'gt_outline', 'samples'},
    view_x=0, view_y=0, view_size=1.6, grid_resolution=400, arrow_len=0.002,
    num_samples=1<<13, seed=1, sample_distance=0, sigma_max=5,
    device=torch.device('cuda'), **kwargs,
):
    # Generate initial samples.
    if any(x.startswith(y) for x in elems for y in ['samples', 'trajectories', 'scores']):
        samples = gt('A', device).sample(num_samples, sigma_max, generator=torch.Generator(device).manual_seed(seed))
        if sample_distance > 0:
            ok = torch.ones(len(samples), dtype=torch.bool, device=device)
            for i in range(1, len(samples)):
                ok[i] = (samples[i] - samples[:i][ok[:i]]).square().sum(-1).sqrt().min() >= sample_distance
            samples = samples[ok]

    # Run sampler.
    if any(x.startswith(y) for x in elems for y in ['samples', 'trajectories']):
        trajectories = do_sample(net=(net or gt('A', device)), x_init=samples, guidance=guidance, gnet=gnet, sigma_max=sigma_max, **kwargs)

    # Initialize plot.
    gridx = torch.linspace(view_x - view_size, view_x + view_size, steps=grid_resolution, device=device)
    gridy = torch.linspace(view_y - view_size, view_y + view_size, steps=grid_resolution, device=device)
    gridxy = torch.stack(torch.meshgrid(gridx, gridy, indexing='xy'), axis=-1)
    plt.xlim(float(gridx[0]), float(gridx[-1]))
    plt.ylim(float(gridy[0]), float(gridy[-1]))
    plt.gca().set_aspect('equal')
    plt.gca().set_axis_off()

    # Plot helper functions.
    def contours(values, levels, colors=None, cmap=None, alpha=1, linecolors='black', linealpha=1, linewidth=2.5):
        values = -(values.max() - values).sqrt().cpu().numpy()
        plt.contourf(gridx.cpu().numpy(), gridy.cpu().numpy(), values, levels=levels, antialiased=True, extend='max', colors=colors, cmap=cmap, alpha=alpha)
        plt.contour(gridx.cpu().numpy(), gridy.cpu().numpy(), values, levels=levels, antialiased=True, colors=linecolors, alpha=linealpha, linestyles='solid', linewidths=linewidth)
    def lines(pos, color='black', alpha=1):
        plt.plot(*pos.cpu().numpy().T, '-', linewidth=5, solid_capstyle='butt', color=color, alpha=alpha)
    def arrows(pos, dir, color='black', alpha=1):
        plt.quiver(*pos.cpu().numpy().T, *dir.cpu().numpy().T * arrow_len, scale=0.6, width=5e-3, headwidth=4, headlength=3, headaxislength=2.5, capstyle='round', color=color, alpha=alpha)
    def points(pos, color='black', alpha=1, size=30):
        plt.plot(*pos.cpu().numpy().T, '.', markerfacecolor='black', markeredgecolor='none', color=color, alpha=alpha, markersize=size)

    # Draw requested plot elements.
    if 'p_net' in elems:            contours(net.logp(gridxy, sigma_max), levels=np.linspace(-2.5, 2.5, num=20)[1:-1], cmap='Greens', linealpha=0.2)
    if 'p_gnet' in elems:           contours(gnet.logp(gridxy, sigma_max), levels=np.linspace(-2.5, 3.5, num=20)[1:-1], cmap='Reds', linealpha=0.2)
    if 'p_ratio' in elems:          contours(net.logp(gridxy, sigma_max) - gnet.logp(gridxy, sigma_max), levels=np.linspace(-2.2, 1.0, num=20)[1:-1], cmap='Blues', linealpha=0.2)
    if 'gt_uncond' in elems:        contours(gt('AB', device).logp(gridxy), levels=[-2.12, 0], colors=[[0.9,0.9,0.9]], linecolors=[[0.7,0.7,0.7]], linewidth=1.5)
    if 'gt_outline' in elems:       contours(gt('A', device).logp(gridxy), levels=[-2.12, 0], colors=[[1.0,0.8,0.6]], linecolors=[[0.8,0.6,0.5]], linewidth=1.5)
    if 'gt_smax' in elems:          contours(gt('A', device).logp(gridxy, sigma_max), levels=[-1.41, 0], colors=['C1'], alpha=0.2, linealpha=0.2)
    if 'gt_shaded' in elems:        contours(gt('A', device).logp(gridxy), levels=np.linspace(-2.5, 3.07, num=15)[1:-1], cmap='Oranges', linealpha=0.2)
    if 'trajectories' in elems:     lines(trajectories.transpose(0, 1), alpha=0.3)
    if 'scores_net' in elems:       arrows(samples, net.score(samples, sigma_max), color='C2')
    if 'scores_gnet' in elems:      arrows(samples, gnet.score(samples, sigma_max), color='C3')
    if 'scores_ratio' in elems:     arrows(samples, net.score(samples, sigma_max) - gnet.score(samples, sigma_max), color='C0')
    if 'samples' in elems:          points(trajectories[-1], size=kwargs.get('size',15), alpha=kwargs.get('alpha',0.25))
    if 'samples_before' in elems:   points(samples)
    if 'samples_after' in elems:    points(trajectories[-1])

#----------------------------------------------------------------------------
# Main command line.

@click.group()
def cmdline():
    """2D toy example from the paper "Guiding a Diffusion Model with a Bad Version of Itself".

    Examples:

    \b
    # Visualize sampling distributions using autoguidance.
    python toy_example.py plot

    \b
    # Same, but save the plot as PNG instead of displaying it.
    python toy_example.py plot --save=out.png

    \b
    # Same, but specify the models explicitly.
    python toy_example.py plot \\
        --net=https://nvlabs-fi-cdn.nvidia.com/edm2/toy-example/clsA-layers04-dim64/iter4096.pkl \\
        --gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/toy-example/clsA-layers04-dim32/iter0512.pkl \\
        --guidance=3

    \b
    # Same, but using classifier-free guidance.
    python toy_example.py plot \\
        --net=https://nvlabs-fi-cdn.nvidia.com/edm2/toy-example/clsA-layers04-dim64/iter4096.pkl \\
        --gnet=https://nvlabs-fi-cdn.nvidia.com/edm2/toy-example/clsAB-layers04-dim32/iter0512.pkl \\
        --guidance=4

    \b
    # Retrain the main model and visualize progress.
    python toy_example.py train

    \b
    # Retrain the main model and save snapshots.
    python toy_example.py train \\
        --outdir=toy-example/clsA-layers04-dim64 \\
        --cls=A --layers=4 --dim=64 --viz=false
    """
    if os.environ.get('WORLD_SIZE', '1') != '1':
        raise click.ClickException('Distributed execution is not supported.')

#----------------------------------------------------------------------------
# 'plot' subcommand.

@cmdline.command()
@click.option('--net',      help='Main model  [default: download]', metavar='PKL|URL',          type=str, default='https://nvlabs-fi-cdn.nvidia.com/edm2/toy-example/clsA-layers04-dim64/iter4096.pkl')
@click.option('--gnet',     help='Guiding model  [default: autoguidance]', metavar='PKL|URL',   type=str, default='https://nvlabs-fi-cdn.nvidia.com/edm2/toy-example/clsA-layers04-dim32/iter0512.pkl')
@click.option('--guidance', help='Guidance weight', metavar='FLOAT',                            type=float, default=3, show_default=True)
@click.option('--save',     help='Save figure, do not display', metavar='PNG|PDF',              type=str, default='plot')
########## added ##########
@click.option('--device',   help='device configuration', metavar='STR',                         type=str, default='cuda')
@click.option('--plot_type',help='ploting method', metavar='STR',                               type=str, default='dist')
@click.option('--pred_type',help='guidance method', metavar='STR',                              type=str, default='cfg')
@click.option('--num_samples',help='sample numbers for dist', metavar='INT',                    type=int, default=1<<16)
@click.option('--alpha',    help='sample dot alpha for dist', metavar='FLOAT',                        type=float, default=0.25)
@click.option('--size',     help='sample dot size for dist', metavar='INT',                           type=int, default=15)
##### reg #####
@click.option('--beta_1',   help='cond EMA weight', metavar='FLOAT',                            type=float, default=0.9, show_default=True)
@click.option('--beta_2',   help='uncond EMA weight', metavar='FLOAT',                          type=float, default=0.9, show_default=True)
@click.option('--gamma',    help='rectifying weight', metavar='FLOAT',                          type=float, default=0.3, show_default=True)
###############
##### apg #####
@click.option('--beta',     help='Momentum weight', metavar='FLOAT',                            type=float, default=0.9, show_default=True)
@click.option('--eta',      help='Parallel weight', metavar='FLOAT',                            type=float, default=0.0, show_default=True)
@click.option('--r_scale',  help='Momentum weight', metavar='FLOAT',                            type=float, default=2.5, show_default=True)
###############
##### fdg #####
@click.option('--w_high',   help='high freq weight', metavar='FLOAT',                           type=float, default=3.0, show_default=True)
@click.option('--w_low',    help='low freq weight', metavar='FLOAT',                            type=float, default=1.0, show_default=True)
@click.option('--w_par',    help='parallel freq weight', metavar='FLOAT',                       type=float, default=3.0, show_default=True)
###############
##### ig ######
@click.option('--sigma_low',    help='sigma low', metavar='FLOAT',                              type=float, default=3.0, show_default=True)
@click.option('--sigma_high',    help='sigma high', metavar='FLOAT',                            type=float, default=3.0, show_default=True)
###############
#### tcfg #####
###############
#### cfg++ ####
@click.option('--sigma_cfgpp',    help='sigma', metavar='FLOAT',                                type=float, default=3.0, show_default=True)
###############

########## added ##########
def plot(net, gnet, guidance, save, num_samples, device=torch.device('cuda'), **kwargs):
    """Visualize sampling distributions with and without guidance."""
    print('Loading models...')
    if isinstance(net, str):
        with dnnlib.util.open_url(net) as f:
            net = pickle.load(f).to(device)
    if isinstance(gnet, str):
        with dnnlib.util.open_url(gnet) as f:
            gnet = pickle.load(f).to(device)

    # Initialize plot.
    print('Drawing plots...')
    plt.rcParams['font.size'] = 28
    
    if kwargs.get('plot_type') == 'dist':
        plt.figure(figsize=[25, 25], dpi=40, tight_layout=True)
        plt.subplot(1, 1, 1)
        fig1_kwargs = dict(view_x=0.30, view_y=0.30, view_size=1.2, num_samples=num_samples, device=device)
        do_plot(net=net, gnet=gnet, guidance=guidance, elems={'gt_uncond', 'gt_outline', 'samples'}, **fig1_kwargs, **kwargs)
    
    elif kwargs.get('plot_type') == 'traj':
        plt.figure(figsize=[25, 25], dpi=40, tight_layout=True)
        plt.subplot(1, 1, 1)
        fig2_kwargs = dict(view_x=0.45, view_y=1.22, view_size=0.3, num_samples=num_samples, device=device, sample_distance=0.045, sigma_max=0.03)
        do_plot(net=net, gnet=gnet, guidance=guidance, elems={'gt_shaded', 'trajectories', 'samples_after'}, **fig2_kwargs)

    elif kwargs.get('plot_type') == 'pdf':
        fig2_kwargs = dict(view_x=0.45, view_y=1.22, view_size=0.3, num_samples=num_samples, device=device, sample_distance=0.045, sigma_max=0.03)
        plt.figure(figsize=[75, 25], dpi=40, tight_layout=True)
        plt.subplot(1, 3, 1)
        do_plot(net=net, elems={'p_net', 'gt_smax', 'scores_net', 'samples_before'}, **fig2_kwargs)
        plt.subplot(1, 3, 2)
        plt.title('PDF of guiding model')
        do_plot(net=net, gnet=gnet, elems={'p_gnet', 'gt_smax', 'scores_gnet', 'samples_before'}, **fig2_kwargs)
        plt.subplot(1, 3, 3)
        plt.title('PDF ratio (main / guiding)')
        do_plot(net=net, gnet=gnet, elems={'p_ratio', 'gt_smax', 'scores_ratio', 'samples_before'}, **fig2_kwargs)

    # Save or display.
    save = f'{save}-{kwargs.get("pred_type")}-gui{guidance}'
    if kwargs.get('pred_type') == 'reg':
        save += f'-betac{kwargs.get("beta_1")}-betau{kwargs.get("beta_2")}-gam{kwargs.get("gamma")}'
    elif kwargs.get('pred_type') == 'apg':
        save += f'-beta{kwargs.get("beta")}-eta{kwargs.get("eta")}-r{kwargs.get("r_scale")}'
    elif kwargs.get('pred_type') == 'fdg':
        save += f'-wh{kwargs.get("w_high")}-wl{kwargs.get("w_low")}-wp{kwargs.get("w_par")}'
    elif kwargs.get('pred_type') == 'ig':
        save += f'-sigl{kwargs.get("sigma_low")}-sigh{kwargs.get("sigma_high")}'
    elif kwargs.get('pred_type') == 'cfgpp':
        save += f'-sig{kwargs.get("sigma_cfgpp")}'
    save += '.png'
    if save is not None:
        print(f'Saving to {save}')
        if os.path.dirname(save):
            os.makedirs(os.path.dirname(save), exist_ok=True)
        plt.savefig(save, dpi=80)
    else:
        print('Displaying...')
        plt.show()
    print('Done.')

#----------------------------------------------------------------------------

if __name__ == "__main__":
    cmdline()

#----------------------------------------------------------------------------
