import json
import math
import shutil
from collections import defaultdict
from pathlib import Path

import colorful
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import trimesh
from scipy.spatial import cKDTree as KDTree
from tabulate import tabulate

from meta_learning_sdf import click, load_model, mkdir
from meta_learning_sdf.datasets.functions import (read_mesh_data,
                                                  read_uniform_point_cloud_data
                                                  )
from meta_learning_sdf.datasets.shapenet.uniform_sparse_sampling import (
    CombinedDataset, MeshData, MinibatchGenerator, UniformPointCloudData,
    UniformPointCloudDataset)
from meta_learning_sdf.experiments.baseline import (
    AdaptiveLatentOptimization, LatentOptimization,
    LatentOptimizationHyperparameters, LossFunction, MarchingCubes, Model,
    ModelHyperparameters, TrainingHyperparameters, setup_model)
from meta_learning_sdf.point_cloud import render_point_cloud


def _normalize(vec: np.ndarray):
    return vec / np.linalg.norm(vec)


def _look_at(eye: np.ndarray, center: np.ndarray, up: np.ndarray):
    eye = np.asanyarray(eye)
    center = np.asanyarray(center)
    up = np.asanyarray(up)

    z = _normalize(eye - center)
    x = np.cross(up, z)
    y = np.cross(z, x)

    x = _normalize(x)
    y = _normalize(y)

    rotation_matrix = np.array(
        [
            [x[0], y[0], z[0]],
            [x[1], y[1], z[1]],
            [x[2], y[2], z[2]],
        ],
        dtype=np.float32,
    )
    translation_vector = np.array([-x @ eye, -y @ eye, -z @ eye])

    return rotation_matrix, translation_vector


def _sample_surface_points(faces: np.ndarray, vertices: np.ndarray,
                           num_point_samples: int):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    samples, face_index = trimesh.sample.sample_surface(
        mesh, num_point_samples)
    return samples


def _compute_chamfer_distance(gt_surface_points: np.ndarray,
                              pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(np.square(distances))

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(np.square(distances))

    return cd_term1 + cd_term2


def _compute_non_squared_chamfer_distance(gt_surface_points: np.ndarray,
                                          pred_surface_points: np.ndarray):
    gt_points_kd_tree = KDTree(gt_surface_points)
    distances, locations = gt_points_kd_tree.query(pred_surface_points)
    cd_term1 = np.mean(distances)

    pred_points_kd_tree = KDTree(pred_surface_points)
    distances, locations = pred_points_kd_tree.query(gt_surface_points)
    cd_term2 = np.mean(distances)

    return cd_term1 + cd_term2


def _summarize(result_directory: str):
    result_directory = Path(result_directory)
    metrics_directory = result_directory / "metrics"
    args_path = result_directory / "args.json"
    assert args_path.is_file(), args_path
    with open(args_path) as f:
        args = json.load(f)
    table = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            table[key].append(value)

    return table, args, len(metrics_path_list)


class ChamferDistance:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        data = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=0,
            max_num_target=0)
        with torch.no_grad():
            initial_h, h_dist = self.model.encoder(data.context_points_list[0])
        h = self.optimize_latent(data, initial_h)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(h)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None

        gt_surface_points = _sample_surface_points(
            gt_faces, gt_vertices, self.chamfer_distance_num_samples)
        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = _compute_chamfer_distance(gt_surface_points,
                                                     pred_surface_points)

        if False:
            camera_theta = math.pi / 3
            camera_phi = -math.pi / 4
            camera_r = 1
            eye = [
                camera_r * math.sin(camera_theta) * math.cos(camera_phi),
                camera_r * math.cos(camera_theta),
                camera_r * math.sin(camera_theta) * math.sin(camera_phi),
            ]
            rotation_matrix, translation_vector = _look_at(eye=eye,
                                                           center=[0, 0, 0],
                                                           up=[0, 1, 0])
            rotation_matrix = np.linalg.inv(rotation_matrix)

            fig, axes = plt.subplots(1, 2)
            points = (rotation_matrix
                      @ gt_surface_points.T).T + translation_vector[None, :]
            image = render_point_cloud(points=points,
                                       colors=np.zeros_like(points),
                                       camera_mag=1,
                                       point_size=1)
            axes[0].imshow(image)
            points = (rotation_matrix
                      @ pred_surface_points.T).T + translation_vector[None, :]
            image = render_point_cloud(points=points,
                                       colors=np.zeros_like(points),
                                       camera_mag=1,
                                       point_size=1)
            axes[1].imshow(image)
            plt.show()
            plt.close(fig)

        return chamfer_distance


class NonSquaredChamferDistance:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        data = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=0,
            max_num_target=0)
        with torch.no_grad():
            initial_h, h_dist = self.model.encoder(data.context_points_list[0])
        h = self.optimize_latent(data, initial_h)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(h)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None

        gt_surface_points = _sample_surface_points(
            gt_faces, gt_vertices, self.chamfer_distance_num_samples)
        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = _compute_non_squared_chamfer_distance(
            gt_surface_points, pred_surface_points)

        return chamfer_distance


class SurfaceChamferDistance:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 use_sampled_h: bool, chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.use_sampled_h = use_sampled_h
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData):
        random_state = np.random.RandomState(0)
        minibatch = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=self.chamfer_distance_num_samples,
            max_num_target=self.chamfer_distance_num_samples,
            random_state=random_state)

        gt_surface_points = minibatch.target_points.detach().cpu().numpy()[0]
        gt_surface_points = gt_surface_points / pc_data.scale - pc_data.offset

        minibatch.target_points = None
        minibatch.target_normals = None
        minibatch.context_normals_list = []

        with torch.no_grad():
            context_points = minibatch.context_points_list[0]
            if self.use_sampled_h:
                initial_h, h_dist = self.model.encoder(context_points)
            else:
                _, h_dist = self.model.encoder(context_points)
                initial_h = h_dist.mean
        h = self.optimize_latent(minibatch, initial_h)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(h)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None

        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = _compute_chamfer_distance(gt_surface_points,
                                                     pred_surface_points)

        return chamfer_distance


class SurfaceNonSquaredChamferDistance:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 use_sampled_h: bool, chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.use_sampled_h = use_sampled_h
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData):
        random_state = np.random.RandomState(0)
        minibatch = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=self.chamfer_distance_num_samples,
            max_num_target=self.chamfer_distance_num_samples,
            random_state=random_state)

        gt_surface_points = minibatch.target_points.detach().cpu()[0]
        minibatch.target_points = None
        minibatch.target_normals = None
        minibatch.context_normals_list = []

        with torch.no_grad():
            context_points = minibatch.context_points_list[0]
            if self.use_sampled_h:
                initial_h, h_dist = self.model.encoder(context_points)
            else:
                _, h_dist = self.model.encoder(context_points)
                initial_h = h_dist.mean
        h = self.optimize_latent(minibatch, initial_h)

        grid_max_value = 1
        grid_min_value = -1
        marching_cubes = MarchingCubes(model=self.model,
                                       grid_size=self.grid_size,
                                       grid_max_value=grid_max_value,
                                       grid_min_value=grid_min_value)
        try:
            mc_vertices, mc_faces = marching_cubes(h)
            mc_vertices = mc_vertices / pc_data.scale - pc_data.offset
        except ValueError:
            return None

        gt_surface_points = gt_surface_points / pc_data.scale - pc_data.offset
        pred_surface_points = _sample_surface_points(
            mc_faces, mc_vertices, self.chamfer_distance_num_samples)

        chamfer_distance = _compute_non_squared_chamfer_distance(
            gt_surface_points, pred_surface_points)

        return chamfer_distance


