from typing import TYPE_CHECKING, Any, Iterator, Optional, Tuple

import torch

if TYPE_CHECKING:
    from ito_vision.methods.iterative_refinement_method import IterativeRefinementMethod
from ito_vision.samplers.sampler import Sampler


class EulerMaruyamaSampler(Sampler):
    def __init__(
        self,
        N: int = 30,
        Lambda: float = 1,  # 0 - deterministic, 1 - stochastic, (0,1) scale of noise
        quiet: bool = False,
    ):
        super().__init__(N, quiet)
        self.Lambda = Lambda

    def f_bar(
        self,
        method: "IterativeRefinementMethod",
        xt: torch.Tensor,
        t: torch.Tensor,
        score: torch.Tensor,
        y: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return (
            method.f(xt, t, y) - 0.5 * (1 + self.Lambda**2) * method.g(t) ** 2 * score
        )

    def __iter__(
        self,
        method: "IterativeRefinementMethod",
        model: torch.nn.Module,
        x1: torch.Tensor,
        ts: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
        dts = torch.diff(ts)
        x = x1.clone()

        for t, dt in zip(ts[:-1], dts):
            score, pred_x0 = method.score(model, x, t, y, **kwargs)

            dw = torch.randn_like(x) * torch.sqrt(torch.abs(dt))
            f_value = self.f_bar(method, x, t, score, y)
            g_value = method.g(t) * self.Lambda
            dx = f_value * dt + g_value * dw
            x += dx

            yield x, pred_x0

        final_prediction = method.pred_x0(model, x, ts[-1], y, **kwargs)
        yield final_prediction, final_prediction
