import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import seaborn as sns
from matplotlib import cm

from margflow.marginal_flow import MarginalFlow


# TODO file needs cleanup


def plot_samples(samples, max_samples=10_000, bounds=1, alpha=0.1):
    if samples.shape[-1] == 2:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot()
        ax.scatter(samples[:max_samples, 0], samples[:max_samples, 1], marker=".", alpha=alpha)
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_xlim(-bounds, bounds)
        ax.set_ylim(-bounds, bounds)
        plt.show()
    elif samples.shape[-1] == 3:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(projection="3d")
        ax.scatter(
            samples[:max_samples, 0],
            samples[:max_samples, 1],
            samples[:max_samples, 2],
            marker=".",
            alpha=alpha,
        )
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")
        ax.set_box_aspect([1, 1, 1])
        ax.set_xlim(-bounds, bounds)
        ax.set_ylim(-bounds, bounds)
        ax.set_zlim(-bounds, bounds)
        plt.show()
    else:
        print(f"Skipping plot because d={samples.shape[-1]}")


# def plot_samples_time(dataset, model, n_samples, n_inner_samples=2048):
#     samples, context = dataset.sample(n_samples)
#     samples_np = samples.detach().cpu().numpy()
#     context_np = context.detach().cpu().numpy()
#     time_points = np.unique(context.detach().cpu().numpy())
#     n_time_points = time_points.shape[0]
#     samples_per_time = n_samples // n_time_points
#     bounds = 10
#     for i, time_step in enumerate(time_points):
#         time_point = torch.tensor([[time_step]], device=model.device).repeat(samples_per_time,1)
#         _, samples_model, _ = model.sample(n_outer_samples=samples_per_time, n_inner_samples=n_inner_samples, context=time_point)
#         samples_model = samples_model.detach().cpu().numpy()
#         plt.scatter(samples_model[:, 0], samples_model[:, 1], label="flow")
#         sample_indices = (context_np == time_step)[:,0]
#         plt.scatter(samples_np[:, sample_indices][:,0], samples_np[:, sample_indices][:, 1], label="gt")
#         plt.xlim(-bounds, bounds)
#         plt.ylim(-bounds, bounds)
#         plt.show()


def plot_samples_time(
    dataset, model, n_samples, n_timesteps=16, n_outer_samples=512, n_inner_samples=512
):
    samples, context = dataset.sample_all(n_samples, n_timesteps=n_timesteps, ordered=True)
    samples_np = samples.detach().cpu().numpy()
    _, samples_model, _ = model.sample_all(
        n_mixtures=n_outer_samples, n_samples=n_inner_samples, context=context
    )
    samples_model_np = samples_model.detach().cpu().numpy()
    n_time_points = context.shape[0]
    bounds = 10
    for i in range(n_time_points):
        plt.scatter(samples_model_np[:, i][:, 0], samples_model_np[:, i][:, 1], label="flow")
        plt.scatter(samples_np[:, i][:, 0], samples_np[:, i][:, 1], label="gt")
        plt.xlim(-bounds, bounds)
        plt.ylim(-bounds, bounds)
        plt.legend()
        plt.show()