class ChamferDistanceGradientDescent:
    def __init__(self,
                 minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization,
                 model: Model,
                 grid_size: int,
                 latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int,
                 gd_steps: int = 20):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples
        self.gd_steps = gd_steps

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        minibatch = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=self.chamfer_distance_num_samples,
            max_num_target=self.chamfer_distance_num_samples,
            tuple_target=False)
        mesh_surface_points = minibatch.target_points

        minibatch.target_points = None
        minibatch.target_normals = None

        with torch.no_grad():
            context_points = minibatch.context_points_list[0]
            initial_h, h_dist = self.model.encoder(context_points)
        h = self.optimize_latent(minibatch, initial_h)

        _h = h[:, None, :].expand(
            (h.shape[0], self.chamfer_distance_num_samples, h.shape[1]))

        moving_points = mesh_surface_points
        moving_points.requires_grad_(True)
        distance_eps = 1e-6
        batch_index = 0
        for k in range(self.gd_steps):
            print("step", k)
            distance = self.model.decoder(X=moving_points, h=_h)
            distance_grad = torch.autograd.grad(distance.sum(),
                                                moving_points,
                                                create_graph=False)[0]
            direction = (distance_grad /
                         torch.norm(distance_grad, dim=2, keepdim=True))
            moving_points = moving_points - direction * distance
            distance = self.model.decoder(X=moving_points, h=_h).squeeze(dim=2)

            accepted = torch.abs(distance[batch_index]) < distance_eps
            sdf_surface_points = moving_points[batch_index][accepted]
            if len(sdf_surface_points) == self.chamfer_distance_num_samples:
                break

            # print(random_points.shape)
            # print(torch.max(distance).item(), torch.min(distance).item())

        sdf_surface_points = sdf_surface_points.detach().cpu()
        mesh_surface_points = mesh_surface_points.detach().cpu()
        mesh_surface_points = mesh_surface_points[batch_index][accepted]

        l2_norm = np.linalg.norm((sdf_surface_points - mesh_surface_points),
                                 axis=1)

        print(len(sdf_surface_points), "<->", len(mesh_surface_points))

        if len(sdf_surface_points) == 0:
            chamfer_distance = None
        else:
            chamfer_distance = float(np.mean(l2_norm))

        return chamfer_distance


class ChamferDistanceMeanSdf:
    def __init__(self, minibatch_generator: MinibatchGenerator,
                 optimize_latent_func: LatentOptimization, model: Model,
                 grid_size: int, latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int):
        self.minibatch_generator = minibatch_generator
        self.optimize_latent = optimize_latent_func
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples

    def __call__(self, pc_data: UniformPointCloudData, mesh_data: MeshData):
        minibatch = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=self.chamfer_distance_num_samples,
            max_num_target=self.chamfer_distance_num_samples,
            tuple_target=False)
        points = minibatch.target_points
        points.requires_grad_(True)

        minibatch.target_points = None
        minibatch.target_normals = None

        with torch.no_grad():
            initial_h, h_dist = self.model.encoder(
                minibatch.context_points_list[0])
        h = self.optimize_latent(minibatch, initial_h)

        _h = h[:, None, :].expand(
            (h.shape[0], self.chamfer_distance_num_samples, h.shape[1]))

        distance = self.model.decoder(X=points, h=_h)
        distance = abs(distance).detach().cpu()
        chamfer_distance = distance.mean().item()
        return chamfer_distance


class ChamferDistanceSteps:
    def __init__(self,
                 minibatch_generator: MinibatchGenerator,
                 latent_optimization: LatentOptimization,
                 model: Model,
                 grid_size: int,
                 latent_optimization_num_samples: int,
                 chamfer_distance_num_samples: int,
                 use_sampled_h: bool,
                 seed: int = 0):
        self.minibatch_generator = minibatch_generator
        self.latent_optimization = latent_optimization
        self.model = model
        self.grid_size = grid_size
        self.latent_optimization_num_samples = latent_optimization_num_samples
        self.chamfer_distance_num_samples = chamfer_distance_num_samples
        self.use_sampled_h = use_sampled_h
        self.seed = seed

    def __call__(self, pc_data: UniformPointCloudData):
        random_state = np.random.RandomState(self.seed)
        minibatch = self.minibatch_generator(
            [pc_data],
            min_num_context=self.latent_optimization_num_samples,
            max_num_context=self.latent_optimization_num_samples,
            min_num_target=self.chamfer_distance_num_samples,
            max_num_target=self.chamfer_distance_num_samples,
            random_state=random_state)
        torch.manual_seed(self.seed)

        mesh_surface_points = minibatch.target_points

        minibatch.target_points = None
        minibatch.target_normals = None

        context_points = minibatch.context_points_list[0]

        with torch.no_grad():
            if self.use_sampled_h:
                initial_h, h_dist = self.model.encoder(context_points)
            else:
                _, h_dist = self.model.encoder(context_points)
                initial_h = h_dist.mean

        result = []
        for step, h in self.latent_optimization.steps(minibatch, initial_h):
            _h = h[:, None, :].expand(
                (h.shape[0], self.chamfer_distance_num_samples, h.shape[1]))

            distance = self.model.decoder(X=mesh_surface_points, h=_h)
            distance = abs(distance).detach().cpu()**2
            chamfer_distance = distance.mean().item()

            result.append((step, chamfer_distance))
            print("step:",
                  step,
                  "chamfer_distance:",
                  chamfer_distance,
                  flush=True)

        return result


@click.group()
def client():
    pass


@client.command()
@click.argument("--result-directory", type=str, required=True)
def summarize(args):
    metric_list = ["chamfer_distance"]
    settings_list = [
        "latent_optimization_num_samples", "latent_optimization_iterations"
    ]
    result, args, num_data = _summarize(args.result_directory)
    print(colorful.bold("Hyperparameters:"))
    print(args)

    # find worst-k
    top_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids)
                   if value is not None]
        ranking = list(sorted(ranking, key=lambda item: item[0]))
        ranking = ranking[:top_k]
        print(colorful.bold(f"Top {top_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    # find worst-k
    worst_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids)
                   if value is not None]
        ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
        ranking = ranking[:worst_k]
        print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    data = {
        "learning_rate": args["learning_rate"],
        "num_data": num_data,
    }
    for metric_name in metric_list:
        values = [value for value in result[metric_name] if value is not None]
        values = np.array(values)
        print(np.min(values), np.max(values))
        print(colorful.bold("Histogram:"))
        print(np.histogram(values))
        mean = np.mean(values)
        std = np.std(values)
        data[f"{metric_name}_mean"] = mean
        data[f"{metric_name}_std"] = std
    for key in settings_list:
        values = np.array(result[key])
        data[key] = values[0]

    tabulate_row = [data["num_data"]]
    for key in settings_list:
        value = data[key]
        tabulate_row.append(f"{value}")
    for metric_name in metric_list:
        mean = data[f"{metric_name}_mean"]
        std = data[f"{metric_name}_std"]
        tabulate_row.append(f"{mean:.06f} (±{std:.06f})")

    print(colorful.bold("Result:"))
    print(
        tabulate([tabulate_row],
                 headers=["# of data"] + settings_list + metric_list,
                 tablefmt="github"))


