import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# off diagonal


G = 0.1


def get_a(d, b):
    part1 = (d[0] ** 2) * b[1] / d[1]
    part2 = (b[1] ** 2) * d[0] / b[0]
    return -part1 - part2


def get_b(d, b, w, a):
    part1 = a[0] * (d[0] / d[1] + b[1] / b[0])
    part2 = w[1] * (d[0] * b[1] + (b[1] ** 2) * d[1] / b[0])
    part3 = w[0] * (d[0] * b[1] + (d[0] ** 2) * b[0] / d[1])
    return part1 - part2 - part3


def get_c(d, b, w, a):
    part1 = a[0] * (w[0] + w[1])
    part2 = w[0] * w[1] * (d[0] * b[0] + d[1] * b[1])
    return part1 - part2 - a[1]


def solve(a, b, c):
    sqrt = (b ** 2) - 4 * a * c
    if sqrt < 0:
        return None
    return (-b - math.sqrt(sqrt)) / (2 * a), (-b + math.sqrt(sqrt)) / (2 * a)


def get_y(d, b, w, a, x):
    part1 = -d[0] * b[1] * x / (d[1] * b[0])
    part2 = (a[0] - d[1] * b[1] * w[1] - d[0] * b[0] * w[0]) / (d[1] * b[0])
    return part1 + part2


def rand(size, order=1):
    return np.random.normal(scale=math.sqrt(G / order), size=size)


def rand_network():
    p = rand(2, 2)
    d = rand(2, 1)
    w = rand(2, 2)
    return p, d, w


def create_results():
    p, d, w_1 = rand_network()
    w_2 = rand(2, 2)
    w = np.row_stack([w_1, w_2])
    return np.array([d @ w @ p, d @ w @ w @ p])


def run():
    p, d, w = rand_network()
    r = create_results()
    a, b, c = get_a(d, p), get_b(d, p, w, r), get_c(d, p, w, r)
    x = solve(a, b, c)
    print(f'{a}, {b}, {c} -> {x is not None}')
    if x is None:
        return False, (a, b, c)
    y = get_y(d, p, w, r, x[0])
    mat = np.array([[w[0], x[0]], [y, w[1]]])
    assert np.isclose(d @ mat @ p, r[0])
    assert np.isclose(d @ mat @ mat @ p, r[1])
    return True, (a, b, c)


def main():
    # global G
    # per_g = {}
    # for g in [0.1, 0.5, 1, 10]:
    #     G = g
    runs = 1000000
    counter = 0
    coeff = [list() for _ in range(3)]
    for i in range(runs):
        res, coeff_now = run()
        if res:
            counter += 1
        # if all(-10 < v < 10 for v in coeff_now):
        for j, v in enumerate(coeff_now):
            coeff[j].append(v)
    print(counter / runs)
    #     per_g[g] = counter / runs
    # print(per_g)
    for i, letter in enumerate(['a', 'b', 'c']):
        plt.figure()
        sns.kdeplot(coeff[i])
        plt.title(letter)
        print(f'{letter} - mean {np.mean(coeff[i])}, var {np.var(coeff[i])}')
    plt.show()


if __name__ == '__main__':
    main()