def sample_animation(
    dataset, model, n_samples, model_name, n_timesteps=16, n_outer_samples=512, n_inner_samples=256
):
    assert model_name in ["marginal_flow", "normalizing_flow"]
    import matplotlib.animation as animation

    samples, context = dataset.sample(n_samples, n_timesteps=n_timesteps, ordered=True)
    samples_np = samples.detach().cpu().numpy()
    if model_name == "marginal_flow":
        samples_model = model.sample(
            n_samples=n_outer_samples, n_mixtures=n_inner_samples, context=context
        )
    elif model_name == "normalizing_flow":
        samples_model = model.sample(n_samples=n_inner_samples, context=context)
        # samples_model = torch.transpose(samples_model, 0,1)# mb_c, mb_x, dim --> mb_x, mb_c, dim
    else:
        raise ValueError("model name must be one of 'marginal_flow', 'normalizing_flow'")
    samples_model_np = samples_model.detach().cpu().numpy()

    min_x, max_x = min(samples_np[..., 0].min(), samples_model_np[..., 0].min()), max(
        samples_np[..., 0].max(), samples_model_np[..., 0].max()
    )
    min_y, max_y = min(samples_np[..., 1].min(), samples_model_np[..., 1].min()), max(
        samples_np[..., 1].max(), samples_model_np[..., 1].max()
    )

    fig, ax = plt.subplots()
    ax.set_xlim(min_x, max_x)
    ax.set_ylim(min_y, max_y)

    i = 0
    sc1 = ax.scatter(
        samples_model_np[i, :][:, 0], samples_model_np[i, :][:, 1], alpha=0.5, label="flow"
    )
    sc2 = ax.scatter(samples_np[i, :][:, 0], samples_np[i, :][:, 1], alpha=0.5, label="gt")

    ax.legend()

    # Update function for animation
    def update(frame):
        sc1.set_offsets(np.c_[samples_model_np[frame, :][:, 0], samples_model_np[frame, :][:, 1]])
        sc2.set_offsets(np.c_[samples_np[frame, :][:, 0], samples_np[frame, :][:, 1]])
        return sc1, sc2

    # Create animation
    num_frames = context.shape[0]
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=100, blit=True)
    writergif = animation.PillowWriter(fps=4)
    ani.save(f"{model_name}.gif", writer=writergif)


def plot_mog(
    model: MarginalFlow,
    target,
    n_outer_samples,
    flow=None,
    n_gridpoints=100,
    plot_surface=False,
    bound=1.0,
    idxs="",
):
    if model.x_dim != 2:
        # contains some projection code but is mostly the same.
        plot_nd_slice(model, target, n_outer_samples, n_gridpoints, plot_surface, bound, idxs)
    else:
        x_range = torch.linspace(
            -1.1 * bound, 1.1 * bound, n_gridpoints, device=model.device, dtype=model.dtype
        )
        grid = torch.cartesian_prod(x_range, -x_range)
        means_flow, samples_flow, log_prob_flow, _ = model.sample_and_log_prob(
            n_mixtures=n_outer_samples, n_samples=5 * n_gridpoints, x=grid
        )
        samples_flow = samples_flow.detach().cpu().numpy()
        log_prob_flow = log_prob_flow.detach().cpu().numpy()

        samples_target = target.sample_all(n_samples=5 * n_gridpoints).detach().cpu().numpy()
        log_prob_target = target.log_prob(x=grid).detach().cpu().numpy()
        target_means = target.means

        if flow is not None:
            try:
                samples_nflow = flow.sample_all(n_samples=5 * n_gridpoints).detach().cpu().numpy()
                log_prob_nflow = flow.log_prob(x=grid).detach().cpu().numpy()
            except:
                samples_nflow = (
                    flow.sample_all(sample_shape=torch.Size([5 * n_gridpoints]))
                    .detach()
                    .cpu()
                    .numpy()
                )
                log_prob_nflow = flow.exact_log_prob(x=grid).log_prob.detach().cpu().numpy()
            fig, axes = plt.subplots(nrows=2, ncols=3)
        else:
            fig, axes = plt.subplots(nrows=2, ncols=2)

        axes[0, 0].imshow(np.exp(log_prob_target.reshape(n_gridpoints, n_gridpoints, order="F")))
        axes[0, 1].imshow(np.exp(log_prob_flow.reshape(n_gridpoints, n_gridpoints, order="F")))
        if flow is not None:
            axes[0, 2].imshow(
                np.exp(log_prob_nflow.reshape(n_gridpoints, n_gridpoints, order="F"))
            )
        axes[1, 0].scatter(
            target_means.detach().cpu().numpy()[:, 0],
            target_means.detach().cpu().numpy()[:, 1],
            s=30,
            label="target means",
        )
        axes[1, 0].scatter(
            samples_target[:, 0], samples_target[:, 1], s=5, alpha=0.5, label="target samples"
        )
        # axes[1, 0].legend()
        axes[1, 1].scatter(
            target_means.detach().cpu().numpy()[:, 0],
            target_means.detach().cpu().numpy()[:, 1],
            s=30,
            label="target means",
        )
        axes[1, 1].scatter(
            samples_flow[:, 0], samples_flow[:, 1], s=5, alpha=0.5, label="flow samples"
        )
        axes[1, 1].scatter(
            means_flow.detach().cpu().numpy()[:, 0],
            means_flow.detach().cpu().numpy()[:, 1],
            s=5,
            label="flow means",
            alpha=0.1,
        )

        if flow is not None:
            axes[1, 2].scatter(
                samples_nflow[:, 0], samples_nflow[:, 1], s=5, alpha=0.5, label="normflow samples"
            )
        # axes[1, 1].legend()
        plt.savefig(f"imshow_{idxs}.pdf")
        plt.show()

        if plot_surface:
            fig, axes = plt.subplots(subplot_kw={"projection": "3d"}, nrows=1, ncols=2)
            axes[0].plot_surface(
                grid[:, 0].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
                grid[:, 1].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
                np.exp(log_prob_target).reshape(n_gridpoints, n_gridpoints),
                cmap=cm.coolwarm,
                linewidth=0,
                antialiased=False,
            )
            axes[1].plot_surface(
                grid[:, 0].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
                grid[:, 1].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
                np.exp(log_prob_flow).reshape(n_gridpoints, n_gridpoints),
                cmap=cm.coolwarm,
                linewidth=0,
                antialiased=False,
            )
            plt.savefig(f"surface_{idxs}.pdf")
            plt.show()


