import subprocess
import json
import heapq
import numpy as np


def find_longest_consecutive(nums):
    max_length = 1
    current_length = 1
    max_start = 0
    current_start = 0
    for i in range(1, len(nums)):
        if nums[i] == nums[i-1] + 1:
            current_length += 1
        else:
            if current_length > max_length:
                max_length = current_length
                max_start = current_start
            current_length = 1
            current_start = i
    if current_length > max_length:
        max_length = current_length
        max_start = current_start
    return (max_start, max_length)


def best_sor(n, dataname, maxit):
    fileA = '/home/wh_linux/{}/A_{}_{}.dat'.format(
        dataname, n % 60 + 1, n//60 + 1)
    fileb = '/home/wh_linux/{}/b_{}_{}.dat'.format(
        dataname, n % 60 + 1, n//60 + 1)
    cmd = ['./e', '-file_A', fileA, '-file_b', fileb, '-ksp_max_it',
           str(maxit), '-pc_type', 'sor', '-pc_sor_omega']
    its = []
    pars = np.arange(0.01, 2, 0.01)
    for i in range(len(pars)):
        cmd_run = cmd.copy()
        cmd_run.append(str(pars[i]))
        result = subprocess.run(cmd_run, capture_output=True, text=True)
        try:
            it_run = int(result.stdout.split()[1])
        except:
            it_run = maxit
        its.append(it_run)
    smallest_5 = heapq.nsmallest(5, its)
    positions = [i for i in range(len(its)) if its[i] in smallest_5]
    if min(smallest_5) == max(smallest_5):
        start, length = find_longest_consecutive(positions)
        return (pars[positions[start]] + pars[positions[start+length-1]]) / 2
    if len(positions) >= 10:
        positions = [i for i in range(len(its)) if its[i] == min(smallest_5)]
    pars_p = []
    its_p = []
    for i in positions:
        if i == 0:
            pars_p = pars_p + [0.001, 0.002, 0.003, 0.004]
        elif i == len(pars):
            pars_p = pars_p + [1.996, 1.997, 1.998, 1.999]
        else:
            pars_p = pars_p + \
                np.arange(pars[i]-0.005, pars[i]+0.005, 0.001).tolist()
    for par in pars_p:
        cmd_run = cmd.copy()
        cmd_run.append(str(par))
        result = subprocess.run(cmd_run, capture_output=True, text=True)
        try:
            it_run = int(result.stdout.split()[1])
        except:
            it_run = maxit
        its_p.append(it_run)
    return pars_p[np.argmin(np.array(its_p))]


dataname = 'Bih4225'
X = []
y = []

with open('/home/wh_linux/sor/Bihar/its.json', 'r') as f:
    its_equal1 = json.load(f)

for i in range(800):
    if its_equal1[i] < 510 and its_equal1[i] > 40:
        print(i)
        par_best = best_sor(i, dataname, int(
            max(its_equal1[i]+30, its_equal1[i]*1.3)))
        X.append([i % 60 + 1, i//60 + 1])
        y.append(par_best)

with open('X.json', 'w') as f:
    json.dump(X, f)
with open('y.json', 'w') as f:
    json.dump(y, f)
