import struct, subprocess, os, shutil
import numpy as np
import scipy.sparse as sp
import numpy.polynomial.chebyshev as cb

#创建二元二阶线性椭圆PDE
def build_elliptic(A, B, C, D, E, F, n):
    a_1 = A * n - D * 2
    a_2 = A * n + D * 2
    b = B * (4 * n)
    c_1 = C * n - E * 2
    c_2 = C * n + E * 2
    d = 2 * (A + C) * n + F
    ones = np.ones(n-1)
    zero_ones = np.ones(n)
    zero_ones[0] = 0
    diag_main_main = d * np.ones(n**2)
    diag_main_lower = c_1 * np.concatenate((ones, np.tile(zero_ones, n-1)))
    diag_main_upper = c_2 * np.concatenate((ones, np.tile(zero_ones, n-1)))
    diag_lower_main = a_1 * np.ones(n**2 - n)
    diag_lower_lower = b * np.concatenate((ones, np.tile(zero_ones, n-2)))
    diag_lower_upper = -b * np.concatenate((np.tile(zero_ones, n-1), np.zeros(1)))
    diag_upper_main = a_2 * np.ones(n**2 - n)
    diag_upper_lower = -b * np.concatenate((np.tile(zero_ones, n-1), np.zeros(1)))
    diag_upper_upper = b * np.concatenate((ones, np.tile(zero_ones, n-2)))
    diagss = [diag_lower_lower, diag_lower_main, diag_lower_upper, diag_main_lower, diag_main_main, diag_main_upper, diag_upper_lower, diag_upper_main, diag_upper_upper]
    P = sp.diags(diagss, offsets=[-n-1, -n, -n+1, -1, 0, 1, n-1, n, n+1]).tocoo()
    if(d > 1e-5):
        return P
    elif(d < -1e-5):
        return -P
    return None

#将稀疏矩阵写入文件A.bin
def writetobin(A, filename='A.bin'):
    if A.format != 'csr':
        A = A.tocsr()
    with open(filename, 'wb') as f:
        f.write(struct.pack('ii', A.shape[0], A.shape[1]))
        nnz = len(A.data)
        f.write(struct.pack('i', nnz))
        for item in [A.indptr, A.indices, A.data]:
            f.write(struct.pack(f'{len(item)}i', *item.astype(np.int32))) if item.dtype != np.float64 else f.write(struct.pack(f'{len(item)}d', *item))
    return None

#调用文件e求解线性方程组
def run(filerun='e', pre='sor', maxit=1000, var=None, var_value=1.0):
    cmd = ['./{}'.format(filerun), '-ksp_max_it', str(maxit), '-pc_type']
    cmd.append(pre)
    if var != None:
        cmd.append(var)
        cmd.append(str(var_value))
    out = []
    result = subprocess.run(cmd, capture_output=True, text=True)
    out.append(float(result.stdout.split()[0]))
    out.append(int(result.stdout.split()[1]))
    out.append(float(result.stdout.split()[6]))
    return out

#二分法求最优参数(只适用于凸函数的参数选取,三个函数代表三种不同策略,分别是中心优先,左侧优先,右侧优先)
def best_dichotomy(A, filename='A.bin', filerun='e', maxit=1000, pre='sor', var='-pc_sor_omega', var_value=[0.0, 2.0], accuracy=11, index=0, is_min=True):
    writetobin(A, filename=filename)
    tol = (var_value[1] - var_value[0]) / 4
    i_0, i, j, k, k_0 = var_value[0], var_value[0] + tol, var_value[0] + 2*tol, var_value[0] + 3*tol, var_value[1]
    result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
    result_j = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=j)
    result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
    for t in range(accuracy):
        if((result_i[index] < result_j[index] and is_min) or (result_i[index] > result_j[index] and not is_min)):
            k_0 = j
            k = (i+j)/2
            j = i
            i = (i_0+j)/2
            result_j = result_i
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
        elif((result_k[index] < result_j[index] and is_min) or (result_k[index] > result_j[index] and not is_min)):
            i_0 = j
            i = (j+k)/2
            j = k
            k = (j+k_0)/2
            result_j = result_k
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
        else:
            i_0 = i
            i = (i+j)/2
            k = (j+k)/2
            k_0 = k
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
    return j