def plot_nd_slice(
    model: MarginalFlow,
    target,
    n_outer_samples,
    n_gridpoints=100,
    plot_surface=False,
    bound=1.0,
    idxs="",
):
    if model.x_dim == 2:
        x_range = torch.linspace(
            -1.1 * bound, 1.1 * bound, n_gridpoints, device=model.device, dtype=model.dtype
        )
        grid = torch.cartesian_prod(x_range, -x_range)
        means_flow, samples_flow, log_prob_flow, _ = model.sample_and_log_prob(
            n_mixtures=n_outer_samples, n_samples=5 * n_gridpoints, x=grid
        )
    elif model.x_dim > 2:
        means_flow, samples_flow, log_prob_flow, _ = model.sample_and_log_prob(
            n_mixtures=n_outer_samples, n_samples=5 * n_gridpoints
        )
        zs_mean = means_flow.mean(0)
        # _, projm, _ = project(means_flow, d=2, type="eig")
        _, projm, _ = project(means_flow, d=2, type="svd")
        # TODO projm is used once for projecting but inverse not used
        zs_range = torch.linspace(-2.5, 2.5, n_gridpoints, device=model.device, dtype=model.dtype)
        zs_grid = torch.cartesian_prod(zs_range, -zs_range)
        grid = zs_mean[None] + (zs_grid @ projm.t())  # uplifts the grid
        means_flow, samples_flow, log_prob_flow, _ = model.sample_and_log_prob(
            n_mixtures=n_outer_samples, n_samples=5 * n_gridpoints, x=grid
        )
        means_flow = means_flow @ projm
        samples_flow = samples_flow @ projm

    samples_flow = samples_flow.detach().cpu().numpy()
    log_prob_flow = log_prob_flow.detach().cpu().numpy()

    # evaluate true distribution
    if model.x_dim == 2:
        samples_target = target.sample_all(n_samples=5 * n_gridpoints).detach().cpu().numpy()
        log_prob_target = target.log_prob(x=grid).detach().cpu().numpy()
        target_means = target.means
    elif model.x_dim > 2:
        samples_target = target.sample_all(n_samples=5 * n_gridpoints)
        samples_target = (samples_target @ projm).detach().cpu().numpy()
        log_prob_target = target.log_prob(x=grid).detach().cpu().numpy()
        target_means = target.means @ projm
    fig, axes = plt.subplots(nrows=2, ncols=2)
    axes[0, 0].imshow(np.exp(log_prob_target.reshape(n_gridpoints, n_gridpoints, order="F")))
    axes[0, 1].imshow(np.exp(log_prob_flow.reshape(n_gridpoints, n_gridpoints, order="F")))
    axes[1, 0].scatter(
        target_means.detach().cpu().numpy()[:, 0], target_means.detach().cpu().numpy()[:, 1], s=30
    )
    axes[1, 0].scatter(samples_target[:, 0], samples_target[:, 1], s=5, alpha=0.5)
    axes[1, 0].scatter(samples_flow[:, 0], samples_flow[:, 1], s=5, alpha=0.5)
    axes[1, 1].scatter(
        target_means.detach().cpu().numpy()[:, 0], target_means.detach().cpu().numpy()[:, 1], s=30
    )
    axes[1, 1].scatter(
        means_flow.detach().cpu().numpy()[:, 0],
        means_flow.detach().cpu().numpy()[:, 1],
        s=5,
        alpha=0.5,
    )
    plt.savefig(f"imshow_slice_{idxs}.pdf")
    plt.show()

    if plot_surface:
        fig, axes = plt.subplots(subplot_kw={"projection": "3d"}, nrows=1, ncols=2)
        axes[0].plot_surface(
            grid[:, 0].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
            grid[:, 1].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
            np.exp(log_prob_flow).reshape(n_gridpoints, n_gridpoints),
            cmap=cm.coolwarm,
            linewidth=0,
            antialiased=False,
        )
        axes[1].plot_surface(
            grid[:, 0].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
            grid[:, 1].detach().cpu().numpy().reshape(n_gridpoints, n_gridpoints),
            np.exp(log_prob_target).reshape(n_gridpoints, n_gridpoints),
            cmap=cm.coolwarm,
            linewidth=0,
            antialiased=False,
        )
        plt.savefig(f"surface_slice_{idxs}.pdf")
        plt.show()


