from typing import Callable

import torch

from margflow.utils.plot_utils import get_orthogonal_vector


def possible_mu_f(
    like: torch.Tensor, dim: int, kind: str, bound: float
) -> Callable[[torch.Tensor], torch.Tensor]:
    match kind:
        case "line":
            z = like.new_zeros(dim)
            o = like.new_ones(dim)
            v = o - z

            def mu_f(t: torch.Tensor) -> torch.Tensor:
                return (2 * bound * t[:, None] - bound) * v

            return mu_f

        case "sin":
            z = like.new_zeros(dim)
            o = like.new_ones(dim)
            v1 = o - z
            v2 = get_orthogonal_vector(v1)
            v1, v2 = v1[None], v2[None]

            def helper(t: torch.Tensor) -> torch.Tensor:
                mu = (2 * bound * t[:, None] - bound) * v1 + 0.4 * bound * torch.sin(
                    4 * torch.pi * t[:, None]
                ) * v2
                return mu

            def mu_f(t: torch.Tensor) -> torch.Tensor:
                curx = helper(t)
                curt = t
                for _ in range(100):
                    ldists = (curx[0:-2] - curx[1:-1]).norm(dim=-1)
                    rdists = (curx[1:-1] - curx[2:]).norm(dim=-1)
                    correction = rdists - ldists
                    corcumsum = torch.cumsum(correction, dim=0)
                    curt[1:-1] = curt[1:-1] + 0.2 * corcumsum / (bound * dim)
                    curx = helper(curt)
                return curx

            return mu_f
        case "circle":

            def mu_f(t: torch.Tensor) -> torch.Tensor:
                x = torch.cos(2 * torch.pi * t[:, None])
                y = torch.sin(2 * torch.pi * t[:, None])
                # the rest of the dimensions are just a linear map -> spiral in higher dim
                z = (2 * t[:, None] - 1).expand((t.shape[0], dim - 2))
                return bound * torch.concatenate([x, y, z], dim=-1)

            return mu_f
        case "spiral":

            def helper(t: torch.Tensor) -> torch.Tensor:
                r = 1 - 0.95 * t[:, None]
                x = r * torch.cos(5 * torch.pi * t[:, None])
                y = r * torch.sin(5 * torch.pi * t[:, None])
                # the rest of the dimensions are just a linear map
                z = (2 * t[:, None] - 1).expand((t.shape[0], dim - 2))
                return bound * torch.concatenate([x, y, z], dim=-1)

            def mu_f(t: torch.Tensor) -> torch.Tensor:
                curx = helper(t)
                curt = t
                eps = 1e-2
                for _ in range(10_000):
                    ldists = (curx[0:-2] - curx[1:-1]).norm(dim=-1)
                    rdists = (curx[1:-1] - curx[2:]).norm(dim=-1)
                    correction = rdists - ldists
                    corcumsum = torch.cumsum(correction, dim=0)
                    curt[1:-1] = curt[1:-1] + eps * corcumsum / (bound * dim)
                    curx = helper(curt)
                return curx

            return mu_f
        case _:
            return lambda x: x
