import os.path
import time
import warnings
from typing import Tuple, Any, Protocol, Union

import matplotlib.pyplot as plt
from dataclasses import dataclass
from tqdm import trange

from phi.flow import *
from phi.math import NUMPY
from phi.vis._log import SceneLog
from phi.vis._vis_base import index_label


class Experiment(Protocol):

    def __repr__(self):
        raise NotImplementedError

    def generate_problem(self, batch_dims: Shape, test: bool) -> Tuple[Any, Any, Tensor]:
        """
        Args:
            batch_dims: Number of examples to generate.
            test: Whether the data set is used for testing, i.e. contains the problems we are actually interested in.
                If `False`, generates a dataset purely for training. The distribution of training data may vary from test data

        Returns:
            full_initial_state: Initial state. This must be a `PhiTreeNode` to support slicing.
            masked_initial_state: Initial state containing only information optimizers should have access to.
                A model fit may not have access to the full initial state and has to work with an approximation.
                If the state is masked, a zero-loss optimization outcome is likely not possible.
            ground_truth: Generated solution. Variables are listed in non-batch dimensions along a `Tensor`.
        """
        raise NotImplementedError

    def forward_process(self, initial_state, guess: Tensor) -> Any:
        """
        Run the differentiable process forward.
        The goal is to find a solution by optimizing `guess`.

        Args:
            initial_state: Example configuration from a data set.
            guess: Solution estimate or ground truth solution.

        Returns:
            output: Observed output state. This goes into the loss function.
        """
        raise NotImplementedError

    def get_observations(self, output) -> Tensor:
        """
        Returns:
            observed_data: Tensor passed to an inverse network as input.
        """
        raise NotImplementedError

    def get_desired(self, output) -> Tensor:
        return self.get_observations(output)

    def loss_function(self, ref_output, guess_output) -> Tensor:
        """
        Args:
            ref_output: Simulation output from a reference solution. Gradients may or may not be provided for this variable.
            guess_output: Simulation output from a guess to be optimized. Gradients will be provided for this variable.

        Returns:
            Loss `Tensor`
        """
        return math.l2_loss(self.get_observations(ref_output) - self.get_observations(guess_output))

    def create_inverse_net(self, lib) -> Any:
        """
        Sets up and returns a neural network to output a guess based on vars.
        """
        raise NotImplementedError

    def run_inverse_net(self, net, initial_state, observed: Tensor, sol_shape: Shape):
        return math.native_call(net, observed, spatial_dim=sol_shape.spatial, channel_dim=sol_shape.channel)
    
    def create_forward_net(self, lib) -> Any:
        """
        Sets up and returns a neural networks to output a final state, based on an initial state and a guess.
        This network could be used as a surrogate for the forward process.
        To work with higher-order optimizers, the network must be twice differentiable.
        """
        raise NotImplementedError

    def run_forward_net(self, net, initial_state, guess: Tensor) -> Any:
        """
        Run a surrogate network for the forward process. Signature matches `forward_process()`.
        """
        raise NotImplementedError

    def plot_process(self, path: str, initial_state, guess: Tensor):
        """
        Plot one or multiple example simulations and write the results to files according to `path`.

        Args:
            path: Path to output files. File extensions and suffixes may be added.
            initial_state: Generated examples.
            guess: Parameters state for the simulation.

        Returns:
            args: Positional arguments for plot()
            kwargs: Keyword arguments for plot()
        """
        raise NotImplementedError

    def get_plot(self, plot_type: str, initial_state, guess: Tensor, ref_output) -> dict:
        raise NotImplementedError


@dataclass
class Dataset:
    experiment: Experiment
    seed: int
    batch_size: int
    initial_state: Any  # masked
    full_initial_state: Any
    ground_truth: Tensor
    output: Any
    observed: Tensor
    desired: Tensor


def dataset(batch_size, seed, experiment, test, compute_output=True):
    math.seed(seed)
    full_is, masked_is, ground_truth = experiment.generate_problem(batch(example=batch_size), test=test)
    assert 'example' in shape(ground_truth)
    if not compute_output:
        return Dataset(experiment, seed, batch_size, masked_is, full_is, ground_truth, None, None, None)
    output = experiment.forward_process(full_is, ground_truth)
    observed = experiment.get_observations(output)
    desired = experiment.get_desired(output)
    return Dataset(experiment, seed, batch_size, masked_is, full_is, ground_truth, output, observed, desired)


def _create_optimizer(nets, net, name, learning_rate):
    if name == 'adam':
        return nets.adam(net, learning_rate)
    elif name == 'BFGS':
        import torch
        return torch.optim.LBFGS(net.parameters())
    else:
        raise ValueError(name)


@math.broadcast
def train_network(batch_size, seed, learning_rate, iterations, optimizer_name, experiment: Experiment, print_every=100, nets=None):
    dset = dataset(batch_size, seed, experiment, True)
    # --- Create scene ---
    scene = Scene.create(f"~/phi/RP/{experiment}/{batch_size}_net_lr{learning_rate:.0e}_seed{seed}")
    scene.put_properties(seed=seed, batch_size=batch_size, learning_rate=learning_rate, iterations=iterations, optimizer=optimizer_name)
    print(scene)
    # --- Train ---
    viewer = view(scene=scene, gui='dash')
    net = experiment.create_inverse_net(nets)
    print(f"Training diff.phys. network with {nets.parameter_count(net)} parameters, bs={batch_size}, lr={learning_rate}, seed={seed}...")
    optimizer = _create_optimizer(nets, net, optimizer_name, learning_rate)
    def loss_function(dset: Dataset):
        prediction = experiment.run_inverse_net(net, dset.initial_state, dset.observed, dset.ground_truth.shape)
        clipped_grad = clip_gradient(prediction)
        output = experiment.forward_process(dset.initial_state, clipped_grad)
        return experiment.loss_function(dset.output, output), math.stop_gradient(prediction)
    for i, _ in zip(trange(iterations), viewer.range(iterations)):
        loss, params = nets.update_weights(net, optimizer, loss_function, dset)
        viewer.log_scalars(loss=loss, reduce=None, **params.vector)
        if i % print_every == 0:
            print(loss)
    nets.save_state(net, scene.subpath(f"net"))
    print(f"Network saved to {scene.subpath('net')}")


