import numpy as np
import scipy.sparse as sp
import struct, json, subprocess, heapq

N = 200 #有限差分边长,实际线性方程组系数矩阵边长是N^2
M = 1000 #数据集大小

#创建二元二阶线性椭圆PDE
def build_elliptic(A, B, C, D, E, F, n):
    a_1 = A * n - D * 2
    a_2 = A * n + D * 2
    b = B * (4 * n)
    c_1 = C * n - E * 2
    c_2 = C * n + E * 2
    d = 2 * (A + C) * n + F
    ones = np.ones(n-1)
    zero_ones = np.ones(n)
    zero_ones[0] = 0
    diag_main_main = d * np.ones(n**2)
    diag_main_lower = c_1 * np.concatenate((ones, np.tile(zero_ones, n-1)))
    diag_main_upper = c_2 * np.concatenate((ones, np.tile(zero_ones, n-1)))
    diag_lower_main = a_1 * np.ones(n**2 - n)
    diag_lower_lower = b * np.concatenate((ones, np.tile(zero_ones, n-2)))
    diag_lower_upper = -b * np.concatenate((np.tile(zero_ones, n-1), np.zeros(1)))
    diag_upper_main = a_2 * np.ones(n**2 - n)
    diag_upper_lower = -b * np.concatenate((np.tile(zero_ones, n-1), np.zeros(1)))
    diag_upper_upper = b * np.concatenate((ones, np.tile(zero_ones, n-2)))
    diagss = [diag_lower_lower, diag_lower_main, diag_lower_upper, diag_main_lower, diag_main_main, diag_main_upper, diag_upper_lower, diag_upper_main, diag_upper_upper]
    P = sp.diags(diagss, offsets=[-n-1, -n, -n+1, -1, 0, 1, n-1, n, n+1]).tocoo()
    if(d > 1e-5):
        return P
    elif(d < -1e-5):
        return -P
    return None

def writetobin(A):
    if A.format != 'csr':
        A = A.tocsr()
    with open('A.bin', 'wb') as f:
        f.write(struct.pack('ii', A.shape[0], A.shape[1]))
        nnz = len(A.data)
        f.write(struct.pack('i', nnz))
        for item in [A.indptr, A.indices, A.data]:
            f.write(struct.pack(f'{len(item)}i', *item.astype(np.int32))) if item.dtype != np.float64 else f.write(struct.pack(f'{len(item)}d', *item))
    return None

def run(max_it=1000, pre='sor', var=None, var_value=1.0):
    cmd = ['./e', '-ksp_max_it', str(max_it), '-pc_type']
    cmd.append(pre)
    if var != None:
        cmd.append(var)
        cmd.append(str(var_value))
    out = []
    print(cmd)
    result = subprocess.run(cmd, capture_output=True, text=True)
    out.append(float(result.stdout.split()[0]))
    out.append(int(result.stdout.split()[1]))
    out.append(float(result.stdout.split()[6]))
    return out

def best_dichotomy(A, pre='sor', var='-pc_sor_omega', var_value=[0.0, 2.0], accuracy=11, index=0, is_min=True):
    writetobin(A)
    tol = (var_value[1] - var_value[0]) / 4
    i_0, i, j, k, k_0 = var_value[0], var_value[0] + tol, var_value[0] + 2*tol, var_value[0] + 3*tol, var_value[1]
    result_i = run(pre=pre, var=var, var_value=i)
    result_j = run(pre=pre, var=var, var_value=j)
    result_k = run(pre=pre, var=var, var_value=k)
    if result_i[1]==1000 and result_j[1]==1000 and result_k[1]==1000:
        return -1
    for t in range(accuracy):
        if((result_i[index] <= result_j[index] and is_min) or (result_i[index] >= result_j[index] and not is_min)):
            k_0 = j
            k = (i+j)/2
            j = i
            i = (i_0+j)/2
            result_j = result_i
            result_i = run(pre=pre, var=var, var_value=i)
            result_k = run(pre=pre, var=var, var_value=k)
        elif((result_k[index] <= result_j[index] and is_min) or (result_k[index] >= result_j[index] and not is_min)):
            i_0 = j
            i = (j+k)/2
            j = k
            k = (j+k_0)/2
            result_j = result_k
            result_i = run(pre=pre, var=var, var_value=i)
            result_k = run(pre=pre, var=var, var_value=k)
        else:
            i_0 = i
            i = (i+j)/2
            k = (j+k)/2
            k_0 = k
            result_i = run(pre=pre, var=var, var_value=i)
            result_k = run(pre=pre, var=var, var_value=k)
    return j