def plot_nd_project(
    model: MarginalFlow, target, n_outer_samples, n_gridpoints=100, n_projects=4, idxs=""
):
    if hasattr(model, "dim"):
        assert model.dim > 2
    elif hasattr(model, "x_dim"):
        assert model.x_dim > 2

    means_flow_free, samples_flow_free, log_prob_flow_free, _ = model.sample_and_log_prob(
        n_mixtures=n_outer_samples, n_samples=5 * n_gridpoints
    )
    # _, projm, allbase = project(means_flow_free, d=2, type="eig")
    _, projm, allbase = project(means_flow_free, d=2, type="svd")
    # TODO projm is used once for projecting but inverse not used

    # TODO project to n_projects number of dimensions and add these to the plot
    means_flow = means_flow_free @ projm
    samples_flow = samples_flow_free @ projm
    samples_flow = samples_flow.detach().cpu().numpy()

    # evaluate true distribution
    samples_target = target.sample_all(n_samples=5 * n_gridpoints)
    samples_target = (samples_target @ projm).detach().cpu().numpy()
    target_means = target.means @ projm
    fig, axes = plt.subplots(nrows=1, ncols=2)
    axes[0].scatter(
        target_means.detach().cpu().numpy()[:, 0], target_means.detach().cpu().numpy()[:, 1], s=30
    )
    axes[0].scatter(samples_target[:, 0], samples_target[:, 1], s=5, alpha=0.5)
    axes[0].scatter(samples_flow[:, 0], samples_flow[:, 1], s=5, alpha=0.5)
    axes[1].scatter(
        target_means.detach().cpu().numpy()[:, 0], target_means.detach().cpu().numpy()[:, 1], s=30
    )
    axes[1].scatter(
        means_flow.detach().cpu().numpy()[:, 0],
        means_flow.detach().cpu().numpy()[:, 1],
        s=5,
        alpha=0.5,
    )
    plt.savefig(f"imshow_project_{idxs}.pdf")
    plt.show()