@math.broadcast
def train_surrogate(batch_size, seed, learning_rate, iterations, optimizer_name, experiment: Experiment, nets=None):
    training_dataset = dataset(batch_size, 1664525 * seed + 1013904223, experiment, False)
    # --- Create scene ---
    scene = Scene.create(f"~/phi/RP/{experiment}/{batch_size}_surrogate_lr{learning_rate:.0e}_seed{seed}")
    scene.put_properties(training_seed=1664525 * seed + 1013904223, test_seed=seed, batch_size=batch_size, learning_rate=learning_rate, iterations=iterations, optimizer=optimizer_name)
    print(scene)
    # --- Train ---
    viewer = view(scene=scene, gui='dash')
    net = experiment.create_forward_net(nets)
    print(f"Training surrogate with {nets.parameter_count(net)} parameters, bs={batch_size}, lr={learning_rate}, seed={seed}...")
    optimizer = _create_optimizer(nets, net, optimizer_name, learning_rate)
    def loss_function(dset: Dataset):
        output = experiment.run_forward_net(net, dset.initial_state, dset.ground_truth)
        return experiment.loss_function(dset.output, output)
    for i, _ in zip(trange(iterations), viewer.range(iterations)):
        train_loss = nets.update_weights(net, optimizer, loss_function, training_dataset)
        viewer.log_scalars(surrogate_train_loss=train_loss, reduce=None)
    nets.save_state(net, scene.subpath("net"))
    print(f"Network saved to {scene.subpath('net')}")


@math.broadcast
def train_supervised(batch_size, seed, learning_rate, iterations, optimizer_name, experiment: Experiment, eval_every=100, nets=None):
    training_dataset = dataset(batch_size, 1664525 * seed + 1013904223, experiment, False)
    test_dataset = dataset(batch_size, seed, experiment, True)
    # --- Create scene ---
    scene = Scene.create(f"~/phi/RP/{experiment}/{batch_size}_sup_lr{learning_rate:.0e}_seed{seed}")
    scene.put_properties(seed=seed, batch_size=batch_size, learning_rate=learning_rate, iterations=iterations, optimizer=optimizer_name)
    print(scene)
    # --- Train ---
    viewer = view(scene=scene, gui='dash')
    net = experiment.create_inverse_net(nets)
    print(f"Training supervised with {nets.parameter_count(net)} parameters, bs={batch_size}, lr={learning_rate}, seed={seed}...")
    optimizer = _create_optimizer(nets, net, optimizer_name, learning_rate)
    def loss_function(dset, test=False):
        observation = dset.desired if test else dset.observed
        prediction = experiment.run_inverse_net(net, dset.initial_state, observation, dset.ground_truth.shape)
        return math.l2_loss(prediction - training_dataset.ground_truth), prediction
    for i, _ in zip(trange(iterations), viewer.range(iterations)):
        train_loss, _ = nets.update_weights(net, optimizer, loss_function, training_dataset)
        viewer.log_scalars(training_loss=train_loss, reduce=None)
        if i % eval_every == 0:
            print(f"Evaluating inference loss at iteration {i}")
            _, prediction = loss_function(test_dataset, test=True)
            output = experiment.forward_process(test_dataset.initial_state, prediction)
            test_loss = experiment.loss_function(test_dataset.output, output)
            viewer.log_scalars(reduce=None, loss=math.convert(test_loss, NUMPY), **prediction.vector)
    nets.save_state(net, scene.subpath(f"net"))
    print(f"Network saved to {scene.subpath('net')}")


@math.broadcast
def visualize_dataset(batch_size, seed, experiment: Experiment, selection={'example': slice(0, 3)}):
    dset = dataset(batch_size, seed, experiment, True)
    gt_min = math.min(dset.ground_truth, 'example')
    gt_max = math.max(dset.ground_truth, 'example')
    initial_state = math.slice(dset.initial_state, selection)
    ground_truth = dset.ground_truth[selection]
    ref_output = math.slice(dset.output, selection)
    os.path.exists(os.path.expanduser(f"~/phi/RP/{experiment}")) or os.makedirs(os.path.expanduser(f"~/phi/RP/{experiment}"))
    # --- Loss landscape ---
    dset_idx = math.range(batch(example=batch_size))[selection]
    for p_idx in non_batch(ground_truth).meshgrid(names=True):  # for each parameter
        p_name = str(next(iter(p_idx.values())))
        print(f"Plotting loss landscape for parameter {p_name} from {gt_min[p_idx]} to {gt_max[p_idx]}")
        def loss_function(param):
            guess = math.scatter(expand(ground_truth, shape(param)), p_idx, param)
            output = experiment.forward_process(initial_state, guess)
            return experiment.loss_function(ref_output, output)
        loss_scan = CenteredGrid(loss_function, bounds=Box(**{p_name: (gt_min[p_idx], gt_max[p_idx])}), **{p_name: 200})
        gradient_function = math.functional_gradient(loss_function, 'param', get_output=False)
        grad_scan = loss_scan.with_values(lambda p: gradient_function(expand(p, batch(loss_scan))))
        title = -f-f"Example {dset_idx}, true={ground_truth[p_idx]:.1f}"
        show(stack([loss_scan, grad_scan], channel(c='L,∇ L')), title=title)
        vis.savefig(f"~/phi/RP/{experiment}/{batch_size}_loss_landscape_{p_name}_{batch_size}_seed{seed}.jpg", close=True)
        show(stack([loss_scan, grad_scan], channel(c='L,∇ L')), title=title, row_dims='c', same_scale=False)
        vis.savefig(f"~/phi/RP/{experiment}/{batch_size}_loss_landscape_{p_name}_{batch_size}_seed{seed}_rows.jpg", close=True)
    # --- Plot simulation ---
    experiment.plot_process(f"~/phi/RP/{experiment}/{batch_size}_seed{seed}_example_0", initial_state, ground_truth)
    vis.close()
    print(f"Saved experiment visualization to ~/phi/RP/{experiment}/d{batch_size}_seed{seed}_example_0")