def best_enumeration(A, pre='sor', var='-pc_sor_omega', var_value=[0.0, 2.0], index=0, is_max=False):
    writetobin(A)
    results = []
    pars = np.arange(var_value[0]+0.02, var_value[1], 0.02)
    for par in pars:
        result = run(pre=pre, var=var, var_value=par)
        results.append(result[index])
    pars_precise = []
    results_precise = []
    if is_max:
        m = np.max(np.array(results))
        for i in range(len(results)):
            if results[i]/m > 0.95:
                t = var_value[0] + 0.02*(i+1)
                pars_precise = pars_precise + np.arange(t-0.01, t+0.01, 0.001).tolist()
    else:
        m = np.min(np.array(results))
        for i in range(len(results)):
            if results[i]/m < 1.05:
                t = var_value[0] + 0.02*(i+1)
                pars_precise = pars_precise + np.arange(t-0.01, t+0.01, 0.001).tolist()
    for par in pars_precise:
        result = run(pre=pre, var=var, var_value=par)
        results_precise.append(result[index])
    if is_max:
        return pars_precise[np.argmax(np.array(results_precise))]
    return pars_precise[np.argmin(np.array(results_precise))]


def best_gamg(A):
    writetobin(A)
    cmd = ['./e', '-ksp_max_it', '1000', '-pc_type', 'gamg', '-pc_gamg_threshold']
    cmd_run = cmd.copy()
    cmd_run.append('1')
    print(cmd_run)
    result = subprocess.run(cmd_run, capture_output=True, text=True)
    rcond = float(result.stdout.split()[6])
    rconds = []
    pars = np.arange(0, 1, 0.01)
    for i in range(len(pars)):
        cmd_run = cmd.copy()
        cmd_run.append(str(pars[i]))
        print(cmd_run)
        result = subprocess.run(cmd_run, capture_output=True, text=True)
        rcond_run = float(result.stdout.split()[6])
        if abs(rcond_run - rcond) < 1e-3:
            rconds.append(rcond_run)
            break
        rconds.append(rcond_run)
    smallest_5 = heapq.nsmallest(5, rconds)
    positions = [i for i in range(len(rconds)) if rconds[i] in smallest_5]
    pars_p = []
    rconds_p = []
    for i in positions:
        if i == 0:
            pars_p = pars_p + [0, 0.001, 0.002, 0.003, 0.004]
        elif i == len(pars):
            pars_p = pars_p + [0.996, 0.997, 0.998, 0.999, 1]
        else:
            pars_p = pars_p + \
                np.arange(pars[i]-0.005, pars[i]+0.005, 0.001).tolist()
    for par in pars_p:
        cmd_run = cmd.copy()
        cmd_run.append(str(par))
        print(cmd_run)
        result = subprocess.run(cmd_run, capture_output=True, text=True)
        rcond_run = float(result.stdout.split()[6])
        if abs(rcond_run - rcond) < 1e-3:
            rconds_p.append(rcond_run)
            break
        rconds_p.append(rcond_run)
    return pars_p[np.argmin(np.array(rconds_p))]

#生成系数矩阵，这里参数是随机生成的，如果你希望使用原有数据集里面的参数，请读一下原有的json文件，按顺序六个参数就是A-F
X = []
y = []
i=0
while(i<200):
    A = np.random.random()
    B = np.random.random() * 2 / N
    C = np.random.random()
    D = np.random.random() * 2 - 1
    E = np.random.random() * 2 - 1
    F = np.random.random() * 2 - 1
    X_0 = build_elliptic(A, B, C, D, E, F, N)
    omega = best_dichotomy(X_0, pre='eisenstat', var='-pc_eisenstat_omega', index=1)
    if omega<0 or omega >2:
        continue
    X.append([A,B,C,D,E,F])
    y.append(omega)
    i = i+1
    print(i)
with open('X_test.json', 'w') as f:
    json.dump(X, f)
with open('y_test.json', 'w') as f:
    json.dump(y, f)