import numpy as np
import torch
from .network import ScalarNetVec
import multiprocessing

from joblib import Parallel, delayed

def eos_global_min(eta, mu, copies):

    if copies == 2:
        s = (1 / (mu * eta) + np.sqrt(1 / (mu**2 * eta**2) - 4 * mu)) / 2
        x = s ** (1 / copies)
        y = (mu / s) ** (1 / copies)
        return x, y
    elif copies == 1:
        x = np.sqrt((1 + np.sqrt(1 - eta**2 * mu**2))/eta)
        y = mu / x
        return x, y
    else:
        return 0, 0

def comp_sharpness(x, y, mu):
    net = ScalarNetVec([x, y] * 2)
    return net.sharpness(mu)

def get_sharpness(xs, ys, mu, verbose=False):
    assert len(xs) == len(ys)
    return Parallel(n_jobs=multiprocessing.cpu_count(), verbose=verbose)(delayed(comp_sharpness)(xs[i], ys[i], mu) for i in range(len(xs)))

def get_residual(xs, ys, mu):
    assert len(xs) == len(ys)
    return xs ** 2 * ys ** 2 - mu

def comp_closed_form_sharpness(x, y, mu):

    v = x**2 + y**2
    m = x**2 * y**2
    s1 = v**2 * (mu - 3 * m)**2 + 4 * m * (3 * mu - 7 * m) * (mu - m)
    s2 = (3 * m - mu) * v
    l1 = (s2 - np.sqrt(s1)) / 2
    l2 = (s2 + np.sqrt(s1)) / 2
    l3 = x**2 * (mu - m)
    l4 = y**2 * (mu - m)
    return l1, l2, l3, l4


def get_sharpness_grid(x_grid, y_grid, mu):
    x_mesh, y_mesh = np.meshgrid(x_grid, y_grid)
    xy_grid = np.c_[x_mesh.ravel(), y_mesh.ravel()]
    sharpness = np.array(get_sharpness(xy_grid[:, 0], xy_grid[:, 1], mu, verbose=3)).reshape(len(x_grid), len(y_grid))
    return sharpness

def get_eos_mins(eta, mu, copies):
    assert mu == 1
    x = np.sqrt(np.sqrt(-copies**2 + 1/eta**2) + 1/eta)/np.sqrt(copies)
    y = 1/x
    return [x, y]
    
    eos_x, eos_y = eos_global_min(eta, mu, copies)
    eos_mins = [[], []]
    for a in [-1, 1]:
        for b in [-1, 1]:
            eos_mins[0].append(eos_x * a)
            eos_mins[0].append(eos_y * a)
            eos_mins[1].append(eos_y * b)
            eos_mins[1].append(eos_x * b)
    return [np.array(eos_mins[0]), np.array(eos_mins[1])]

def crd2angle(x, y):
    return (np.arctan(y / (x + 1e-18)) + np.pi * (y < 0)) / np.pi * 180