def best_dichotomy1(A, filerun='e', filename='A.bin', maxit=1000, pre='sor', var='-pc_sor_omega', var_value=[0.0, 2.0], accuracy=11, index=0, is_min=True):
    writetobin(A, filename=filename)
    tol = (var_value[1] - var_value[0]) / 4
    i_0, i, j, k, k_0 = var_value[0], var_value[0] + tol, var_value[0] + 2*tol, var_value[0] + 3*tol, var_value[1]
    result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
    result_j = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=j)
    result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
    for t in range(accuracy):
        if((result_i[index] <= result_j[index] and is_min) or (result_i[index] >= result_j[index] and not is_min)):
            k_0 = j
            k = (i+j)/2
            j = i
            i = (i_0+j)/2
            result_j = result_i
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
        elif((result_k[index] <= result_j[index] and is_min) or (result_k[index] >= result_j[index] and not is_min)):
            i_0 = j
            i = (j+k)/2
            j = k
            k = (j+k_0)/2
            result_j = result_k
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
        else:
            i_0 = i
            i = (i+j)/2
            k = (j+k)/2
            k_0 = k
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
    return j

def best_dichotomy2(A, filerun='e', filename='A.bin', maxit=1000, pre='sor', var='-pc_sor_omega', var_value=[0.0, 2.0], accuracy=11, index=0, is_min=True):
    writetobin(A, filename=filename)
    tol = (var_value[1] - var_value[0]) / 4
    i_0, i, j, k, k_0 = var_value[0], var_value[0] + tol, var_value[0] + 2*tol, var_value[0] + 3*tol, var_value[1]
    result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
    result_j = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=j)
    result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
    for t in range(accuracy):
        if((result_k[index] <= result_j[index] and is_min) or (result_k[index] >= result_j[index] and not is_min)):
            i_0 = j
            i = (j+k)/2
            j = k
            k = (j+k_0)/2
            result_j = result_k
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
        elif((result_i[index] <= result_j[index] and is_min) or (result_i[index] >= result_j[index] and not is_min)):
            k_0 = j
            k = (i+j)/2
            j = i
            i = (i_0+j)/2
            result_j = result_i
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
        else:
            i_0 = i
            i = (i+j)/2
            k = (j+k)/2
            k_0 = k
            result_i = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=i)
            result_k = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=k)
    return j

#网格法求最优参数
def best_enumeration(A, filerun='e', filename='A.bin', maxit=1000, pre='sor', var='-pc_sor_omega', var_value=[0.0, 2.0], index=0, is_max=False):
    writetobin(A, filename=filename)
    results = []
    pars = np.arange(var_value[0]+0.02, var_value[1], 0.02)
    for par in pars:
        result = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=par)
        results.append(result[index])
    pars_precise = []
    results_precise = []
    if is_max:
        m = np.max(np.array(results))
        for i in range(len(results)):
            if results[i]/m > 0.95:
                t = var_value[0] + 0.02*(i+1)
                pars_precise = pars_precise + np.arange(t-0.01, t+0.02, 0.001).tolist()
    else:
        m = np.min(np.array(results))
        for i in range(len(results)):
            if results[i]/m < 1.05:
                t = var_value[0] + 0.02*(i+1)
                pars_precise = pars_precise + np.arange(t-0.01, t+0.02, 0.001).tolist()
    for par in pars_precise:
        result = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=par)
        results_precise.append(result[index])
    if is_max:
        return pars_precise[np.argmax(np.array(results_precise))]
    return pars_precise[np.argmin(np.array(results_precise))]

#为GAMG预处理特性设计的最优参数求解
def best_gamg(A, filerun='e', filename='A.bin', maxit=1000, pre='gamg', var='-pc_gamg_threshold'):
    writetobin(A, filename=filename)
    results = []
    pars = np.arange(0, 1.1, 0.1)
    for par in pars:
        print(par)
        result = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=par)
        print(result)
        results.append(result[2])
    if results[10] - np.min(np.array(results)) < 1e-5:
        for i in  range(11):
            if abs(results[i] - np.min(np.array(results))) < 1e-5:
                return pars[i]
    pars_precise = []
    results_precise = []
    for i in range(11):
        if results[i] - np.min(np.array(results)) < 1:
            if i == 0:
                pars_precise = pars_precise + [0.0, 0.01, 0.02, 0.03, 0.04]
            elif i == 10:
                pars_precise = pars_precise + [0.95, 0.96, 0.97, 0.98, 0.99, 1.0]
            else:
                pars_precise = pars_precise + np.arange(i/10-0.05, i/10+0.05, 0.01).tolist()
    for par in pars_precise:
        print(par)
        result = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=par)
        print(result)
        results_precise.append(result[2])
    pars_pp = []
    results_pp = []
    for i in range(len(pars_precise)):
        if results_precise[i] - np.min(np.array(results_precise)) < 0.5:
            if pars_precise[i] == 0.0:
                pars_pp = pars_pp + [0.0, 0.001, 0.002, 0.003, 0.004]
            elif pars_precise[i] == 1.0:
                pars_pp = pars_pp + [0.995, 0.996, 0.997, 0.998, 0.999, 1.0]
            else:
                pars_pp = pars_pp + np.arange(pars_precise[i]-0.005, pars_precise[i]+0.005, 0.001).tolist()
    for par in pars_pp:
        print(par)
        result = run(filerun=filerun, maxit=maxit, pre=pre, var=var, var_value=par)
        print(result)
        results_pp.append(result[2])
    return pars_pp[np.argmin(np.array(results_pp))]

