import sys
import numpy as np
import torch
from sorcerun.git_utils import get_repo

repo = get_repo()
ROOT = repo.working_dir
if ROOT not in sys.path:
    sys.path.append(ROOT)

from globals import SIGN, SQRT, INV, PROOT, RESHAPE

LOSSES = {}


# %%
def sign_loss(x, y, custom_loss=False):
    # separate torch and numpy
    if isinstance(x, np.ndarray):
        loss = np.linalg.norm(x*x - np.ones_like(x))
        if custom_loss == True:
            loss = np.sum(np.maximum(np.abs(np.abs(x) - 1) - 0.3, np.zeros_like(y)))
    else:
        loss = torch.linalg.norm(x*x - torch.ones_like(x)).item()

    d = x.shape[0]
    rel_loss = loss / np.sqrt(d).item()

    return rel_loss


LOSSES[SIGN] = sign_loss


# %%
def sqrt_loss(x, y, custom_loss=False):
    # separate torch and numpy
    if isinstance(x, np.ndarray):
        loss = np.linalg.norm(x * x - y) / np.linalg.norm(y)
    else:
        loss = torch.linalg.norm(x * x - y) / torch.linalg.norm(y)

    return loss


LOSSES[SQRT] = sqrt_loss


# %%
def inv_loss(x, y, custom_loss=False):
    # separate torch and numpy
    if isinstance(x, np.ndarray):
        loss = np.linalg.norm(x * y - 1) / np.linalg.norm(y)
    else:
        loss = torch.linalg.norm(x * y - 1) / torch.linalg.norm(y)

    return loss


LOSSES[INV] = inv_loss


def proot_loss(x, y, custom_loss = False):
    # separate torch and numpy
    if isinstance(x, np.ndarray):
        loss = np.linalg.norm(x * x * x - y) / np.linalg.norm(y)
    else:
        loss = torch.linalg.norm(x * x * x - y) / torch.linalg.norm(y)

    return loss


LOSSES[PROOT] = proot_loss

def reshape_spectrum_loss(x, y, custom_loss = False):

    if isinstance(x, np.ndarray):
        oy = np.where(y < 0.5, 0.5, y)
    else:
        oy = torch.where(y < 0.5, torch.tensor(0.5, dtype=y.dtype, device=y.device), y)
    if isinstance(x, np.ndarray):
        loss = np.linalg.norm(x - oy) / np.linalg.norm(y)
    else:
        loss = torch.linalg.norm(x - oy) / torch.linalg.norm(y)

    return loss

LOSSES[RESHAPE] = reshape_spectrum_loss