import subprocess, json, struct
import numpy as np
import scipy.sparse as sp

N = 200
M = 1000

def build_elliptic(A, B, C, D, E, F, n):
    a_1 = A * n - D * 2
    a_2 = A * n + D * 2
    b = B * (4 * n)
    c_1 = C * n - E * 2
    c_2 = C * n + E * 2
    d = 2 * (A + C) * n + F
    ones = np.ones(n-1)
    zero_ones = np.ones(n)
    zero_ones[0] = 0
    diag_main_main = d * np.ones(n**2)
    diag_main_lower = c_1 * np.concatenate((ones, np.tile(zero_ones, n-1)))
    diag_main_upper = c_2 * np.concatenate((ones, np.tile(zero_ones, n-1)))
    diag_lower_main = a_1 * np.ones(n**2 - n)
    diag_lower_lower = b * np.concatenate((ones, np.tile(zero_ones, n-2)))
    diag_lower_upper = -b * np.concatenate((np.tile(zero_ones, n-1), np.zeros(1)))
    diag_upper_main = a_2 * np.ones(n**2 - n)
    diag_upper_lower = -b * np.concatenate((np.tile(zero_ones, n-1), np.zeros(1)))
    diag_upper_upper = b * np.concatenate((ones, np.tile(zero_ones, n-2)))
    diagss = [diag_lower_lower, diag_lower_main, diag_lower_upper, diag_main_lower, diag_main_main, diag_main_upper, diag_upper_lower, diag_upper_main, diag_upper_upper]
    P = sp.diags(diagss, offsets=[-n-1, -n, -n+1, -1, 0, 1, n-1, n, n+1]).tocsr()
    if(d > 1e-5):
        return P
    elif(d < -1e-5):
        return -P
    return None

def writetobin(A):
    if A.format != 'csr':
        A = A.tocsr()
    with open('A.bin', 'wb') as f:
        f.write(struct.pack('ii', A.shape[0], A.shape[1]))
        nnz = len(A.data)
        f.write(struct.pack('i', nnz))
        for item in [A.indptr, A.indices, A.data]:
            f.write(struct.pack(f'{len(item)}i', *item.astype(np.int32))) if item.dtype != np.float64 else f.write(struct.pack(f'{len(item)}d', *item))
    return None

with open('/root/sor/data/data5/y.json', 'r') as f:
    y = json.load(f)
with open('/root/sor/data/data5/X.json', 'r') as f:
    X = json.load(f)

omega_test = np.array(y).mean()
cmd = ['./e', '-ksp_max_it', '1000']
pre = ['none', 'jacobi', 'bjacobi', 'pbjacobi', 'kaczmarz', 'deflation', 'ilu', 'gamg']
files = np.arange(1, 11, 1).tolist()
key = ['time', 'iter']

value = {'best': {'time':[], 'iter': []}, 'mean': {'time':[], 'iter': []}, 'equal1': {'time':[], 'iter': []}}
for pre_add in pre:
    value[pre_add] = {'time':[], 'iter': []}
for i in files:
    value[str(i)] = {'time':[], 'iter': []}

for i in range(M):
    print(i)
    writetobin(build_elliptic(X[i][0], X[i][1], X[i][2], X[i][3], X[i][4], X[i][5], N))
    cmd_run = cmd.copy()
    cmd_run.append('-pc_sor_omega')
    cmd_run.append(str(y[i]))
    result = subprocess.run(cmd_run, capture_output=True, text=True)
    value['best']['time'].append(float(result.stdout.split()[0]))
    value['best']['iter'].append(int(result.stdout.split()[1]))
    cmd_run = cmd.copy()
    cmd_run.append('-pc_sor_omega')
    cmd_run.append(str(omega_test))
    result = subprocess.run(cmd_run, capture_output=True, text=True)
    value['mean']['time'].append(float(result.stdout.split()[0]))
    value['mean']['iter'].append(int(result.stdout.split()[1]))
    cmd_run = cmd.copy()
    cmd_run.append('-pc_sor_omega')
    cmd_run.append('1.0')
    result = subprocess.run(cmd_run, capture_output=True, text=True)
    value['equal1']['time'].append(float(result.stdout.split()[0]))
    value['equal1']['iter'].append(int(result.stdout.split()[1]))
    for pre_run in pre:
        cmd_run = cmd.copy()
        cmd_run.append('-pc_type')
        cmd_run.append(pre_run)
        result = subprocess.run(cmd_run, capture_output=True, text=True)
        value[pre_run]['time'].append(float(result.stdout.split()[0]))
        value[pre_run]['iter'].append(int(result.stdout.split()[1]))
    for j in files:
        with open('/root/sor/data/data5/y{}.json'.format(str(j)), 'r') as f:
            y_read = json.load(f)
        if (y_read[i] <= 0.0 or y_read[i] >= 2.0):
            value[str(j)]['time'].append(-1.0)
            value[str(j)]['iter'].append(-1)
        else:
            cmd_run = cmd.copy()
            cmd_run.append('-pc_sor_omega')
            cmd_run.append(str(y_read[i]))
            result = subprocess.run(cmd_run, capture_output=True, text=True)
            value[str(j)]['time'].append(float(result.stdout.split()[0]))
            value[str(j)]['iter'].append(int(result.stdout.split()[1]))

for write in ['best', 'mean', 'equal1']:
    with open('time_sor_{}.json'.format(write), 'w') as f:
        json.dump(value[write], f)
for write in pre:
    with open('time_{}.json'.format(write), 'w') as f:
        json.dump(value[write], f)
for write in files:
    with open('time_dso_{}.json'.format(write), 'w') as f:
        json.dump(value[str(write)], f)