@math.broadcast
def visualize_training(training_type: str, batch_size, seed, learning_rate, experiment: Experiment, scene_id=-1, selection={'example': slice(0, 3)}):
    dset = dataset(batch_size, seed, experiment, True)
    scenes = Scene.list(f"~/phi/RP/{experiment}/{batch_size}_{training_type}_lr{learning_rate:.0e}_seed{seed}")
    scene = scenes[scene_id]
    optimizer_name = scene.properties['optimizer']
    iterations = scene.properties['iterations']
    # --- Loss curves ---
    curve_names = {'net': ['loss'], 'sup': ['loss', 'training_loss'], 'surrogate': ['surrogate_train_loss']}[training_type]
    for curve_name in curve_names:
        curve = vis.load_scalars(scene.paths, curve_name, batch_dim=batch('example'))
        final_loss = curve[curve_name].iteration[-1]
        # --- Mean curve ---
        show(math.mean(curve, 'example'), err=math.std(curve, 'example'), log_dims=curve_name, title=f"{training_type} training |D|={batch_size} lr={learning_rate:.0e}")
        vis.savefig(scene.subpath(f"plots/{curve_name} curve mean.jpg", create_parent=True), close=True)
        # --- Individual curves ---
        show(curve.example[:20].example.as_channel(), log_dims=curve_name, title=f"{training_type} training |D|={batch_size} lr={learning_rate:.0e}")
        vis.savefig(scene.subpath(f"plots/{curve_name} curve all.jpg", create_parent=True), close=True)
        # --- Loss histogram ---
        base_title = f"Network error after {iterations} {optimizer_name} iterations. {experiment} |D|={batch_size} lr={learning_rate:.0e}"
        if math.is_nan(final_loss).all:
            warnings.warn(f"All examples diverged on {training_type} with bs={batch_size}, seed={seed}, lr={learning_rate}")
            continue
        histogram, bin_edges, bin_centers = math.histogram(math.b2i(final_loss), instance(**{curve_name: 20}))
        show(PointCloud(bin_centers, histogram), title=base_title)
        vis.savefig(scene.subpath(f"plots/{curve_name} histogram.jpg", create_parent=True), close=True)
        # --- scatter loss vs parameters ---
        histogram, bin_edges, bin_centers = math.histogram(math.b2i(final_loss), spatial(**{curve_name: 20}))
        scatters = []
        for idx in non_batch(dset.ground_truth).meshgrid(names=True):  # for each parameter
            param_name = index_label(idx)
            scatters.append(vec(**{param_name: dset.ground_truth[idx], curve_name: final_loss}).example.as_instance())
        show(*scatters, PointCloud(bin_centers, histogram), same_scale=curve_name)
        vis.savefig(scene.subpath(f'plots/{curve_name} by params.jpg', create_parent=True), close=True)
        # --- Cumulative loss by threshold ---
        threshold = math.linspace(1e-4, math.max(final_loss, 'example'), spatial(threshold=50))
        win_rate = math.mean(final_loss <= threshold, 'example')
        show(vec(threshold=threshold, win_rate=win_rate), title=f"{training_type} training |D|={batch_size} lr={learning_rate:.0e}")
        vis.savefig(scene.subpath(f'plots/{curve_name} by threshold.jpg', create_parent=True), close=True)
    # --- Parameter trajectories if applicable ---
    existing_params = [p for p in dset.ground_truth.vector.item_names if os.path.exists(scene.subpath(f'log_{p}.txt'))]
    if existing_params:
        existing_params = wrap(existing_params, channel('variables'))
        param_curves = vis.load_scalars(scene, existing_params, batch_dim=batch('example'))[selection]
        param_curves = rename_dims(param_curves, 'vector', channel(vector='iteration,value'))
        true_params = vec(iteration=param_curves.vector['iteration'].iteration[-1], value=dset.ground_truth[selection].vector.as_channel('variables'))
        color = math.range(channel(variables=param_curves.variables.size))
        show(param_curves, true_params, overlay='args', color=color)
        vis.savefig(scene.subpath(f"plots/param_trj.jpg", create_parent=True))
    print(f"Training visualization saved to {scene}")


def _load_param_curves(methods: str, log: str, batch_size: int, seed: int, learning_rate: Union[float, dict], experiment: Experiment, refined: str = None):
    curves = {}
    for method in methods.split(','):
        method_name = {'net': "Reparameterized", 'sup': "Supervised", 'surrogate': "Neural Adjoint", 'guessing': "Random guessing"}.get(method, method)
        lr = learning_rate.get(method, 1) if isinstance(learning_rate, dict) else learning_rate
        scene = _current_scene(method, batch_size, seed, lr, experiment=experiment)
        curves[method_name] = vis.load_scalars(scene.paths, log, batch_dim=batch('example'))
        if refined:
            if method.lower() == 'bfgs':
                curves[method_name] = expand(curves[method_name], batch(refined='not_refined,refined'))
            else:
                scene = _current_scene(method + '+' + refined, batch_size, seed, lr, experiment=experiment)
                refined_curve = vis.load_scalars(scene.paths, log, batch_dim=batch('example'))
                curves[method_name] = stack([curves[method_name], refined_curve], batch(refined='not_refined,refined'))
    return stack(curves, batch('method'))


def plot_all_curves(methods: str, log: str, batch_size, seed, learning_rate, experiment: Experiment, refined: str = None, extend_curves=False, show_std=True):
    curve = _load_param_curves(methods, log, batch_size, seed, learning_rate, experiment, refined)
    curve = _pad_curve(curve, extend_steps=extend_curves)
    # --- Mean curve ---
    unrefined = curve.refined['not_refined']
    mean_curve = math.finite_mean(unrefined, 'example')
    plt.rc('font', family='Arial', weight='normal', size=8)
    show(math.mean(mean_curve, shape(seed)), err=math.std(mean_curve, shape(seed)) if show_std else None,
         log_dims='steps,loss', overlay='method', same_scale=log, title=-f-f"$n$ = {batch_size}", size=(7, 3))
    vis.savefig(f"~/phi/RP/{experiment}/SI_{experiment}_{log}_curves_by_dset.pdf", close=True)


def _pad_curve(curves: Tensor, extend_steps=False):
    max_iter = int(curves.iteration.size.max)
    result = []
    batches = curves.shape.without('iteration,vector,example')
    for i in batches.meshgrid():
        if extend_steps:
            step, curve = curves[i].vector
            padded_curve = math.pad(curve, {'iteration': (0, max_iter - curve.iteration.size)}, ZERO_GRADIENT)
            more_steps = math.concat([step, math.linspace(step.max+1, max_iter, spatial(iteration=max_iter - step.iteration.size))], 'iteration')
            padded = vec(steps=more_steps, loss=padded_curve)
            assert padded.shape.is_uniform
        else:
            curve = curves[i]
            padded = math.pad(curve, {'iteration': (0, max_iter - curve.iteration.size)}, ZERO_GRADIENT)
        result.append(padded)
    result = stack(result, batches)
    return result


