from dataclasses import dataclass
from typing import List

import torch
from torch import nn, optim
from torchsde import SDEIto

from .sdeintloss import SdeIntLoss
from .piis import PathIntLoss


@dataclass
class TrainingResult:
    clock_time: List[float]
    loss: List[float]
    mse: List[float]


class Timer:
    cum_time = 0
    last_time = 0

    def start(self):
        from time import time

        self.last_time = time()

    def stop(self):
        from time import time

        self.cum_time += time() - self.last_time

    @property
    def current(self):
        from time import time

        return self.cum_time + time() - self.last_time


def train(
    sde: SDEIto,
    obs_times,
    obs,
    *,
    callback: callable = None,
    use_sdeint=True,
    lr=1e-3,
    steps=10_000,
    dt=1e-2,
    n_samples=100,
    **kwargs,
) -> TrainingResult:
    opt = optim.Adam(sde.parameters(), lr=lr)
    if use_sdeint:
        loss_fn = SdeIntLoss(nn.MSELoss(), n_samples=n_samples, dt=dt, **kwargs)
        mse_fn = loss_fn
    else:
        loss_fn = PathIntLoss(sde, obs_times, obs, dt, n_samples)
        mse_fn = SdeIntLoss(nn.MSELoss(), n_samples=n_samples, dt=dt, **kwargs)

    res = TrainingResult([], [], [])

    timer = Timer()
    timer.start()
    for step in range(1, steps + 1):
        opt.zero_grad()

        # from time import time
        # start = time()
        loss = loss_fn(sde, obs_times, obs)
        # print("forward pass:", f"{time() - start:.2e}")
        loss.backward()

        opt.step()
        res.clock_time.append(timer.current)
        res.loss.append(loss.item())

        if step % 10 == 0:
            # don't record mse calculation time
            timer.stop()

            with torch.no_grad():
                mse = mse_fn(sde, obs_times, obs)
                res.mse.append(mse)
                if callback is not None:
                    callback(step, res)

                timer.start()

    return res
