import numpy as np
import json, sys, subprocess, heapq
sys.path.append('/home/wh_linux/')
import sor.custom as cus
from numbers import Number

def run(var):
    cmd = ['./e', '-ksp_max_it', '1000', '-pc_eisenstat_omega', str(var)]
    out = []
    result = subprocess.run(cmd, capture_output=True, text=True)
    try:
        out.append(float(result.stdout.split()[0]))
        out.append(float(result.stdout.split()[1]))
        out.append(float(result.stdout.split()[2]))
        out.append(int(result.stdout.split()[3]))
    except:
        return var
    return out

def best_ssor(A):
    cus.writetobin(A)
    pars = np.arange(0.05, 2, 0.05)
    its = []
    for par in pars:
        result = run(par)
        if isinstance(result, Number):
            return par+2
        its.append(result[3])
    if min(its) == 1000:
        return -1
    smallest_5 = heapq.nsmallest(5, its)
    positions = [i for i in range(len(its)) if its[i] in smallest_5]
    pars_p = []
    its_p = []
    for i in positions:
        pars_p = pars_p + np.arange(pars[i]-0.025, pars[i]+0.025, 0.001).tolist()
    for par in pars_p:
        result = run(par)
        if isinstance(result, Number):
            return par+2
        its_p.append(result[3])
    return pars_p[np.argmin(np.array(its_p))]

N = 72 #有限差分边长,实际线性方程组系数矩阵边长是N^2

#生成系数矩阵，这里参数是随机生成的，如果你希望使用原有数据集里面的参数，请读一下原有的json文件，按顺序六个参数就是A-F
X = []
y = []
X_error = []
y_error = []
i=0
while(i<100):
    k = np.random.rand(16)
    A= cus.build_chebyshev(4, N, k)
    X_0 = cus.build_darcy(A)
    omega = best_ssor(X_0)
    if omega == -1:
        continue
    elif omega > 2:
        X_error.append(k.tolist())
        y_error.append(omega-2)
        continue
    X.append(k.tolist())
    y.append(omega)
    i = i+1
    print(i)
with open('X.json', 'w') as f:
    json.dump(X, f)
with open('y.json', 'w') as f:
    json.dump(y, f)
with open('X_error.json', 'w') as f:
    json.dump(X_error, f)
with open('y_error.json', 'w') as f:
    json.dump(y_error, f)