from example_model_off_diagonal import solve, rand_network, create_results, G
import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns

# diagonal


DEBUG = False


def get_a(d, p):
    return p[0] * d[0] + (p[0] ** 2 * d[0] ** 2) / (p[1] * d[1])


def get_b(d, p, w, r):
    part1 = (-2 * r[0] * d[0] * p[0]) / (d[1] * p[1])
    part2 = d[0] * p[1] * w[0] + d[1] * p[0] * w[1]
    part3 = (d[0] * p[0]) / (d[1] * p[1]) + 1
    return part1 + part2 * part3


def get_c(d, p, w, r):
    part1 = (r[0] * (r[0] - (d[0] * p[1] * w[0]) - (d[1] * p[0] * w[1]))) / (d[1] * p[1])
    part2 = d[1] * p[1] * w[0] * w[1] + d[0] * p[0] * w[0] * w[1]
    return part1 + part2 - r[1]


def get_y(d, p, w, r, x):
    part1 = d[0] * p[0] * x + d[0] * p[1] * w[0] + d[1] * p[0] * w[1]
    return (r[0] - part1) / (d[1] * p[1])


def run():
    p, d, w = rand_network()
    r = create_results()
    a, b, c = get_a(d, p), get_b(d, p, w, r), get_c(d, p, w, r)
    x = solve(a, b, c)
    print(f'{a}, {b}, {c} -> {x is not None}')
    if x is None:
        return False, (a, b, c)
    if DEBUG:
        y = get_y(d, p, w, r, x[0])
        mat = np.array([[x[0], w[0]], [w[1], y]])
        assert np.isclose(d @ mat @ p, r[0])
        assert np.isclose(d @ mat @ mat @ p, r[1])
    return True, (a, b, c)


def main():
    global G
    per_g = {}
    for g in [0.1, 0.5, 1, 10]:
        G = g
        runs = 1000000
        counter = 0
        coeff = [list() for _ in range(3)]
        for i in range(runs):
            res, coeff_now = run()
            if res:
                counter += 1
            for j, v in enumerate(coeff_now):
                coeff[j].append(v)
        print(counter / runs)
        per_g[g] = counter / runs
    print(per_g)
    # for i, letter in enumerate(['a', 'b', 'c']):
    #     plt.figure()
    #     sns.kdeplot(coeff[i])
    #     plt.title(letter)
    #     print(f'{letter} - mean {np.mean(coeff[i])}, var {np.var(coeff[i])}')
    # plt.show()


if __name__ == '__main__':
    main()