@math.broadcast
def visualize_optimizations(methods: str, log: str, batch_size, seed, learning_rate, experiment: Experiment, refined: str = None):
    dset = dataset(batch_size, int(seed), experiment, True, compute_output=False)
    curve = _load_param_curves(methods, log, batch_size, seed, learning_rate, experiment, refined)
    final_loss = curve[log].iteration[-1]
    batch_dims = batch(batch_size) & batch(seed) & batch(learning_rate)
    path = f"~/phi/RP/{experiment}/{batch_size}_{methods}_" + ("refined_" if refined else "")
    # --- Mean curve ---
    unrefined = curve.refined['not_refined']
    show(math.finite_mean(unrefined, 'example'), err=math.std(unrefined, 'example'), log_dims='loss', overlay=batch_dims, same_scale=log)
    vis.savefig(path+f"{log} curve mean.jpg", close=True)
    # --- Individual curves ---
    if batch_dims.volume == 1:
        show(unrefined.example[:20].example.as_channel(), log_dims='loss', same_scale=log)
        vis.savefig(path+f"{log} curve all.jpg", close=True)
    # --- scatter loss vs parameters ---
    histogram, bin_edges, bin_centers = math.histogram(final_loss.example.as_instance(), spatial(**{log: 10}), same_bins=batch)
    scatters = []
    for idx in non_batch(dset.ground_truth).meshgrid(names=True):  # for each parameter
        param_name = index_label(idx)
        scatters.append(vec(**{param_name: dset.ground_truth[idx], log: final_loss}).example.as_instance())
    show(*scatters, PointCloud(bin_centers, histogram.method.as_channel('method_')), same_scale=log, overlay='method', log_dims='_')
    vis.savefig(path+f"{log} final by params.jpg", close=True)
    # --- Cumulative loss by threshold ---
    threshold = math.exp(math.linspace(math.log(1e-6), math.log(math.finite_max(final_loss, 'example,method,refined')), spatial(threshold=50)))
    win_rate = math.mean(final_loss <= threshold, 'example')
    threshold_curve = vec(threshold=threshold, win_rate=win_rate).method.as_channel().refined.as_instance()
    show(threshold_curve, log_dims='threshold', color=math.range(threshold_curve.shape['method']), title=f"|D| = {batch_size}, lr={learning_rate:.0e}")
    vis.savefig(path + f'{log} by threshold.jpg', close=True)
    print(f"Optimization plots saved to {path}")


def plot_all_optimizations(methods: str, log: str, batch_size: int, seeds, learning_rate, experiment: Experiment, refined: str = None):
    dsets = [dataset(batch_size, seed, experiment, True, compute_output=False) for seed in seeds]
    ground_truth = concat([dset.ground_truth for dset in dsets], 'example')
    curve = _load_param_curves(methods, log, batch_size, seeds, learning_rate, experiment, refined)
    final_loss = curve[log].iteration[-1]
    final_loss = math.concat(math.unstack(final_loss, shape(seeds)), 'example')
    # --- scatter loss vs parameters ---
    histogram, bin_edges, bin_centers = math.histogram(final_loss.example.as_instance(), spatial(**{log: 10}), same_bins=batch)
    scatters = []
    for idx in non_batch(ground_truth).meshgrid(names=True):  # for each parameter
        param_name = index_label(idx)
        scatters.append(vec(**{param_name: ground_truth[idx], log: final_loss}).example.as_instance())
    show(*scatters, PointCloud(bin_centers, histogram.method.as_channel('method_')), same_scale=log, overlay='method', log_dims='_')
    vis.savefig(f"~/phi/RP/{experiment}/SI_{experiment}_{batch_size}_{log} final by params.pdf", close=True)


def direct_compare(methods: str, batch_size, seed, learning_rate, experiment: Experiment, refined: str = None, log='loss'):
    curve = _load_param_curves(methods, log, batch_size, seed, learning_rate, experiment, refined)
    best_loss = math.finite_min(curve[log], 'iteration')
    ref_best = best_loss.method[0]
    for m in best_loss.method.item_names[1:]:
        m_best = best_loss.method[m]
        better = m_best < ref_best
        equal = m_best == ref_best
        print(f"Averaged over {batch_size}")
        print(f"{m} yields lower {log} than {best_loss.method.item_names[0]} fraction: {math.mean(better, batch(seed) & batch('example') & batch(batch_size))}")
        print(f"{m} yields equal {log} as {best_loss.method.item_names[0]} fraction: {math.mean(equal, batch(seed) & batch('example') & batch(batch_size))}")
        for size_idx in shape(batch_size).meshgrid(names=True):
            print(size_idx)
            print(f"{m} yields lower {log} than {best_loss.method.item_names[0]} fraction: {math.mean(better[size_idx], batch(seed) & batch('example'))}")
            print(f"{m} yields equal {log} as {best_loss.method.item_names[0]} fraction: {math.mean(equal[size_idx], batch(seed) & batch('example'))}")


WITH_OR_WITHOUT_REFINEMENT = wrap(['without', 'with'], batch(refined='not_refined,refined'))


