import os

import torch

from tests.problems import Ex4
from torchsde import BrownianInterval
from torchsde.settings import LEVY_AREA_APPROXIMATIONS, SDE_TYPES
from . import inspection
from . import utils


def main():
    small_batch_size, large_batch_size, d, m = 16, 16384, 3, 5
    t0, t1, steps, dt = 0., 2., 10, 1e-1
    ts = torch.linspace(t0, t1, steps=steps, device=device)
    dts = tuple(2 ** -i for i in range(1, 7))  # For checking strong order.
    sde = Ex4(d=d, m=m, sde_type=SDE_TYPES.stratonovich).to(device)
    methods = ('midpoint', 'log_ode')
    img_dir = os.path.join(os.path.dirname(__file__), 'plots', 'stratonovich_general')

    y0 = torch.full((small_batch_size, d), fill_value=0.1, device=device)
    bm = BrownianInterval(
        t0=t0, t1=t1, shape=(small_batch_size, m), dtype=y0.dtype, device=device,
        levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
    )
    inspection.inspect_samples(y0, ts, dt, sde, bm, img_dir, methods)

    y0 = torch.full((large_batch_size, d), fill_value=0.1, device=device)
    bm = BrownianInterval(
        t0=t0, t1=t1, shape=(large_batch_size, m), dtype=y0.dtype, device=device,
        levy_area_approximation=LEVY_AREA_APPROXIMATIONS.foster
    )
    inspection.inspect_orders(y0, t0, t1, dts, sde, bm, img_dir, methods)


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.set_default_dtype(torch.float64)
    utils.manual_seed()

    main()