@client.command(name="summarize_gt")
@click.argument("--result-directory", type=str, required=True)
def summarize_gt(args):
    result_directory = Path(args.result_directory)
    metrics_directory = result_directory / "metrics"
    result = defaultdict(list)
    metrics_path_list = list(metrics_directory.glob("*.json"))
    for metrics_path in metrics_path_list:
        try:
            with open(metrics_path) as f:
                metrics = json.load(f)
        except json.decoder.JSONDecodeError:
            print("Error:", metrics_path)
            continue
        for key, value in metrics.items():
            result[key].append(value)

    # find top-k
    top_k = 10
    model_ids = result["model_id"]
    values = result["chamfer_distance"]
    ranking = [(value, model_id) for value, model_id in zip(values, model_ids)
               if value > 0]
    ranking = list(sorted(ranking, key=lambda item: item[0]))
    ranking = ranking[:top_k]
    print(colorful.bold(f"Top {top_k} chamfer_distance:"))
    for item in ranking:
        print(item[0], item[1])

    # find worst-k
    worst_k = 10
    model_ids = result["model_id"]
    values = result["chamfer_distance"]
    ranking = [(value, model_id) for value, model_id in zip(values, model_ids)
               if value > 0]
    ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
    ranking = ranking[:worst_k]
    print(colorful.bold(f"Worst {worst_k} chamfer_distance:"))
    for item in ranking:
        print(item[0], item[1])

    values = np.array(result["chamfer_distance"])
    ignore = values < 0
    values = values[~ignore]
    mean = np.mean(values)
    std = np.std(values)

    tabulate_row = []
    tabulate_row.append(f"{mean:.06f} (±{std:.06f})")

    print(colorful.bold("Result:"))
    print(
        tabulate([tabulate_row],
                 headers=["chamfer_distance"],
                 tablefmt="github"))