#创建二维Chebyshev多项式逼近的函数的采样矩阵
def build_chebyshev(n, N, k):
    X, Y = np.meshgrid(np.arange(0, 1+1/(N+1), 1/(N+1)), np.arange(0, 1+1/(N+1), 1/(N+1)))
    K = []
    for i in range(N):
        for j in range(N):
            m = 0
            for i_0 in range(n):
                for j_0 in range(n):
                    x_cheby = cb.Chebyshev([int(t==i_0) for t in range(n)])
                    y_cheby = cb.Chebyshev([int(t==j_0) for t in range(n)])
                    m = m + k[i_0 + n*j_0] * x_cheby.__call__(X[i][j]) * y_cheby.__call__(Y[i][j])
            K.append(m)
    return np.array(K).reshape(N, N)

#利用系数函数的采样矩阵创建darcyflow方程
def build_darcy(coef):
    K = coef.shape[0]
    s = K - 2
    diag_list = []
    off_diag_list = []
    for j in range(1, K-1):
        diag_values = np.array([
            np.concatenate((-0.5 * (coef[1:K-2, j] + coef[2:K-1, j]),[0])),
            0.5 * (coef[0:K-2, j] + coef[1:K-1, j]) + 0.5 * (coef[2:K, j] + coef[1:K-1, j]) + \
            0.5 * (coef[1:K-1, j-1] + coef[1:K-1, j]) + 0.5 * (coef[1:K-1, j+1] + coef[1:K-1, j]),
            np.concatenate((-0.5 * (coef[1:K-2, j] + coef[2:K-1, j]),[0]))
        ])
        diag_list.append(diag_values)
        if j != K-2:
            off_diag = -0.5 * (coef[1:K-1, j] + coef[1:K-1, j+1])
            off_diag_list.append(off_diag)
    diag_output = np.concatenate(diag_list,axis=1)
    off_diag_output = np.concatenate(off_diag_list,axis=0)
    A = (sp.diags(diag_output,[-1,0,1],(s**2,s**2)) + sp.diags((off_diag_output,off_diag_output),[-(K-2),(K-2)],(s**2,s**2))) * (K-1)**2
    return A

#将文件在目标文件夹中复制若干份一样的并重命名
def filecopy(source_file_path, target_folder_path, num):
    os.makedirs(target_folder_path, exist_ok=True)
    file_name, file_extension = os.path.splitext(os.path.basename(source_file_path))
    for i in range(1, num + 1):
        new_file_name = '{}_{}{}'.format(file_name, str(i), file_extension)
        new_file_path = os.path.join(target_folder_path, new_file_name)
        shutil.copy(source_file_path, new_file_path)
    return None

#对称化系数矩阵
def symmetrize(A):
    if A.format != 'csr':
        A = A.tocsr()
    n = A.shape[0]
    A_T = A.transpose()
    O = sp.csr_matrix((n, n))
    I = sp.eye(n, format='csr')
    block = sp.bmat([[O, A_T], [A, I]])
    return block

# 找到一个排好序的,无重复的,全是整数的列表里,最长的连续整数列的位置(用于结合heapq库判断网格粗筛时的数据结构)
def find_longest_consecutive(nums):
    max_length = 1
    current_length = 1
    max_start = 0
    current_start = 0
    for i in range(1, len(nums)):
        if nums[i] == nums[i-1] + 1:
            current_length += 1
        else:
            if current_length > max_length:
                max_length = current_length
                max_start = current_start
            current_length = 1
            current_start = i
    if current_length > max_length:
        max_length = current_length
        max_start = current_start
    return (max_start, max_length)
    