from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, Tuple

import torch

if TYPE_CHECKING:
    from ito_vision.methods.iterative_refinement_method import IterativeRefinementMethod
from ito_vision.samplers.sampler import Sampler


class EIODESampler(Sampler):
    def __init__(self, N: int = 30, f_evals: int = 100, quiet: bool = False):
        super().__init__(N, quiet)
        self.F_EVALS = f_evals

    def approx_int(
        self,
        f: Callable[[float], float],
        lower_bound: float,
        upper_bound: float,
    ) -> float:
        width = (upper_bound - lower_bound) / self.F_EVALS
        total = 0.0

        for i in range(self.F_EVALS):
            x = lower_bound + i * width
            total += f(x) * width

        return total

    def sample_from_EI(
        self,
        method: "IterativeRefinementMethod",
        x: torch.Tensor,
        pred_x0: torch.Tensor,
        y: Optional[torch.Tensor],
        s: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        first_term = method.transition_lambda_x(s, s=t) * x

        second_term = -y * self.approx_int(
            lambda ti: method.transition_lambda_x(s, s=ti) * method.b(ti), s, t
        )
        third_term = pred_x0 * self.approx_int(
            lambda ti: method.transition_lambda_x(s, s=ti)
            * method.g(ti) ** 2
            * method.transition_lambda_x(ti)
            / method.transition_std(ti) ** 2,
            s,
            t,
        )

        fourth_term = y * self.approx_int(
            lambda ti: method.transition_lambda_x(s, s=ti)
            * method.g(ti) ** 2
            * method.transition_lambda_y(ti)
            / method.transition_std(ti) ** 2,
            s,
            t,
        )

        fifth_term = -x * self.approx_int(
            lambda ti: method.transition_lambda_x(s, s=ti)
            * method.g(ti) ** 2
            / method.transition_std(ti) ** 2,
            s,
            t,
        )

        return first_term + second_term + 0.5 * (third_term + fourth_term + fifth_term)

    def __iter__(
        self,
        method: "IterativeRefinementMethod",
        model: torch.nn.Module,
        x1: torch.Tensor,
        ts: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
        dts = torch.diff(ts)
        x = x1.clone()

        for t, dt in zip(ts[:-1], dts):
            pred_x0 = method.pred_x0(model, x, t, y, **kwargs)

            x = self.sample_from_EI(method, x, pred_x0, y, t + dt, t)

            yield x, pred_x0

        final_prediction = method.pred_x0(model, x, ts[-1], y, **kwargs)
        yield final_prediction, final_prediction