def plot_by_dset_size(methods: str, batch_size, seed, learning_rate, experiment: Experiment, refined: str = None, log='loss', threshold=1e-3):
    curve = _load_param_curves(methods, log, batch_size, seed, learning_rate, experiment, refined)
    best_loss = math.finite_min(curve[log], 'iteration')
    path = f"~/phi/RP/{experiment}/{methods}_" + ("refined_" if refined else "")
    if refined:
        refined_curve = curve.refined['refined']
        mean_it_count = math.mean(refined_curve.iteration.size, shape(seed))
        print("Refinement Iterations:")
        for method, mitc in zip(refined_curve.method.item_names, mean_it_count.method):
            print(method)
            print(mitc)
    # --- Plot loss ---
    refined_best_loss = best_loss.refined['refined']
    mean_final_loss = math.mean(refined_best_loss, 'example,lr').method.as_channel().dataset_size.as_instance()
    plot(math.mean(mean_final_loss, shape(seed)), err=math.std(mean_final_loss, shape(seed)), title=f"Loss {'with ' + refined.upper() if refined else 'without'} refinement, tol={threshold:.0e}, {shape(seed).volume} seeds")
    plt.ylim((0, mean_final_loss.max * 1.1))
    show()
    vis.savefig(path + f'{log} by dataset.jpg', close=True)
    # --- Plot win rate ---
    win_rate = math.mean(refined_best_loss <= threshold, 'example,lr')
    win_rate = math.b2i(win_rate.method.as_channel())
    mean_win_rate = math.mean(win_rate, shape(seed))
    std_win_rate = math.std(win_rate, shape(seed))
    math.print(mean_win_rate, f"mean win-rate")
    plot(mean_win_rate, err=std_win_rate, title=f"Win rate {'with ' + refined.upper() if refined else 'without'} refinement, tol={threshold:.0e}, {shape(seed).volume} seeds")
    plt.ylim((0, 1))
    show()
    vis.savefig(path + f'{log} win-rate by dataset.jpg', close=True)
    # --- Plot improvement over BFGS ---
    mean_final_loss = math.mean(best_loss, 'example,lr').method.as_channel().dataset_size.as_instance()
    improvement_rate = math.mean((best_loss < best_loss.method[0]), 'example,lr').method.as_channel().dataset_size.as_instance()
    plt.rc('font', family='Arial', weight='normal', size=8)
    plot((math.mean(mean_final_loss, shape(seed)), math.mean(improvement_rate, shape(seed))),
         err=(math.std(mean_final_loss, shape(seed)), math.std(improvement_rate, shape(seed))),
         title=(-f-f"Loss {WITH_OR_WITHOUT_REFINEMENT} refinement", -f-f"Improvement {WITH_OR_WITHOUT_REFINEMENT} refinement over BFGS"),
         row_dims='refined,tuple', size=(5.5, 6), same_scale=False)
    # plt.ylim((0, mean_final_loss.max * 1.1))
    plt.gcf().axes[0].get_legend().remove()
    show()
    vis.savefig(path + f"loss_and_improvement_by_dset.jpg")
    vis.savefig(f"~/phi/RP/{experiment}/SI_{experiment}_loss_and_improvement_by_dset.pdf", close=True)
    print(f"Dataset plots saved to {path}")


def paper_plot(plot_types: str,
               batch_sizes,
               seeds,
               learning_rate,
               experiment: Experiment,
               example=0,
               example_seed=0,
               example_batch_size=64,
               curves_batch_size=128,
               size=(5.5, 2.5),
               names={},
               param_range=None,
               methods='BFGS,net,sup,surrogate'):
    dset = dataset(example_batch_size, example_seed, experiment, True)
    selection = {'example': example}
    initial_state = math.slice(dset.initial_state, selection)
    full_initial_state = math.slice(dset.full_initial_state, selection)
    ground_truth = dset.ground_truth[selection]
    ref_output = math.slice(dset.output, selection)
    subplots = []
    for plot_type in [p.strip() for p in plot_types.split(',')]:
        if plot_type.startswith('landscape:'):
            p_name = plot_type[len('landscape:'):].strip()
            p_idx = {'vector': p_name}
            p_name = names.get(p_name, p_name)
            gt_min = math.min(dset.ground_truth, 'example')
            gt_max = math.max(dset.ground_truth, 'example')
            if param_range:
                gt_min, gt_max = [wrap(t) for t in param_range]
            print(f"Plotting loss landscape for parameter {p_name} from {gt_min[p_idx]} to {gt_max[p_idx]}")
            def loss_function(param):
                guess = math.scatter(expand(ground_truth, shape(param)), p_idx, param)
                output = experiment.forward_process(initial_state, guess)
                return experiment.loss_function(ref_output, output)
            loss_scan = CenteredGrid(loss_function, bounds=Box(**{p_name: (gt_min[p_idx], gt_max[p_idx])}), **{p_name: 200})
            gradient_function = math.functional_gradient(loss_function, 'param', get_output=False)
            grad_scan = loss_scan.with_values(lambda p: gradient_function(expand(p, batch(loss_scan))))
            norm = abs(grad_scan.values).max / loss_scan.values.max
            allowable_norms = sum([(i, 2*i, 5*i) for i in [1, 10, 100, 1000, 10000]], ())[1:]
            try:
                norm = max([n for n in allowable_norms if n <= norm])
                subplots.append(dict(obj=stack([loss_scan, grad_scan / norm], channel(c=f'\u2112,∇\u2112 / {norm}'))))
            except ValueError:
                subplots.append(dict(obj=stack([loss_scan, grad_scan], channel(c='\u2112,∇\u2112'))))
        elif plot_type.startswith('loss-curves'):
            curve = _load_param_curves(methods, 'loss', curves_batch_size, example_seed, learning_rate, experiment, refined=None)
            max_iter = curve.shape.get_size('iteration').max
            mcurve = math.finite_mean(curve, 'example').method.as_channel()
            mcurve = concat([mcurve, vec(steps=max_iter, loss=mcurve.vector['loss'].iteration[-1:])], 'iteration')
            subplots.append(dict(obj=math.finite_mean(mcurve, batch(seeds))))  # , err=math.std(mcurve, batch(seeds))
        elif plot_type in ['refined-loss-by-n', 'loss-by-n']:
            refined = 'bfgs' if 'refined' in plot_type else None
            curve = _load_param_curves(methods, 'loss', batch_sizes, seeds, learning_rate, experiment, refined=refined)
            if refined:
                curve = curve.refined['refined']
            best_loss = math.finite_min(curve['loss'], 'iteration')
            mean_final_loss = math.b2i(math.finite_mean(best_loss, 'example,lr').method.as_channel())
            # plt.ylim((0, mean_final_loss.max * 1.1))
            subplots.append(dict(obj=math.finite_mean(mean_final_loss, shape(seeds)), err=math.std(mean_final_loss, shape(seeds))))
        else:
            subplots.append(experiment.get_plot(plot_type, full_initial_state, ground_truth, ref_output))

    plt.rc('font', family='Arial', weight='normal', size=8)
    all_keys = set().union(*subplots)
    objs = [s['obj'] for s in subplots]
    plt_args = {k: [subplots[i].get(k, None) for i in range(len(subplots))] for k in all_keys if k != 'obj'}
    plot(objs, size=size, same_scale='', log_dims='loss,steps', title=[f"({chr(97+i)})" for i in range(len(objs))], **plt_args)


