import subprocess, json, sys, os
import numpy as np
sys.path.append('/home/wh_linux/')
import sor.custom as cus
from numbers import Number

source_folder = '/home/wh_linux/sor/Bihar/amg'
folder = os.path.join(source_folder, 'tol_1e-5')

if not os.path.exists(folder):
    os.makedirs(folder)

def run(var, m, n, tol=1e-5):
    cmd = ['./e', '-file_A', '/home/wh_linux/Bih4225/A_{}_{}.dat'.format(str(m), str(n)), 
           '-file_b', '/home/wh_linux/Bih4225/b_{}_{}.dat'.format(str(m), str(n)), 
           '-pc_type', 'gamg', '-pc_gamg_threshold', str(var), '-ksp_rtol', str(tol)]
    out = []
    result = subprocess.run(cmd, capture_output=True, text=True)
    out.append(float(result.stdout.split()[6]))
    return out

def get_its(X, y, tol=1e-5):
    its = []
    if isinstance(y, Number):
        for i in range(len(X)):
            it = []
            try:
                it = run(y, X[i][0], X[i][1], tol=tol)
            except:
                its.append('error')
                continue
            its.append(it)
            print(i)
    else:
        for i in range(len(X)):
            it = []
            if y[i]>0 and y[i]<=1:
                try:
                    it = run(y[i], X[i][0], X[i][1], tol=tol)
                except:
                    its.append('error')
                    continue
                its.append(it)
            else:
                try:
                    it = run(1.0, X[i][0], X[i][1], tol=tol)
                except:
                    its.append('error')
                    continue
                its.append(it)
            print(i)
    return its

def get_its_none(X, tol=1e-5):
    its = []
    for i in range(len(X)):
        cmd = ['./e', '-file_A', '/home/wh_linux/Bih4225/A_{}_{}.dat'.format(str(X[i][0]), str(X[i][1])), 
               '-file_b', '/home/wh_linux/Bih4225/A_{}_{}.dat'.format(str(X[i][0]), str(X[i][1])), 
               '-pc_type', 'none', '-ksp_rtol', str(tol)]
        result = subprocess.run(cmd, capture_output=True, text=True)
        its.append(float(result.stdout.split()[6]))
        print(i)

with open('X_test.json', 'r') as f:
    X_0 = json.load(f)
with open('y_test.json', 'r') as f:
    y_0 = json.load(f)
X = X_0[:100]
y = y_0[:100]

pars = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

for par in pars:
    its = get_its(X, par, tol=1e-5)
    with open('{}/tol_1e-5/rcond_equal_{}.json'.format(source_folder, str(par)), 'w') as f:
        json.dump(its, f)

for i in range(1, 11):
    with open('y_test1_{}.json'.format(str(i)), 'r') as f:
        y_cal = json.load(f)
    its = get_its(X, y_cal)
    with open('{}/tol_1e-5/rcond_test1_{}.json'.format(source_folder, str(i)), 'w') as f:
        json.dump(its, f)
    with open('y_test2_{}.json'.format(str(i)), 'r') as f:
        y_cal = json.load(f)
    its = get_its(X, y_cal)
    with open('{}/tol_1e-5/rcond_test2_{}.json'.format(source_folder, str(i)), 'w') as f:
        json.dump(its, f)
        
its = get_its(X, y)
with open('{}/tol_1e-5/rcond_best.json'.format(source_folder), 'w') as f:
    json.dump(its, f)
    
its = get_its_none(X)
with open('{}/tol_1e-5/rcond_none.json'.format(source_folder), 'w') as f:
    json.dump(its, f)