import torch
from .MultKAN import *


def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None):

    result = {}
    result['test_loss'] = []
    result['c'] = []
    result['G'] = []
    result['id'] = []
    if metrics != None:
        for i in range(len(metrics)):
            result[metrics[i].__name__] = []

    def collect(evaluation):
        result['test_loss'].append(evaluation['test_loss'])
        result['c'].append(evaluation['n_edge'])
        result['G'].append(evaluation['n_grid'])
        result['id'].append(f'{model.round}.{model.state_id}')
        if metrics != None:
            for i in range(len(metrics)):
                result[metrics[i].__name__].append(metrics[i](model, dataset).item())

    for i in range(prune_round):
        # train and prune
        if i == 0:
            model = KAN(width=width, grid=grids[0])
        else:
            model = model.rewind(f'{i-1}.{2*i}')

        model.fit(dataset, steps=steps, lamb=lamb)
        model = model.prune(edge_th=edge_th, node_th=node_th)
        evaluation = model.evaluate(dataset)
        collect(evaluation)

        for j in range(refine_round):
            model = model.refine(grids[j])
            model.fit(dataset, steps=steps)
            evaluation = model.evaluate(dataset)
            collect(evaluation)

    for key in list(result.keys()):
        result[key] = np.array(result[key])

    return result


def pareto_frontier(x,y):

    pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0]
    x_pf = x[pf_id]
    y_pf = y[pf_id]

    return x_pf, y_pf, pf_id