import numpy as np
import json, sys, subprocess, heapq
sys.path.append('/home/wh_linux/')
import sor.custom as cus

def run(var):
    cmd = ['./e', '-pc_type', 'eisenstat', '-pc_eisenstat_omega', str(var)]
    out = []
    result = subprocess.run(cmd, capture_output=True, text=True)
    out.append(float(result.stdout.split()[0]))
    out.append(float(result.stdout.split()[1]))
    out.append(float(result.stdout.split()[2]))
    out.append(int(result.stdout.split()[3]))
    return out

def best_ssor(A):
    cus.writetobin(A)
    par_cri = []
    for par in np.arange(1, 2, 0.001):
        result = run(par)
        if result[1]/result[0] < 1e-5:
            par_cri.append(par)
            break
    for par in np.arange(1, 0, -0.001):
        result = run(par)
        if result[1]/result[0] < 1e-5:
            par_cri.append(par)
            break
    d1 = (par_cri[1]-0.001)/50
    d2 = (1.999-par_cri[0])/50
    pars1 = np.arange(0.001, par_cri[1]+0.001, d1)
    pars2 = np.arange(par_cri[0], 2, d2)
    its1 = []
    its2 = []
    for par in pars1:
        result = run(par)
        its1.append(result[3])
    for par in pars2:
        result = run(par)
        its2.append(result[3])
    its_com = its1+its2
    smallest_5 = heapq.nsmallest(5, its_com)
    positions = [i for i in range(len(its_com)) if its_com[i] in smallest_5]
    pars = []
    its = []
    for i in positions:
        if i < len(pars1):
            pars = pars + np.arange(pars1[i]-(d1/2), pars1[i]+(d1/2), d1/20).tolist()
        else:
            pars = pars + np.arange(pars2[i-len(pars1)]-(d2/2), pars2[i-len(pars1)]+(d2/2), d2/20).tolist()
    for par in pars:
        result = run(par)
        its.append(result[3])
    return pars[np.argmin(np.array(its))]

N = 72 #有限差分边长,实际线性方程组系数矩阵边长是N^2

#生成系数矩阵，这里参数是随机生成的，如果你希望使用原有数据集里面的参数，请读一下原有的json文件，按顺序六个参数就是A-F
X = []
y = []
i=0
while(i<100):
    k = np.random.rand(16)
    A= cus.build_chebyshev(4, N, k)
    X_0 = cus.build_darcy(A)
    X_1 = cus.symmetrize(X_0)
    try:
        omega = best_ssor(X_1)
    except:
        continue
    X.append(k.tolist())
    y.append(omega)
    i = i+1
    print(i)
with open('X_1.json', 'w') as f:
    json.dump(X, f)
with open('y_1.json', 'w') as f:
    json.dump(y, f)