@math.broadcast
def visualize_param_trj(methods: str, batch_size, seed, learning_rate, experiment: Experiment, selection={'example': slice(0, 4)}):
    dset = dataset(batch_size, int(seed), experiment, True, compute_output=False)
    params = vec(batch('variables'), dset.ground_truth.vector.item_names)
    curve = _load_param_curves(methods, params, batch_size, seed, learning_rate, experiment, refined=None)[selection].method.as_channel()
    curve = rename_dims(curve, 'vector', channel(vector='iterations,value'))
    curve += vec(iterations=1, value=0)
    gt_param = dset.ground_truth[selection].vector.as_batch('variables')
    # --- Plot parameter trajectories ---
    gt_point = vec(iterations=10, value=gt_param)
    show(curve, gt_point, overlay='args', log_dims='iterations', row_dims='variables', title=-f-f"{params}, example {math.range(curve.shape['example'])}", size=(14, 10))
    vis.savefig(f"~/phi/RP/{experiment}/{batch_size}_param_trj_seed{seed}.jpg", close=True)
    # --- Plot L2 norm to GT ---
    gt_norm = math.sqrt(math.sum((curve.vector['value'] - gt_param) ** 2, 'variables'))
    gt_curve = vec(iterations=curve.vector['iterations'], ground_truth_norm=gt_norm)
    show(gt_curve, same_scale='ground_truth_norm', log_dims='iterations', title=-f-f"Example {math.range(curve.shape['example'])}")
    vis.savefig(f"~/phi/RP/{experiment}/{batch_size}_gt_norm_seed{seed}.jpg", close=True)
    print(f"Parameter trajectories saved to ~/phi/RP/{experiment}/{batch_size}_param_trj.jpg")


def plot_parameter_trajectories(methods: str, batch_size, seed, learning_rate, experiment: Experiment, selection={'example': slice(0, 4)}):
    dset = dataset(batch_size, int(seed), experiment, True, compute_output=False)
    params = vec(batch('variables'), dset.ground_truth.vector.item_names)
    curve = _load_param_curves(methods, params, batch_size, seed, learning_rate, experiment, refined=None)[selection].method.as_channel()
    curve = rename_dims(curve, 'vector', channel(vector='iterations,value'))
    curve += vec(iterations=1, value=0)
    gt_param = dset.ground_truth[selection].vector.as_batch('variables')
    # --- Plot parameter trajectories ---
    gt_point = vec(iterations=10, value=gt_param)
    plt.rc('font', family='Arial', weight='normal', size=8)
    plot(curve, gt_point, overlay='args', log_dims='iterations', row_dims='variables',
         title=-f - f"{params}, example {math.range(curve.shape['example'])}", size=(9, 1 + dset.ground_truth.vector.size))
    plt.gcf().axes[0].get_legend().remove()
    show()
    vis.savefig(f"~/phi/RP/{experiment}/SI_{experiment}_param_trj.pdf", close=True)


@math.broadcast
def visualize_results(batch_size, seed, learning_rate, experiment: Experiment, bfgs=True, supervised=False, nets=None, selection={'example': slice(0, 3)}):
    dset = dataset(batch_size, int(seed), experiment, True)
    initial_state = math.slice(dset.initial_state, selection)
    ground_truth = math.slice(dset.ground_truth, selection)
    observed_data = dset.observed[selection]
    scenes: Scene = _current_scene('net', batch_size, seed, learning_rate, experiment=experiment)
    if bfgs:
        scenes_bfgs = _current_scene('bfgs', batch_size, seed, experiment=experiment)
        bfgs_sol = np.loadtxt(scenes_bfgs.subpath("bfgs_sol.txt"))
        bfgs_sol = wrap(bfgs_sol, batch('example'), channel(ground_truth)) if bfgs_sol.ndim == 2 else expand(wrap(bfgs_sol, batch('example')), channel(ground_truth))
        bfgs_sol = bfgs_sol[selection]
    net = experiment.create_inverse_net(nets)
    for i, idx in enumerate(scenes.shape.meshgrid(names=True)):
        nets.load_state(net, scenes[idx].subpath('net'))
        prediction = experiment.run_inverse_net(net, initial_state, observed_data, ground_truth.shape)
        all_guesses = {
            "Ground Truth": ground_truth,
            "Net": prediction
        }
        if bfgs:
            all_guesses["BFGS"] = bfgs_sol
        all_guesses = stack(all_guesses, channel('overlay'))
        experiment.plot_process(f"~/phi/RP/{experiment}/{batch_size}_net_vs_gt_", initial_state, all_guesses)
        print(f"Saved result {i+1} of {scenes.shape.volume} to ~/phi/RP/{experiment}/{batch_size}_net_vs_gt_...")


@math.broadcast
def _current_scene(training_type: str, batch_size, seed, learning_rate=None, experiment: Experiment = None) -> Scene:
    if training_type.lower() in ['bfgs', 'gd', 'guessing']:
        return Scene.at(f"~/phi/RP/{experiment}/{batch_size}_bfgs_seed{seed}", -1)
    assert learning_rate is not None
    return Scene.at(f"~/phi/RP/{experiment}/{batch_size}_{training_type}_lr{learning_rate:.0e}_seed{seed}", -1)