def plot_samples_2D(
    model: MarginalFlow,
    target,
    n_samples,
    n_outer_samples,
    other_models,
    bound=2,
    signature=None,
    manifold=False,
    repeat=1,
):
    assert model.x_dim == 2, "plot function only for 3D models/datasets"
    n_other_models = len(other_models.keys()) if other_models is not None else 0

    with torch.no_grad():
        means_flow, samples_flow, log_prob_flow, _ = model.sample_and_log_prob(
            n_mixtures=n_outer_samples, n_samples=n_samples
        )
        samples_flow = samples_flow.detach().cpu().numpy()
        train_data, _, test_data = target.load_dataset()
        samples_target = target.sample(n_samples=n_samples).detach().cpu().numpy()
        samples_other_models = []
        if n_other_models > 0:
            for model_ in other_models.values():
                try:
                    samples = model_.sample(n_samples=n_samples).detach().cpu().numpy()
                except:
                    samples = model_.sample(torch.Size([n_samples])).detach().cpu().numpy()
                samples_other_models.append(samples)

    figsize = (20, 10)
    marker_size = 50
    fig, axes = plt.subplots(nrows=1, ncols=4 + n_other_models, figsize=figsize)
    axes[0].scatter(samples_target[:, 0], samples_target[:, 1], s=5, alpha=0.15)
    axes[1].scatter(train_data[:, 0], train_data[:, 1], s=5, alpha=1)
    axes[2].scatter(test_data[:, 0], test_data[:, 1], s=5, alpha=1)
    axes[3].scatter(samples_flow[:, 0], samples_flow[:, 1], s=5, alpha=0.15)
    if manifold:
        axes[3].scatter(
            means_flow.detach().cpu().numpy()[:, 0],
            means_flow.detach().cpu().numpy()[:, 1],
            s=marker_size // 2,
            label="learnt manifold",
            alpha=0.45,
        )
    # try:
    #     mapped_base_means = model.network(model.base_means)
    # except:
    #     mapped_base_means = model.trainable_means
    # axes[1].scatter(
    #     mapped_base_means.detach().cpu().numpy()[:, 0],
    #     mapped_base_means.detach().cpu().numpy()[:, 1],
    #     s=marker_size,
    #     label="base means",
    #     alpha=0.75,
    #     color="k",
    #     marker="*",
    # )
    if n_other_models > 0:
        for i, model_name in enumerate(other_models.keys()):
            axes[4 + i].scatter(
                samples_other_models[i][:, 0], samples_other_models[i][:, 1], s=5, alpha=0.15
            )
            new_model = other_models[model_name]
            # if isinstance(new_model, MarginalFlow):
            #     mapped_base_means = new_model.trainable_means
            #     axes[2 + i].scatter(
            #         mapped_base_means.detach().cpu().numpy()[:, 0],
            #         mapped_base_means.detach().cpu().numpy()[:, 1],
            #         s=marker_size,
            #         label="trainable means",
            #         alpha=0.75,
            #         color="k",
            #         marker="*",
            #     )
    for i in range(4 + n_other_models):
        axes[i].legend(fontsize="13")
        axes[i].set_aspect("equal", adjustable="box")
        axes[i].set_xlim(-bound, bound)
        axes[i].set_ylim(-bound, bound)
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        if i == 0:
            model_title = "gt"
        elif i == 1:
            model_title = "train samples"
        elif i == 2:
            model_title = "test samples"
        elif i == 3:
            model_title = "marginal flow"
        else:
            model_title = list(other_models.keys())[i - 4]
        axes[i].set_title(model_title, fontdict={"fontsize": 25})
    plt.tight_layout()
    plt.savefig(f"./plots/{signature}_samples_{repeat}.pdf", dpi=300)
    plt.savefig(f"./plots/{signature}_samples_{repeat}.png", dpi=300)
    plt.show()
    plt.show()


