import numpy as np
import torch
import einops
import random
from typing import Tuple
from torch_geometric.data import Data
from torch_geometric.nn import radius_graph, knn_graph

def ode_scheduling(_int, _f, true_codes, t, epsilon, method="rk4"):
    if epsilon < 1e-3:
        epsilon = 0
    if epsilon == 0:
        codes = _int(_f, y0=true_codes[..., 0], t=t, method=method)
    else:
        eval_points = np.random.random(len(t)) < epsilon
        eval_points[-1] = False
        eval_points = eval_points[1:]

        start_i, end_i = 0, None
        codes = []
        for i, eval_point in enumerate(eval_points):
            if eval_point == True:
                end_i = i + 1
                t_seg = t[start_i: end_i + 1]
                res_seg = _int(
                    _f, y0=true_codes[..., start_i], t=t_seg, method=method)

                if len(codes) == 0:
                    codes.append(res_seg)
                else:
                    codes.append(res_seg[1:])
                start_i = end_i
        t_seg = t[start_i:]
        res_seg = _int(_f, y0=true_codes[..., start_i], t=t_seg, method=method)
        if len(codes) == 0:
            codes.append(res_seg)
        else:
            codes.append(res_seg[1:])
        codes = torch.cat(codes, dim=0)
    # (t b l) -> (b l t)
    return torch.movedim(codes, 0, -1)