@client.command(name="summarize_by_category")
@click.argument("--result-directory", type=str, required=True)
def summarize_by_category(args):
    metric_list = ["chamfer_distance"]
    result, args, num_data = _summarize(args.result_directory)
    print(colorful.bold("Hyperparameters:"))
    print(args)

    latent_optimization_num_samples = result[
        "latent_optimization_num_samples"][0]
    latent_optimization_iterations = result["latent_optimization_iterations"][
        0]

    # find top-k
    top_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(sorted(ranking, key=lambda item: item[0]))
        ranking = ranking[:top_k]
        print(colorful.bold(f"Top {top_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    # find worst-k
    worst_k = 10
    for metric_name in metric_list:
        model_ids = result["model_id"]
        values = result[metric_name]
        ranking = [(value, model_id)
                   for value, model_id in zip(values, model_ids) if value > 0]
        ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
        ranking = ranking[:worst_k]
        print(colorful.bold(f"Worst {worst_k} {metric_name}:"))
        for item in ranking:
            print(item[0], item[1])

    map_category_name = {
        "02691156": "Plane",
        "04256520": "Sofa",
        "04379243": "Table",
        "04530566": "Vessel",
        "02958343": "Car",
        "03001627": "Chair",
        "02933112": "Cabinet",
        "03636649": "Lamp",
        "02818832": "Bed",
        "02828884": "Bench",
        "02871439": "Bookshelf",
        "02924116": "Bus",
        "03467517": "Guitar",
        "03790512": "Motorbike",
        "03948459": "Pistol",
        "04225987": "Skateboard",
    }

    data_list = []
    model_ids = result["model_id"]
    category_id_set = set()
    for k, model_id in enumerate(model_ids):
        category_id = model_id.split("_")[0]
        object_id = model_id.split("_")[1]
        category_id_set.add(category_id)
        data = {"category_id": category_id, "object_id": object_id}
        skip = False
        for metric_name in metric_list:
            value = result[metric_name][k]
            data[metric_name] = value
            if value < 0:
                skip = True
        if skip:
            continue
        data_list.append(data)

    print(colorful.bold("Result:"))
    table = pd.DataFrame(data_list)
    for metric_name in metric_list:
        print(colorful.bold_green(metric_name))
        tabulate_header = [
            "latent_optimization_num_samples", "latent_optimization_iterations"
        ]
        tabulate_row = [
            latent_optimization_num_samples, latent_optimization_iterations
        ]
        for category_id in sorted(category_id_set):
            df = table[table["category_id"] == category_id]
            mean = df[metric_name].mean()
            num_data = df[metric_name].count()
            print(map_category_name[category_id], num_data)
            tabulate_header.append(map_category_name[category_id])
            tabulate_row.append(f"{mean:.06f}")
        print(
            tabulate([tabulate_row],
                     headers=tabulate_header,
                     tablefmt="github"))


@client.command(name="summarize_all")
@click.argument("--result-root-directory", type=str, required=True)
def summarize_all(args):
    metric_list = ["chamfer_distance"]
    result_root_directory = Path(args.result_root_directory)

    data_list = []
    for result_directory in result_root_directory.iterdir():
        result, args, num_data = _summarize(result_directory, metric_list)
        data = {
            "max_num_target": args["max_num_target"],
            "kld_weight": args["loss_kld_initial_weight"],
            "learning_rate": args["learning_rate"],
            "num_samples": args["max_num_target"],
            "num_data": num_data,
        }
        for metric_name in metric_list:
            values = np.array(result[metric_name])
            ignore = values < 0
            values = values[~ignore]
            mean = np.mean(values)
            std = np.std(values)
            data[f"{metric_name}_mean"] = mean
            data[f"{metric_name}_std"] = std
        data_list.append(data)

    table = pd.DataFrame(data_list)

    lr_list = [0.0001, 0.00032]
    for lr in lr_list:
        df = table[table["learning_rate"] == lr]
        tabulate_rows = []
        print(lr)
        for index, row in df.iterrows():
            tabulate_row = [
                row["num_data"], row["num_samples"], row["kld_weight"]
            ]
            for metric_name in metric_list:
                mean = row[f"{metric_name}_mean"]
                std = row[f"{metric_name}_std"]
                tabulate_row.append(f"{mean:.06f} (±{std:.06f})")
            tabulate_rows.append(tabulate_row)

        print(
            tabulate(tabulate_rows,
                     headers=["# of data", "# of samples", "kld_weight"] +
                     metric_list,
                     tablefmt="github"))


@client.command(name="plot_xy")
@click.argument("--output-directory", type=str, required=True)
def plot_xy(args):
    output_directory = Path(args.output_directory)
    mkdir(output_directory)

    latent_optimization_iterations = 0
    method = "mean_sdf"

    symmetric_directories = [
        f"evaluations/method_comparison_surface/non_squared_symmetric/cd_30000/sv2_cars/epoch_5000/latent_optim_{latent_optimization_iterations}/100_samples/lr_0.0001/f23fd0e76fa0",
        f"evaluations/method_comparison_surface/non_squared_symmetric/cd_30000/sv2_cars/epoch_5000/latent_optim_{latent_optimization_iterations}/1000_samples/lr_0.0001/f23fd0e76fa0",
        f"evaluations/method_comparison_surface/non_squared_symmetric/cd_30000/sv2_cars/epoch_5000/latent_optim_{latent_optimization_iterations}/30000_samples/lr_0.005/f23fd0e76fa0",
    ]
    target_directories = [
        f"evaluations/method_comparison/{method}/cd_30000/sv2_cars/epoch_5000/latent_optim_{latent_optimization_iterations}/100_samples/lr_0.0001/f23fd0e76fa0",
        f"evaluations/method_comparison/{method}/cd_30000/sv2_cars/epoch_5000/latent_optim_{latent_optimization_iterations}/1000_samples/lr_0.0001/f23fd0e76fa0",
        f"evaluations/method_comparison/{method}/cd_30000/sv2_cars/epoch_5000/latent_optim_{latent_optimization_iterations}/30000_samples/lr_0.005/f23fd0e76fa0",
    ]

    data = []
    ranking = []
    for symmetric_directory, target_directory in zip(symmetric_directories,
                                                     target_directories):
        print(symmetric_directory, target_directory)
        symmetric_map_model_cd = {}
        symmetric_result = _summarize(symmetric_directory)[0]
        for model_id, chamfer_distance in zip(
                symmetric_result["model_id"],
                symmetric_result["chamfer_distance"]):
            if chamfer_distance is None:
                continue
            if chamfer_distance < 0:
                continue
            symmetric_map_model_cd[model_id] = chamfer_distance

        result = _summarize(target_directory)[0]
        latent_optimization_num_samples = result[
            "latent_optimization_num_samples"][0]
        for model_id, chamfer_distance in zip(result["model_id"],
                                              result["chamfer_distance"]):
            if model_id not in symmetric_map_model_cd:
                continue
            symmetric_chamfer_distance = symmetric_map_model_cd[model_id]
            if symmetric_chamfer_distance is None:
                continue
            if symmetric_chamfer_distance < 0:
                continue
            if chamfer_distance is None:
                continue
            if chamfer_distance < 0:
                continue
            data.append({
                "symmetric": symmetric_chamfer_distance,
                f"{method}": chamfer_distance,
                "# of context": str(latent_optimization_num_samples),
            })
            ratio = symmetric_chamfer_distance / chamfer_distance
            ranking.append((ratio, symmetric_chamfer_distance,
                            chamfer_distance, model_id))

    worst_k = 10
    ranking = list(reversed(sorted(ranking, key=lambda item: item[0])))
    ranking = ranking[:worst_k]
    print(colorful.bold(f"Worst {worst_k}:"))
    for item in ranking:
        print(item[0], item[1], item[2], item[3])

    figsize_px = np.array([1000, 1000])
    dpi = 100
    figsize_inch = figsize_px / dpi
    plt.figure(figsize=figsize_inch)

    df = pd.DataFrame(data)
    print(df)
    sns.scatterplot(
        x=method,
        y="symmetric",
        data=df,
        hue="# of context",
        style="# of context",
        legend="full",
        palette="colorblind",
    )
    plt.xlim([0, 0.1])
    plt.ylim([0, 0.1])
    plt.suptitle(
        f"latent_optimization_iterations={latent_optimization_iterations}")
    plt.savefig(output_directory /
                f"{method}_optim_{latent_optimization_iterations}.png")


@client.command(name="chamfer_distance_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--mesh-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=512)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.argument("--chamfer-distance-method",
                type=click.Choice([
                    "symmetric", "non_squared_symmetric", "gradient_descent",
                    "mean_sdf"
                ]),
                required=True)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_data(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_data_path = Path(args.pc_data_path)
    mesh_data_path = Path(args.mesh_data_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = CombinedDataset([(pc_data_path, mesh_data_path)],
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=0,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    latent_optimization = LatentOptimization(model=model,
                                             loss_function=loss_function,
                                             params=latent_optimization_params)

    if args.chamfer_distance_method == "symmetric":
        compute_chamfer_distance = ChamferDistance(
            minibatch_generator=minibatch_generator,
            optimize_latent_func=latent_optimization,
            model=model,
            grid_size=args.grid_size,
            latent_optimization_num_samples=args.
            latent_optimization_num_samples,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples)
    elif args.chamfer_distance_method == "non_squared_symmetric":
        compute_chamfer_distance = NonSquaredChamferDistance(
            minibatch_generator=minibatch_generator,
            optimize_latent_func=latent_optimization,
            model=model,
            grid_size=args.grid_size,
            latent_optimization_num_samples=args.
            latent_optimization_num_samples,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples)
    elif args.chamfer_distance_method == "gradient_descent":
        compute_chamfer_distance = ChamferDistanceGradientDescent(
            minibatch_generator=minibatch_generator,
            optimize_latent_func=latent_optimization,
            model=model,
            grid_size=args.grid_size,
            latent_optimization_num_samples=args.
            latent_optimization_num_samples,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples)
    elif args.chamfer_distance_method == "mean_sdf":
        compute_chamfer_distance = ChamferDistanceMeanSdf(
            minibatch_generator=minibatch_generator,
            optimize_latent_func=latent_optimization,
            model=model,
            grid_size=args.grid_size,
            latent_optimization_num_samples=args.
            latent_optimization_num_samples,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples)
    else:
        raise NotImplementedError()

    for data_tuple in dataset:
        pc_data = data_tuple[0]
        mesh_data = data_tuple[1]
        chamfer_distance = compute_chamfer_distance(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)
        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "latent_optimization_num_samples":
            args.latent_optimization_num_samples,
            "grid_size": args.grid_size,
            "latent_optimization_iterations":
            args.latent_optimization_iterations,
            "latent_optimization_initial_lr":
            args.latent_optimization_initial_lr,
            "chamfer_distance_method": args.chamfer_distance_method,
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_surface_data")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--use-sampled-h", is_flag=True)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.argument("--chamfer-distance-method",
                type=click.Choice(["symmetric", "non_squared_symmetric"]),
                required=True)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_surface_data(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_data_path = Path(args.pc_data_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    dataset = UniformPointCloudDataset([pc_data_path])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=0,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    latent_optimization = AdaptiveLatentOptimization(
        model=model,
        loss_function=loss_function,
        params=latent_optimization_params)

    if args.chamfer_distance_method == "symmetric":
        compute_chamfer_distance = SurfaceChamferDistance(
            minibatch_generator=minibatch_generator,
            optimize_latent_func=latent_optimization,
            model=model,
            grid_size=args.grid_size,
            latent_optimization_num_samples=args.
            latent_optimization_num_samples,
            use_sampled_h=args.use_sampled_h,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples)
    elif args.chamfer_distance_method == "non_squared_symmetric":
        compute_chamfer_distance = SurfaceNonSquaredChamferDistance(
            minibatch_generator=minibatch_generator,
            optimize_latent_func=latent_optimization,
            model=model,
            grid_size=args.grid_size,
            latent_optimization_num_samples=args.
            latent_optimization_num_samples,
            use_sampled_h=args.use_sampled_h,
            chamfer_distance_num_samples=args.chamfer_distance_num_samples)
    else:
        raise NotImplementedError()

    for pc_data in dataset:

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        result_path = metrics_directory / f"{category_id}_{model_id}.json"

        chamfer_distance = compute_chamfer_distance(pc_data)
        print(category_id, model_id, chamfer_distance, flush=True)
        print(result_path)

        result = {
            "model_id":
            f"{category_id}_{model_id}",
            "chamfer_distance":
            chamfer_distance,
            "chamfer_distance_num_samples":
            args.chamfer_distance_num_samples,
            "grid_size":
            args.grid_size,
            "use_sampled_h":
            args.use_sampled_h,
            "latent_optimization_num_samples":
            latent_optimization_params.num_samples,
            "latent_optimization_iterations":
            latent_optimization_params.iterations,
            "latent_optimization_initial_lr":
            latent_optimization_params.initial_lr,
            "latent_optimization_decrease_lr_every":
            latent_optimization_params.decrease_lr_every,
            "latent_optimization_lr_decay_factor":
            latent_optimization_params.lr_decay_factor,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_dataset")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-dataset-directory", type=str, required=True)
@click.argument("--mesh-dataset-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=512)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_dataset(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_dataset_directory = Path(args.pc_dataset_directory)
    mesh_dataset_directory = Path(args.mesh_dataset_directory)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    dest_path = output_directory / "args.json"
    if not dest_path.exists():
        shutil.copyfile(args_path, dest_path)

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    combined_data_path_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            pc_data_path = pc_dataset_directory / category / model_id / "point_cloud.npz"
            mesh_data_path = mesh_dataset_directory / category / model_id / "models" / "model_normalized.obj"
            if not pc_data_path.exists():
                continue
            if not mesh_data_path.exists():
                continue
            combined_data_path_list.append((pc_data_path, mesh_data_path))

    print(len(combined_data_path_list))
    dataset = CombinedDataset(combined_data_path_list,
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    latent_optimization = LatentOptimization(model=model,
                                             loss_function=loss_function,
                                             params=latent_optimization_params)
    compute_chamfer_distance = ChamferDistanceSteps(
        loss_function=loss_function,
        minibatch_generator=minibatch_generator,
        latent_optimization=latent_optimization,
        model=model,
        grid_size=args.grid_size,
        latent_optimization_num_samples=args.latent_optimization_num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for data_tuple in dataset:
        pc_data = data_tuple[0]
        mesh_data = data_tuple[1]
        chamfer_distance = compute_chamfer_distance(pc_data, mesh_data)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)
        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples,
            "latent_optimization_num_samples":
            args.latent_optimization_num_samples,
            "grid_size": args.grid_size,
            "latent_optimization_iterations":
            args.latent_optimization_iterations,
            "latent_optimization_initial_lr":
            args.latent_optimization_initial_lr
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="gt_chamfer_distance_dataset")
@click.argument("--output-directory", type=str, required=True)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--npz-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
def gt_chamfer_distance_dataset(args):
    npz_dataset_directory = Path(args.npz_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    combined_data_path_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            pc_data_path = npz_dataset_directory / category / model_id / "point_cloud.npz"
            mesh_data_path = obj_dataset_directory / category / model_id / "models" / "model_normalized.obj"
            if not pc_data_path.exists():
                continue
            if not mesh_data_path.exists():
                continue
            combined_data_path_list.append((pc_data_path, mesh_data_path))

    print(len(combined_data_path_list))

    dataset = CombinedDataset(combined_data_path_list,
                              [read_uniform_point_cloud_data, read_mesh_data])

    for data_tuple in dataset:
        pc_data = data_tuple[0]
        mesh_data = data_tuple[1]

        gt_faces = mesh_data.vertex_indices
        gt_vertices = mesh_data.vertices

        gt_surface_points_1 = _sample_surface_points(
            gt_faces, gt_vertices, args.chamfer_distance_num_samples)
        gt_surface_points_2 = _sample_surface_points(
            gt_faces, gt_vertices, args.chamfer_distance_num_samples)

        chamfer_distance = _compute_chamfer_distance(gt_surface_points_1,
                                                     gt_surface_points_2)

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        print(category_id, model_id, chamfer_distance, flush=True)
        result = {
            "model_id": f"{category_id}_{model_id}",
            "chamfer_distance": chamfer_distance,
            "chamfer_distance_num_samples": args.chamfer_distance_num_samples
        }
        with open(metrics_directory / f"{category_id}_{model_id}.json",
                  "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_data_for_each_steps")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--mesh-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--use-sampled-h", is_flag=True)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.argument("--log-every", type=int, default=10)
@click.argument("--skip-if-exists", is_flag=True)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_data_for_each_steps(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_data_path = Path(args.pc_data_path)
    mesh_data_path = Path(args.mesh_data_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    args_dest_path = output_directory / "args.json"
    if not args_dest_path.exists():
        shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    combined_data_path_list = [(pc_data_path, mesh_data_path)]
    dataset = CombinedDataset(combined_data_path_list,
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    latent_optimization = LatentOptimization(model=model,
                                             loss_function=loss_function,
                                             params=latent_optimization_params,
                                             yield_every=args.log_every)
    compute_chamfer_distance_for_each_step = ChamferDistanceSteps(
        minibatch_generator=minibatch_generator,
        latent_optimization=latent_optimization,
        model=model,
        use_sampled_h=args.use_sampled_h,
        grid_size=args.grid_size,
        latent_optimization_num_samples=latent_optimization_params.num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for data_index, data_tuple in enumerate(dataset.shuffle()):
        pc_data: UniformPointCloudData = data_tuple[0]
        mesh_data: UniformPointCloudData = data_tuple[1]

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        result_path = metrics_directory / f"{category_id}_{model_id}.json"

        if args.skip_if_exists and result_path.exists():
            print(result_path, "skipped.")
            continue

        chamfer_distance_for_each_steps = compute_chamfer_distance_for_each_step(
            pc_data)

        print(category_id, model_id, "done", flush=True)
        result = {
            "model_id":
            f"{category_id}_{model_id}",
            "chamfer_distance_for_each_steps":
            chamfer_distance_for_each_steps,
            "chamfer_distance_num_samples":
            args.chamfer_distance_num_samples,
            "grid_size":
            args.grid_size,
            "use_sampled_h":
            args.use_sampled_h,
            "latent_optimization_num_samples":
            latent_optimization_params.num_samples,
            "latent_optimization_iterations":
            latent_optimization_params.iterations,
            "latent_optimization_initial_lr":
            latent_optimization_params.initial_lr,
            "latent_optimization_decrease_lr_every":
            latent_optimization_params.decrease_lr_every,
            "latent_optimization_lr_decay_factor":
            latent_optimization_params.lr_decay_factor,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="chamfer_distance_data_for_each_steps_")
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--output-directory", type=str, required=True)
@click.argument("--pc-data-path", type=str, required=True)
@click.argument("--mesh-data-path", type=str, required=True)
@click.argument("--grid-size", type=int, default=256)
@click.argument("--use-sampled-h", is_flag=True)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.argument("--log-every", type=int, default=10)
@click.argument("--skip-if-exists", is_flag=True)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def chamfer_distance_data_for_each_steps_(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    pc_data_path = Path(args.pc_data_path)
    mesh_data_path = Path(args.mesh_data_path)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    args_dest_path = output_directory / "args.json"
    if not args_dest_path.exists():
        shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    combined_data_path_list = [(pc_data_path, mesh_data_path)]
    dataset = CombinedDataset(combined_data_path_list,
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    latent_optimization = AdaptiveLatentOptimization(
        model=model,
        loss_function=loss_function,
        params=latent_optimization_params,
        yield_every=args.log_every)
    compute_chamfer_distance_for_each_step = ChamferDistanceSteps(
        minibatch_generator=minibatch_generator,
        latent_optimization=latent_optimization,
        model=model,
        use_sampled_h=args.use_sampled_h,
        grid_size=args.grid_size,
        latent_optimization_num_samples=latent_optimization_params.num_samples,
        chamfer_distance_num_samples=args.chamfer_distance_num_samples)

    for data_index, data_tuple in enumerate(dataset.shuffle()):
        pc_data: UniformPointCloudData = data_tuple[0]

        parts = str(pc_data.path).split("/")
        category_id = parts[-3]
        model_id = parts[-2]
        result_path = metrics_directory / f"{category_id}_{model_id}.json"

        if args.skip_if_exists and result_path.exists():
            print(result_path, "skipped.")
            continue

        chamfer_distance_for_each_steps = compute_chamfer_distance_for_each_step(
            pc_data)

        print(category_id, model_id, "done", flush=True)
        result = {
            "model_id":
            f"{category_id}_{model_id}",
            "chamfer_distance_for_each_steps":
            chamfer_distance_for_each_steps,
            "chamfer_distance_num_samples":
            args.chamfer_distance_num_samples,
            "grid_size":
            args.grid_size,
            "use_sampled_h":
            args.use_sampled_h,
            "optimizer":
            args.latent_optimization_optimizer,
            "latent_optimization_num_samples":
            latent_optimization_params.num_samples,
            "latent_optimization_iterations":
            latent_optimization_params.iterations,
            "latent_optimization_initial_lr":
            latent_optimization_params.initial_lr,
            "latent_optimization_decrease_lr_every":
            latent_optimization_params.decrease_lr_every,
            "latent_optimization_lr_decay_factor":
            latent_optimization_params.lr_decay_factor,
        }
        with open(result_path, "w") as f:
            json.dump(result, f, indent=4, sort_keys=True)


@client.command(name="check_loss_step_0")
@click.argument("--output-directory", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, required=True)
@click.argument("--checkpoint-epoch", type=int, default=None)
@click.argument("--test-split-path", type=str, required=True)
@click.argument("--npz-dataset-directory", type=str, required=True)
@click.argument("--obj-dataset-directory", type=str, required=True)
@click.argument("--use-sampled-h", is_flag=True)
@click.argument("--chamfer-distance-num-samples", type=int, default=30000)
@click.hyperparameter_class(LatentOptimizationHyperparameters)
def check_loss_step_0(
        args, latent_optimization_params: LatentOptimizationHyperparameters):
    device = torch.device("cuda", 0)
    npz_dataset_directory = Path(args.npz_dataset_directory)
    obj_dataset_directory = Path(args.obj_dataset_directory)
    output_directory = Path(args.output_directory)
    metrics_directory = output_directory / "metrics"
    mkdir(output_directory)
    mkdir(metrics_directory)

    checkpoint_directory = Path(args.checkpoint_directory)
    args_path = checkpoint_directory / "args.json"
    if args.checkpoint_epoch is None:
        model_path = checkpoint_directory / "model.pt"
    else:
        model_path = checkpoint_directory / f"model.{args.checkpoint_epoch}.pt"
    assert args_path.is_file()
    assert model_path.is_file()

    args_dest_path = output_directory / "args.json"
    if not args_dest_path.exists():
        shutil.copyfile(args_path, output_directory / "args.json")

    model_hyperparams = ModelHyperparameters.load_json(args_path)
    training_hyperparams = TrainingHyperparameters.load_json(args_path)

    model = setup_model(model_hyperparams)
    load_model(model_path, model)
    model.to(device)
    model.eval()

    test_split_path = Path(args.test_split_path)
    assert test_split_path.is_file()

    with open(test_split_path) as f:
        split = json.load(f)

    combined_data_path_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            pc_data_path = npz_dataset_directory / category / model_id / "point_cloud.npz"
            mesh_data_path = obj_dataset_directory / category / model_id / "models" / "model_normalized.obj"
            combined_data_path_list.append((pc_data_path, mesh_data_path))

    print(len(combined_data_path_list))
    dataset = CombinedDataset(combined_data_path_list,
                              [read_uniform_point_cloud_data, read_mesh_data])
    minibatch_generator = MinibatchGenerator(device=device)

    def kld_weight_func():
        if training_hyperparams.anneal_kld_weight:
            return training_hyperparams.loss_kld_final_weight
        return training_hyperparams.loss_kld_initial_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    latent_optimization = AdaptiveLatentOptimization(
        model=model,
        loss_function=loss_function,
        params=latent_optimization_params,
        yield_every=10)

    for data_tuple in dataset:
        pc_data = data_tuple[0]

        random_state = np.random.RandomState(0)
        minibatch = minibatch_generator(
            [pc_data],
            min_num_context=latent_optimization_params.num_samples,
            max_num_context=latent_optimization_params.num_samples,
            min_num_target=args.chamfer_distance_num_samples,
            max_num_target=args.chamfer_distance_num_samples,
            random_state=random_state)
        torch.manual_seed(0)

        mesh_surface_points = minibatch.target_points

        minibatch.target_points = None
        minibatch.target_normals = None

        context_points = minibatch.context_points_list[0]

        with torch.no_grad():
            _, h_dist = model.encoder(context_points)
            # use E[h|D] as initial h
            initial_h = h_dist.mean

        result = []
        for step, h in latent_optimization.steps(minibatch, initial_h):
            _h = h[:, None, :].expand(
                (h.shape[0], args.chamfer_distance_num_samples, h.shape[1]))

            distance = model.decoder(X=mesh_surface_points, h=_h)
            distance = abs(distance).detach().cpu()**2
            chamfer_distance = distance.mean().item()

            result.append((step, chamfer_distance))
            print("step:",
                  step,
                  "chamfer_distance:",
                  chamfer_distance,
                  flush=True)


@client.command(name="plot_chamfer_distance")
@click.argument("--result-directory", type=str, required=True)
@click.argument("--relative-value", is_flag=True)
def plot_chamfer_distance(args):
    result, hyperparams, num_data = _summarize(args.result_directory)
    print(colorful.bold("Hyperparameters:"))
    print(hyperparams)

    latent_optimization_num_samples = result[
        "latent_optimization_num_samples"][0]
    latent_optimization_iterations = result["latent_optimization_iterations"][
        0]
    # latent_optimization_initial_lr = result["latent_optimization_initial_lr"][
    #     0]
    latent_optimization_initial_lr = "0.005"
    print(latent_optimization_initial_lr)

    data_list = []
    for model_id, chamfer_distance_for_each_steps in zip(
            result["model_id"], result["chamfer_distance_for_each_steps"]):
        if chamfer_distance_for_each_steps == -1:
            continue
        for step, chamfer_distance in chamfer_distance_for_each_steps:
            if args.relative_value:
                if step == 0:
                    base_chamfer_distance = chamfer_distance
                    chamfer_distance = 1
                else:
                    chamfer_distance = chamfer_distance / base_chamfer_distance
            data_list.append({
                "model_id": model_id,
                "step": step,
                "chamfer_distance": chamfer_distance,
                "lr": latent_optimization_initial_lr
            })
            if step == 800:
                break

    df = pd.DataFrame(data_list)
    grouped = df.groupby("step")
    print(grouped)
    print(grouped.size())
    mean = grouped.mean()
    print(mean)

    ax = sns.lineplot(x="step",
                      y="chamfer_distance",
                      hue="lr",
                      legend="full",
                      palette="colorblind",
                      data=df)
    # ax.set_ylim([0.90, 1.05])
    plt.show()


@client.command(name="plot_chamfer_distance_steps")
@click.argument("--output-directory", type=str)
def plot_chamfer_distance_steps(args):
    run_id = "f23fd0e76fa0"
    output_directory = Path(args.output_directory) / "baseline" / run_id
    num_samples = 50
    mkdir(output_directory)
    directories = [
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.005/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.004/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.003/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.002/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.00125/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.001/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.00075/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0005/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0015/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.00175/{num_samples}_samples/{run_id}",
    ]

    def plot(relative: bool):
        param_set = set()
        data_list = []
        metric = "relative_chamfer_distance" if relative else "chamfer_distance"
        for directory in directories:
            result, hyperparams, num_data = _summarize(directory)
            decrease_lr_every = result[
                "latent_optimization_decrease_lr_every"][0]
            lr = result["latent_optimization_initial_lr"][0]

            for model_id, chamfer_distance_for_each_steps in zip(
                    result["model_id"],
                    result["chamfer_distance_for_each_steps"]):
                for step, chamfer_distance in chamfer_distance_for_each_steps:
                    if chamfer_distance is None:
                        if step == 0:
                            break
                        continue
                    if relative:
                        if step == 0:
                            base_chamfer_distance = chamfer_distance
                            chamfer_distance = 1
                        else:
                            chamfer_distance = chamfer_distance / base_chamfer_distance
                    param = f"{lr}_{decrease_lr_every}"
                    data_list.append({
                        "model_id": model_id,
                        "step": step,
                        metric: chamfer_distance,
                        "lr": lr,
                        "decrease_lr_every": decrease_lr_every,
                        "param": param
                    })
                    param_set.add(lr)

        df = pd.DataFrame(data_list)
        grouped = df.groupby("step")
        # print(grouped)
        print(grouped.size())
        print(grouped.mean())
        # mean = grouped.mean()
        # print(mean)

        solid = ()
        loosely_dotted = (1, 10)
        dotted = (1, 1)
        loosely_dashed = (5, 10)
        dashed = (5, 5)
        densely_dashed = (5, 1)
        loosely_dashdotted = (3, 10, 1, 10)
        dashdotted = (3, 3, 1, 3)
        densely_dashdotted = (3, 1, 1, 1)
        dashdotdotted = (3, 3, 1, 3, 1, 3)
        loosely_dashdotdotted = (3, 1, 1, 1, 1, 1)

        dash_style_list = [
            dashdotdotted, loosely_dashdotdotted, loosely_dashdotted,
            loosely_dashed, densely_dashdotted, loosely_dotted, dashdotted,
            densely_dashed, dashed, dotted, solid
        ]
        dash_styles = {}
        for param in param_set:
            dash_styles[param] = dash_style_list.pop()

        ax = sns.lineplot(x="step",
                          y=metric,
                          hue="lr",
                          style="lr",
                          legend="full",
                          palette="colorblind",
                          dashes=dash_styles,
                          ci=None,
                          data=df)
        return ax

    figsize_px = np.array([1600, 800])
    dpi = 100
    figsize_inch = figsize_px / dpi
    fig = plt.figure(figsize=figsize_inch)
    plt.subplot(1, 2, 1)
    ax = plot(relative=False)
    plt.subplot(1, 2, 2)
    ax = plot(relative=True)
    figure_path = output_directory / f"{num_samples}_samples.png"
    plt.suptitle(
        f"uniform sparse sampling\nbaseline\nnum_context={num_samples}")
    plt.savefig(figure_path)


@client.command(name="plot_chamfer_distance_steps_adaptive_lr")
@click.argument("--output-directory", type=str)
def plot_chamfer_distance_steps_adaptive_lr(args):
    run_id = "f23fd0e76fa0"
    output_directory = Path(
        args.output_directory) / "baseline_adaptive_comparison" / run_id
    num_samples = 1000
    mkdir(output_directory)
    directories = [
        f"evaluations/cd_steps/use_sampled_h_False/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0005/{num_samples}_samples/{run_id}",
    ]
    adaptive_directories = [
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0025/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.005/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.01/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.025/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.05/{num_samples}_samples/{run_id}",
    ]

    def add(directories, data_list, param_set, relative, adaptive):
        metric = "relative_chamfer_distance" if relative else "chamfer_distance"
        for directory in directories:
            result, hyperparams, num_data = _summarize(directory)
            decrease_lr_every = result[
                "latent_optimization_decrease_lr_every"][0]
            lr = result["latent_optimization_initial_lr"][0]
            if "use_sampled_h" in result:
                use_sampled_h = result["use_sampled_h"][0]

            for model_id, chamfer_distance_for_each_steps in zip(
                    result["model_id"],
                    result["chamfer_distance_for_each_steps"]):
                for step, chamfer_distance in chamfer_distance_for_each_steps:
                    if chamfer_distance is None:
                        if step == 0:
                            break
                        continue
                    if relative:
                        if step == 0:
                            base_chamfer_distance = chamfer_distance
                            chamfer_distance = 1
                        else:
                            chamfer_distance = chamfer_distance / base_chamfer_distance
                    lr_type = "adaptive" if adaptive else "constant"
                    data_list.append({
                        "model_id": model_id,
                        "step": step,
                        metric: chamfer_distance,
                        "lr": lr,
                        "lr_type": lr_type,
                        "decrease_lr_every": decrease_lr_every,
                    })
                    param_set.add(lr)

    def plot(relative: bool):
        metric = "relative_chamfer_distance" if relative else "chamfer_distance"
        param_set = set()
        data_list = []
        add(directories, data_list, param_set, relative, False)
        add(adaptive_directories, data_list, param_set, relative, True)

        df = pd.DataFrame(data_list)
        grouped = df.groupby("step")
        # print(grouped)
        print(grouped.size())
        print(grouped.mean())
        # mean = grouped.mean()
        # print(mean)

        solid = ()
        loosely_dotted = (1, 10)
        dotted = (1, 1)
        loosely_dashed = (5, 10)
        dashed = (5, 5)
        densely_dashed = (5, 1)
        loosely_dashdotted = (3, 10, 1, 10)
        dashdotted = (3, 3, 1, 3)
        densely_dashdotted = (3, 1, 1, 1)
        dashdotdotted = (3, 3, 1, 3, 1, 3)
        loosely_dashdotdotted = (3, 1, 1, 1, 1, 1)

        dash_style_list = [
            dashdotted, loosely_dashdotdotted, loosely_dashdotted,
            loosely_dashed, loosely_dotted, dashdotdotted, densely_dashdotted,
            densely_dashed, dashed, dotted, solid
        ]
        dash_styles = {}
        for param in sorted(list(param_set)):
            dash_styles[param] = dash_style_list.pop()

        ax = sns.lineplot(x="step",
                          y=metric,
                          hue="lr_type",
                          style="lr",
                          legend="full",
                          palette="colorblind",
                          dashes=dash_styles,
                          ci=None,
                          data=df)
        return ax

    figsize_px = np.array([1600, 800])
    dpi = 100
    figsize_inch = figsize_px / dpi
    fig = plt.figure(figsize=figsize_inch)
    plt.subplot(1, 2, 1)
    ax = plot(relative=False)
    plt.subplot(1, 2, 2)
    ax = plot(relative=True)
    figure_path = output_directory / f"{num_samples}_samples.png"
    plt.suptitle(
        f"uniform sparse sampling\nbaseline\nnum_context={num_samples}")
    plt.savefig(figure_path)


@client.command(name="plot_chamfer_distance_steps_sampled_h_vs_mu")
@click.argument("--output-directory", type=str)
def plot_chamfer_distance_steps_sampled_h_vs_mu(args):
    run_id = "f23fd0e76fa0"
    output_directory = Path(
        args.output_directory) / "baseline_sampled_h_vs_mu" / run_id
    num_samples = 50
    mkdir(output_directory)
    sampled_h_directories = [
        f"evaluations/cd_steps/use_sampled_h_True/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.001/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_True/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.002/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_True/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0005/{num_samples}_samples/{run_id}",
    ]
    mu_directories = [
        f"evaluations/cd_steps/use_sampled_h_False/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.003/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.002/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.001/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0005/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0001/{num_samples}_samples/{run_id}",
    ]

    def add(directories, data_list, param_set, relative):
        metric = "relative_chamfer_distance" if relative else "chamfer_distance"
        for directory in directories:
            result, hyperparams, num_data = _summarize(directory)
            decrease_lr_every = result[
                "latent_optimization_decrease_lr_every"][0]
            lr = result["latent_optimization_initial_lr"][0]
            use_sampled_h = result["use_sampled_h"][0]

            for model_id, chamfer_distance_for_each_steps in zip(
                    result["model_id"],
                    result["chamfer_distance_for_each_steps"]):
                for step, chamfer_distance in chamfer_distance_for_each_steps:
                    if chamfer_distance is None:
                        if step == 0:
                            break
                        continue
                    if relative:
                        if step == 0:
                            base_chamfer_distance = chamfer_distance
                            chamfer_distance = 1
                        else:
                            chamfer_distance = chamfer_distance / base_chamfer_distance
                    h_type = "sampled" if use_sampled_h else "E[h|D]"
                    param = f"{lr}_sampled_h" if use_sampled_h else f"{lr}_E[h|D]"
                    param = lr
                    data_list.append({
                        "model_id": model_id,
                        "step": step,
                        metric: chamfer_distance,
                        "lr": param,
                        "h": h_type,
                        "decrease_lr_every": decrease_lr_every,
                    })
                    param_set.add(param)

    def plot(relative: bool):
        metric = "relative_chamfer_distance" if relative else "chamfer_distance"
        param_set = set()
        data_list = []
        add(sampled_h_directories, data_list, param_set, relative)
        add(mu_directories, data_list, param_set, relative)

        df = pd.DataFrame(data_list)
        grouped = df.groupby("step")
        # print(grouped)
        print(grouped.size())
        print(grouped.mean())
        # mean = grouped.mean()
        # print(mean)

        solid = ()
        loosely_dotted = (1, 10)
        dotted = (1, 1)
        loosely_dashed = (5, 10)
        dashed = (5, 5)
        densely_dashed = (5, 1)
        loosely_dashdotted = (3, 10, 1, 10)
        dashdotted = (3, 3, 1, 3)
        densely_dashdotted = (3, 1, 1, 1)
        dashdotdotted = (3, 3, 1, 3, 1, 3)
        loosely_dashdotdotted = (3, 1, 1, 1, 1, 1)

        dash_style_list = [
            loosely_dashdotdotted, loosely_dashdotted, loosely_dashed, dashed,
            densely_dashdotted, loosely_dotted, dashdotdotted, dashdotted,
            densely_dashed, dotted, solid
        ]
        dash_styles = {}
        for param in sorted(list(param_set)):
            dash_styles[param] = dash_style_list.pop()

        ax = sns.lineplot(x="step",
                          y=metric,
                          hue="h",
                          style="lr",
                          legend="full",
                          palette="colorblind",
                          dashes=dash_styles,
                          ci=None,
                          data=df)
        return ax

    figsize_px = np.array([1600, 800])
    dpi = 100
    figsize_inch = figsize_px / dpi
    fig = plt.figure(figsize=figsize_inch)
    plt.subplot(1, 2, 1)
    ax = plot(relative=False)
    plt.subplot(1, 2, 2)
    ax = plot(relative=True)
    figure_path = output_directory / f"{num_samples}_samples.png"
    plt.suptitle(
        f"uniform sparse sampling\nbaseline\nnum_context={num_samples}")
    plt.savefig(figure_path)


@client.command(name="plot_chamfer_distance_steps_adam_vs_sgd")
@click.argument("--output-directory", type=str)
def plot_chamfer_distance_steps_adam_vs_sgd(args):
    run_id = "f23fd0e76fa0"
    output_directory = Path(
        args.output_directory) / "baseline_adam_vs_sgd" / run_id
    num_samples = 300
    mkdir(output_directory)
    adam_directories = [
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0025/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.005/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.01/{num_samples}_samples/{run_id}",
    ]
    sgd_directories = [
        f"evaluations/cd_steps/optimizer_sgd/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.001/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/optimizer_sgd/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.00075/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/optimizer_sgd/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0005/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/optimizer_sgd/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.00025/{num_samples}_samples/{run_id}",
        f"evaluations/cd_steps/optimizer_sgd/use_sampled_h_False/adaptive_lr/uniform_sparse_sampling/baseline/0501/symmetric/decrease_lr_every_800/decay_factor_0.1/sv2_cars/epoch_5000/latent_optim_800/lr_0.0001/{num_samples}_samples/{run_id}",
    ]

    def add(directories, data_list, param_set, relative, optimizer):
        metric = "relative_chamfer_distance" if relative else "chamfer_distance"
        for directory in directories:
            result, hyperparams, num_data = _summarize(directory)
            decrease_lr_every = result[
                "latent_optimization_decrease_lr_every"][0]
            lr = result["latent_optimization_initial_lr"][0]
            use_sampled_h = result["use_sampled_h"][0]

            for model_id, chamfer_distance_for_each_steps in zip(
                    result["model_id"],
                    result["chamfer_distance_for_each_steps"]):
                for step, chamfer_distance in chamfer_distance_for_each_steps:
                    if chamfer_distance is None:
                        if step == 0:
                            break
                        continue
                    if relative:
                        if step == 0:
                            base_chamfer_distance = chamfer_distance
                            chamfer_distance = 1
                        else:
                            chamfer_distance = chamfer_distance / base_chamfer_distance
                    param = lr
                    data_list.append({
                        "model_id": model_id,
                        "step": step,
                        metric: chamfer_distance,
                        "lr": param,
                        "optimizer": optimizer,
                        "decrease_lr_every": decrease_lr_every,
                    })
                    param_set.add(param)

    def plot(relative: bool):
        metric = "relative_chamfer_distance" if relative else "chamfer_distance"
        param_set = set()
        data_list = []
        add(adam_directories, data_list, param_set, relative, "adam")
        add(sgd_directories, data_list, param_set, relative, "sgd")

        df = pd.DataFrame(data_list)
        grouped = df.groupby("step")
        # print(grouped)
        print(grouped.size())
        print(grouped.mean())
        # mean = grouped.mean()
        # print(mean)

        solid = ()
        loosely_dotted = (1, 10)
        dotted = (1, 1)
        loosely_dashed = (5, 10)
        dashed = (5, 5)
        densely_dashed = (5, 1)
        loosely_dashdotted = (3, 10, 1, 10)
        dashdotted = (3, 3, 1, 3)
        densely_dashdotted = (3, 1, 1, 1)
        dashdotdotted = (3, 3, 1, 3, 1, 3)
        loosely_dashdotdotted = (3, 1, 1, 1, 1, 1)

        dash_style_list = [
            loosely_dashdotdotted, loosely_dashdotted, loosely_dashed, dashed,
            densely_dashdotted, loosely_dotted, dashdotdotted, dashdotted,
            densely_dashed, dotted, solid
        ]
        dash_styles = {}
        for param in sorted(list(param_set)):
            dash_styles[param] = dash_style_list.pop()

        ax = sns.lineplot(x="step",
                          y=metric,
                          hue="optimizer",
                          style="lr",
                          legend="full",
                          palette="colorblind",
                          dashes=dash_styles,
                          ci=None,
                          data=df)
        return ax

    figsize_px = np.array([1600, 800])
    dpi = 100
    figsize_inch = figsize_px / dpi
    fig = plt.figure(figsize=figsize_inch)
    plt.subplot(1, 2, 1)
    ax = plot(relative=False)
    plt.subplot(1, 2, 2)
    ax = plot(relative=True)
    figure_path = output_directory / f"{num_samples}_samples.png"
    plt.suptitle(
        f"uniform sparse sampling\nbaseline\nnum_context={num_samples}")
    plt.savefig(figure_path)


if __name__ == "__main__":
    client()