def plot_likelihood_2D(
    model: MarginalFlow, target, grid_size, n_outer_samples, other_models, bound=2, signature=None
):
    assert model.x_dim == 2, "plot function only for 3D models/datasets"
    n_other_models = len(other_models.keys()) if other_models is not None else 0

    x = torch.linspace(-bound, bound, grid_size)
    grid = torch.meshgrid(x, x, indexing="xy")
    grid = torch.stack([grid[0].flatten(), grid[1].flatten()], dim=1).to(model.device)

    gt_logp = target.log_prob(grid)
    gt_lik = torch.exp(gt_logp).cpu().reshape(grid_size, grid_size).detach().numpy()
    model_logp = model.log_prob(n_mixtures=n_outer_samples, x=grid)
    model_lik = torch.exp(model_logp).cpu().reshape(grid_size, grid_size).detach().numpy()

    fig, axes = plt.subplots(nrows=1, ncols=2 + n_other_models, figsize=(20, 10))
    axes[0].imshow(gt_lik, extent=(-bound, bound, -bound, bound), origin="lower")
    axes[1].imshow(model_lik, extent=(-bound, bound, -bound, bound), origin="lower")
    if n_other_models > 0:
        for i, model_name in enumerate(other_models.keys()):
            try:
                other_model_logp = other_models[model_name].log_prob(grid, exact=True)
            except:
                try:
                    other_model_logp = other_models[model_name].exact_log_prob(grid)[2]
                except:
                    other_model_logp = other_models[model_name].log_prob(grid)
            other_model_lik = (
                torch.exp(other_model_logp).cpu().reshape(grid_size, grid_size).detach().numpy()
            )
            # free form flow sometimes returns spiked values --> remove outliers
            log_lik_flattened = other_model_lik.flatten()
            mean = log_lik_flattened[log_lik_flattened > 0].mean()
            std = log_lik_flattened[log_lik_flattened > 0].std()
            other_model_lik = np.clip(other_model_lik, a_min=0, a_max=mean + 5 * std)
            axes[2 + i].imshow(
                other_model_lik,
                extent=(-bound, bound, -bound, bound),
                origin="lower",
                label=model_name,
            )
    for i in range(2 + n_other_models):
        axes[i].set_aspect("equal", adjustable="box")
        axes[i].set_xlim(-bound, bound)
        axes[i].set_ylim(-bound, bound)
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        if i == 0:
            model_title = "gt"
        elif i == 1:
            model_title = "marginal flow"
        else:
            model_title = list(other_models.keys())[i - 2]
        axes[i].set_title(model_title, fontdict={"fontsize": 25})
    fig.colorbar(
        cm.ScalarMappable(cmap="viridis"), ax=axes, orientation="vertical", label="density"
    )
    plt.savefig(f"./plots/{signature}_loglik.pdf", dpi=300)
    plt.savefig(f"./plots/{signature}_loglik.png", dpi=300)
    plt.show()


def plot_metrics_runtime(models, metrics):
    models_df = {}
    for name in models.keys():
        models_df[name] = pd.read_json(f"{models[name].model_path}.json")
    metrics_df = pd.concat(models_df, names=["models"]).reset_index(level=0)
    for metric in metrics:
        sns.lineplot(data=metrics_df, x="runtime", y=metric, hue="models", marker="o")
        plt.xscale("log")
        plt.yscale("log")
        plt.show()


def project(x: torch.Tensor, d=2, type="svd") -> (torch.Tensor, torch.Tensor, torch.Tensor):
    xm = x.mean(-1, keepdim=True)
    xc = x - xm
    cov = torch.mm(xc.t(), xc) / (xc.size(0) - 1)

    match type:
        case "svd":
            lu, fu = project_svd_cov(cov, d=d)
        case "eig":
            lu, fu = project_pca_cov(cov, d=d)
        case _:
            raise ValueError(f"Type {type} not supported")
    return torch.mm(x, lu), lu, fu


def project_pca_cov(cov: torch.Tensor, d=2) -> (torch.Tensor, torch.Tensor):
    # find important eigenvectors. the eigenvalues are in ascending order
    eigvals, eigvecs = torch.linalg.eigh(cov)
    eigvals = torch.flip(eigvals, dims=[-1])
    eigvecs = torch.flip(eigvecs, dims=[-1])
    seigv = torch.sqrt(eigvals)[None] * eigvecs
    # CAREFUL! all eigvectors with dim > d are normalized and not scaled by any eigenvalue
    seigv[:, d:] = eigvecs[:, d:]
    lseigv = seigv[:, :d]
    return lseigv, seigv


def project_svd_cov(cov: torch.Tensor, d=2) -> (torch.Tensor, torch.Tensor):
    # find svd decomposition. descending order of importance
    svd = torch.linalg.svd(cov, full_matrices=False)
    svh = (torch.sqrt(svd.S[:, None]) * svd.Vh).t()
    # CAREFUL! all eigvectors with dim > d are normalized and not scaled by any eigenvalue
    svh[:, d:] = svd.Vh.t()[:, d:]
    lsvh = svh[:, :d]
    return lsvh, svh


def get_orthogonal_vector(v):
    rand_vec = torch.randn_like(v)
    return orthonormalize_second(v, rand_vec)


def orthonormalize_second(v, v2):
    # gram schmidt of second vector -> no checks in case it doesn't work
    proj = torch.dot(v, v2) / torch.dot(v, v) * v
    orth_vec = v2 - proj
    return orth_vec / torch.norm(orth_vec)