@math.broadcast
def optimize_directly(method: str, batch_size, seed, experiment: Experiment, selection={'example': slice(0, 3)}):
    """
    Args:
        method: `'BFGS'` or `'GD'`
    """
    dset = dataset(batch_size, seed, experiment, True)
    scene = Scene.create(f"~/phi/RP/{experiment}/{batch_size}_{method.lower()}_seed{seed}")
    # --- Optimization ---
    print(f"{method}... {scene}")
    def true_objective_function(guess):
        output = experiment.forward_process(dset.initial_state, guess)
        return experiment.loss_function(dset.output, output)
    t = time.perf_counter()
    try:
        bfgs_solve = Solve(method, 0, 1e-5, x0=dset.ground_truth * 0, max_iterations=100)
        with math.SolveTape(bfgs_solve, record_trajectories=True) as solves:
            opt_sol = minimize(true_objective_function, bfgs_solve)
        print(f"Finished in {solves[0].iterations} iterations -> {opt_sol}")
    except Diverged as div:
        print(f"Diverged in {div.result.iterations} iterations with error {div.result.msg}")
        opt_sol = div.result.x.trajectory[-1]
        print(div.result.x)
    print(f"{method} optimization took {time.perf_counter() - t:.2f} seconds")
    opt_loss = true_objective_function(opt_sol)
    print(f"BFGS Solution: {opt_sol}")
    print(f"Final loss", opt_loss)
    opt_trj = solves[0].x
    true_loss_trj = true_objective_function(opt_trj)  # should be equal to solves[0].residual
    # --- Save curves ---
    logger = SceneLog(scene)
    for i, (l, p) in enumerate(zip(true_loss_trj.trajectory, opt_trj.trajectory)):
        logger.log_scalars(i, reduce=None, loss=l, **p.vector)
    # --- Save solutions ---
    sol_path = scene.subpath(f"{method.lower()}_sol.txt")
    np.savetxt(sol_path, opt_sol.numpy('example,vector'))
    print(f"Solutions saved to {sol_path}")
    # --- Plot loss ---
    residual = true_loss_trj.trajectory.as_spatial()
    show(math.mean(residual, 'example'), log_dims='_', err=math.std(residual, 'example'), title=f"BFGS on {experiment} |D|={batch_size}")
    vis.savefig(scene.subpath(f'plots/{method}_loss.jpg', create_parent=True))
    # --- Plot parameter evolution ---
    param_curves = CenteredGrid(opt_trj[selection].trajectory.as_spatial().vector.as_channel('variables'))
    color = math.range(channel(variables=opt_trj.vector.size))
    show(param_curves, vec(trajectory=0, value=dset.ground_truth[selection].vector.as_channel('variables')), overlay='args', color=color)
    vis.savefig(scene.subpath(f"plots/{method}_params.jpg", create_parent=True))


@math.broadcast
def random_guessing(batch_size, seed, experiment: Experiment, count=1, selection={'example': slice(0, 3)}):
    dset = dataset(batch_size, seed, experiment, True)
    scene = Scene.create(f"~/phi/RP/{experiment}/{batch_size}_guessing_seed{seed}")
    # --- Save curves ---
    logger = SceneLog(scene)
    for i in range(count):
        _, _, rnd_guess = experiment.generate_problem(batch(dset.ground_truth), test=False)
        output = experiment.forward_process(dset.initial_state, rnd_guess)
        loss = experiment.loss_function(dset.output, output)
        logger.log_scalars(i, reduce=None, loss=loss, **rnd_guess.vector)
    print(f"Guesses saved to {scene}")


@math.broadcast
def neural_adjoint(batch_size, seed, learning_rate, optimizer_name, experiment: Experiment, nets=None, scene_id=-1, use_boundary_loss=True, selection={'example': slice(0, 3)}):
    dset = dataset(batch_size, seed, experiment, True)
    scenes = Scene.list(f"~/phi/RP/{experiment}/{batch_size}_surrogate_lr{learning_rate:.0e}_seed{seed}")
    scene = scenes[scene_id]
    # --- Get training data for range of solution values ---
    training_dset = dataset(batch_size, scene.properties['training_seed'], experiment, False)
    min_sol = math.min(training_dset.ground_truth, non_channel)
    max_sol = math.max(training_dset.ground_truth, non_channel)
    # --- Load network ---
    net = experiment.create_forward_net(nets)
    nets.load_state(net, scene.subpath('net'))
    # --- Neural Adjoint ---
    print(f"Neural adjoint {'with' if use_boundary_loss else 'without'} boundary loss... {scene}")
    def objective_function(guess):
        how_far_outside = math.maximum(guess - max_sol, min_sol - guess) / (max_sol - min_sol)
        boundary_loss = math.sum(math.soft_plus(how_far_outside * 4) / 4, channel)
        output = experiment.run_forward_net(net, dset.initial_state, guess)
        return experiment.loss_function(dset.output, output) + boundary_loss * use_boundary_loss
    try:
        bfgs_solve = Solve('BFGS', 0, 1e-5, x0=dset.ground_truth * 0, max_iterations=100)
        with math.SolveTape(bfgs_solve, record_trajectories=True) as solves:
            bfgs_sol = minimize(objective_function, bfgs_solve)
        print(f"Finished in {solves[0].iterations} iterations -> {bfgs_sol}")
    except Diverged as div:
        print(f"Diverged in {div.result.iterations} iterations with error {div.result.msg}")
        bfgs_sol = div.result.x.trajectory[-1]
        print(div.result.x)
    surrogate_loss = objective_function(bfgs_sol)
    print(f"Final surrogate loss", surrogate_loss)
    def true_objective_function(guess):
        output = experiment.forward_process(dset.initial_state, guess)
        return experiment.loss_function(dset.output, output)
    bfgs_trj = solves[0].x
    true_loss_trj = true_objective_function(bfgs_trj)
    print(f"Final true loss", true_loss_trj.trajectory[-1])
    # --- Save curves ---
    logger = SceneLog(scene)
    for i, (l, t, p) in enumerate(zip(solves[0].residual.trajectory, true_loss_trj.trajectory, bfgs_trj.trajectory)):
        logger.log_scalars(i, reduce=None, neural_adjoint_surrogate_loss=l, neural_adjoint_true_loss=t, loss=t, **p.vector)
    # --- Save solutions ---
    sol_path = scene.subpath(f"neural_adjoint_sol_{'boundary' if use_boundary_loss else 'pure'}.txt")
    np.savetxt(sol_path, bfgs_sol.numpy('example,vector'))
    print(f"Neural adjoint solutions saved to {sol_path}")
    # --- Plot loss ---
    residual = solves[0].residual.trajectory.as_spatial("BFGS Iterations")
    true_loss_trj = true_loss_trj.trajectory.as_spatial("BFGS Iterations")
    losses = stack({"True": true_loss_trj, "Surrogate": residual}, channel('c'))
    show(math.maximum(1e-5, math.mean(losses, 'example')), log_dims='_', err=math.std(losses, 'example'), title=f"Neural adjoint on {experiment} |D|={batch_size} (BFGS)")
    vis.savefig(scene.subpath(f"plots/neural_adjoint_{'boundary' if use_boundary_loss else 'pure'}_bfgs_loss.jpg", create_parent=True))
    # --- Plot parameter evolution ---
    param_curves = CenteredGrid(bfgs_trj[selection].trajectory.as_spatial().vector.as_channel('variables'))
    color = math.range(channel(variables=bfgs_trj.vector.size))
    show(param_curves, vec(trajectory=100, value=dset.ground_truth[selection].vector.as_channel('variables')), overlay='args', color=color)
    vis.savefig(scene.subpath(f"plots/neural_adjoint_{'boundary' if use_boundary_loss else 'pure'}_bfgs_params.jpg", create_parent=True))


def get_time_to_train(training_type: str, batch_size, seed, learning_rate, experiment: Experiment, selection={'example': slice(0, 3)}):
    scene = _current_scene(training_type, batch_size, seed, learning_rate, experiment=experiment)
    curve = vis.load_scalars(scene.paths, 'step_time', batch_dim=batch('example'), x=None)
    return math.sum(curve, 'iteration')


@math.broadcast
def delete_last_refined(training_type: str, method: str, batch_size, seed, learning_rate, experiment: Experiment, selection={'example': slice(0, 3)}):
    scene = Scene.list(f"~/phi/RP/{experiment}/{batch_size}_{training_type}+{method.lower()}_lr{learning_rate:.0e}_seed{seed}")[-1]
    scene.rename('last_iteration')


@math.broadcast
def refine_directly(training_type: str, method: str, batch_size, seed, learning_rate, experiment: Experiment, selection={'example': slice(0, 3)}):
    """
    Args:
        training_type: `'net'` or `'sup'` or `'surrogate'`
        method: `'BFGS'` or `'GD'`
    """
    dset = dataset(batch_size, seed, experiment, True)
    scene = _current_scene(training_type, batch_size, seed, learning_rate, experiment=experiment)
    # --- Load previous solutions ---
    param_names = wrap(dset.ground_truth.vector.item_names, dset.ground_truth.shape['vector'])
    curve = vis.load_scalars(scene.paths, param_names, batch_dim=batch('example'), x=None)
    loss_curve = vis.load_scalars(scene.paths, 'loss', batch_dim=batch('example'), x=None)
    best_loss = math.finite_min(loss_curve, 'iteration')
    best_loss_idx = math.nonzero((loss_curve == best_loss) | math.is_nan(best_loss)).nonzero[-1]
    initial_guess = curve[best_loss_idx]
    if training_type == 'sup':
        initial_guess = curve.iteration[-1]
    # --- Optimization ---
    print(f"Refine with {method}... {scene}")
    def true_objective_function(guess):
        output = experiment.forward_process(dset.initial_state, guess)
        return experiment.loss_function(dset.output, output)
    # best_primary_loss = true_objective_function(initial_guess)
    # final_primary_loss = true_objective_function(curve.iteration[-1])
    # assert (best_primary_loss <= final_primary_loss).all, f"Selection of best minimum failed for some examples."
    # --- Create new scene ---
    scene = Scene.create(f"~/phi/RP/{experiment}/{batch_size}_{training_type}+{method.lower()}_lr{learning_rate:.0e}_seed{seed}")
    try:
        bfgs_solve = Solve(method, 0, 1e-5, x0=initial_guess, max_iterations=100)
        with math.SolveTape(bfgs_solve, record_trajectories=True) as solves:
            opt_sol = minimize(true_objective_function, bfgs_solve)
        print(f"Finished in {solves[0].iterations} iterations -> {opt_sol}")
    except Diverged as div:
        print(f"Diverged in {div.result.iterations} iterations with error {div.result.msg}")
        opt_sol = div.result.x.trajectory[-1]
        print(div.result.x)
    opt_loss = true_objective_function(opt_sol)
    print(f"BFGS Solution: {opt_sol}")
    print(f"Final loss", opt_loss)
    opt_trj = solves[0].x
    true_loss_trj = true_objective_function(opt_trj)  # should be equal to solves[0].residual
    # --- Save curves ---
    logger = SceneLog(scene)
    for i, (l, p) in enumerate(zip(true_loss_trj.trajectory, opt_trj.trajectory)):
        logger.log_scalars(i, reduce=None, loss=l, **p.vector)
    # --- Save solutions ---
    sol_path = scene.subpath(f"{method}_sol.txt")
    np.savetxt(sol_path, opt_sol.numpy('example,vector'))
    print(f"Solutions saved to {sol_path}")
    # --- Plot loss ---
    residual = true_loss_trj.trajectory.as_spatial()
    show(math.finite_mean(residual, 'example'), log_dims='_', err=math.std(residual, 'example'), title=f"BFGS on {experiment} |D|={batch_size}")
    vis.savefig(scene.subpath(f'plots/{method}_loss.jpg', create_parent=True))
    # --- Plot parameter evolution ---
    param_curves = CenteredGrid(opt_trj[selection].trajectory.as_spatial().vector.as_channel('variables'))
    color = math.range(channel(variables=opt_trj.vector.size))
    show(param_curves, vec(trajectory=0, value=dset.ground_truth[selection].vector.as_channel('variables')), overlay='args', color=color)
    vis.savefig(scene.subpath(f"plots/{method}_params.jpg", create_parent=True))


def copy_all_plots(src_dir, target_dir, extensions=(".jpg", ".png", ".mp4", ".pdf"), delete_old=False):
    src_dir = os.path.expanduser(src_dir)
    target_dir = os.path.expanduser(target_dir)
    if os.path.exists(target_dir):
        if delete_old:
            for file_name in os.listdir(target_dir):
                if any([file_name.endswith(ext) for ext in extensions]):
                    print(f"Deleting {file_name}")
                    os.remove(os.path.join(target_dir, file_name))
    else:
        os.makedirs(target_dir)
    src_paths = []
    for dir_path, dirs, files in os.walk(src_dir):
        for file_name in files:
            if any([file_name.endswith(ext) for ext in extensions]):
                src_paths.append(os.path.join(dir_path, file_name))
    import shutil
    for src_path in src_paths:
        rel = os.path.relpath(src_path, src_dir)
        rel = rel.replace("/", " ").replace("\\", " ")
        target_path = os.path.join(target_dir, rel)
        print(f"Copying {src_path} to {target_path}")
        shutil.copyfile(src_path, target_path)


def clip_gradient(x, q=.9):
    def gradient_function(input_dict, y, dy):  # y and dy have shape (example, vector)
        quantile = math.quantile(abs(dy), q, 'example')
        dy = math.where(abs(dy) > quantile, quantile * math.sign(dy), dy)
        dy = math.where(math.is_finite(dy), dy, 0)
        return {'x': dy}
    return math.custom_gradient(math.identity, gradient_function)(x)
