# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
from torchdiffeq import odeint

from flow_matching.solver.solver import Solver
from flow_matching.utils import gradient, ModelWrapper

import os
import math
from PIL import Image
import numpy as np


def _to_uint8_nhwc_from_minus1_1(x: torch.Tensor) -> np.ndarray:
    """
    x: BCHW in [-1, 1] -> NHWC uint8 [0,255]
    """
    x = x.detach().cpu().clamp(-1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(0, 2, 3, 1).contiguous().numpy()
    x = (255.0 * x).astype(np.uint8)
    return x


def _make_grid_pil(bchw: torch.Tensor, nrow: int = 5, pad: int = 2, pad_val: int = 255) -> Image.Image:
    """
    bchw [-1,1] -> tiled PIL image
    """
    arr = _to_uint8_nhwc_from_minus1_1(bchw)  # [B,H,W,C]
    B, H, W, C = arr.shape
    ncol = (B + nrow - 1) // nrow  # rows
    grid_h = ncol * H + (ncol + 1) * pad
    grid_w = nrow * W + (nrow + 1) * pad
    canvas = np.full((grid_h, grid_w, C), pad_val, dtype=np.uint8)

    idx = 0
    for r in range(ncol):
        for c in range(nrow):
            if idx >= B:
                break
            y0 = pad + r * (H + pad)
            x0 = pad + c * (W + pad)
            canvas[y0:y0+H, x0:x0+W, :] = arr[idx]
            idx += 1

    return Image.fromarray(canvas)


def _save_grid(bchw: torch.Tensor, out_path: str, nrow: int = 5):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    _make_grid_pil(bchw, nrow=nrow).save(out_path)

class ODESolver(Solver):
    """A class to solve ordinary differential equations (ODEs) using a specified velocity model.

    This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

    Args:
        velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)`
    """

    def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
        super().__init__()
        self.velocity_model = velocity_model

    def sample(
        self,
        x_init: Tensor,
        step_size: Optional[float],
        method: str = "euler",
        atol: float = 1e-5,
        rtol: float = 1e-5,
        time_grid: Tensor = torch.tensor([0.0, 1.0]),
        return_intermediates: bool = False,
        enable_grad: bool = False,
        **model_extras,
    ) -> Union[Tensor, Sequence[Tensor]]:
        r"""Solve the ODE with the velocity field.

        Example:

        .. code-block:: python

            import torch
            from flow_matching.utils import ModelWrapper
            from flow_matching.solver import ODESolver

            class DummyModel(ModelWrapper):
                def __init__(self):
                    super().__init__(None)

                def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
                    return torch.ones_like(x) * 3.0 * t**2

            velocity_model = DummyModel()
            solver = ODESolver(velocity_model=velocity_model)
            x_init = torch.tensor([0.0, 0.0])
            step_size = 0.001
            time_grid = torch.tensor([0.0, 1.0])

            result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)

        Args:
            x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...].
            step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
            method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
            atol (float): Absolute tolerance, used for adaptive step solvers.
            rtol (float): Relative tolerance, used for adaptive step solvers.
            time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
            return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
            enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
            **model_extras: Additional input for the model.

        Returns:
            Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
        """

        time_grid = time_grid.to(x_init.device)

        def ode_func(t, x):
            return self.velocity_model(x=x, t=t, **model_extras)

        ode_opts = {"step_size": step_size} if step_size is not None else {}

        with torch.set_grad_enabled(enable_grad):
            # Approximate ODE solution with numerical ODE solver
            sol = odeint(
                ode_func,
                x_init,
                time_grid,
                method=method,
                options=ode_opts,
                atol=atol,
                rtol=rtol,
            )

        if return_intermediates:
            return sol
        else:
            return sol[-1]

    def compute_likelihood(
        self,
        x_1: Tensor,
        log_p0: Callable[[Tensor], Tensor],
        step_size: Optional[float],
        method: str = "euler",
        atol: float = 1e-5,
        rtol: float = 1e-5,
        time_grid: Tensor = torch.tensor([1.0, 0.0]),
        return_intermediates: bool = False,
        exact_divergence: bool = False,
        enable_grad: bool = False,
        **model_extras,
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
        r"""Solve for log likelihood given a target sample at :math:`t=0`.

        Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x.
        The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`.

        Args:
            x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`).
            log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution.
            step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
            method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
            atol (float): Absolute tolerance, used for adaptive step solvers.
            rtol (float): Relative tolerance, used for adaptive step solvers.
            time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
            return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
            exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
            enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
            **model_extras: Additional input for the model.

        Returns:
            Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1.
        """
        assert (
            time_grid[0] == 1.0 and time_grid[-1] == 0.0
        ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}"

        # Fix the random projection for the Hutchinson divergence estimator
        if not exact_divergence:
            z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0

        def ode_func(x, t):
            return self.velocity_model(x=x, t=t, **model_extras)

        def dynamics_func(t, states):
            xt = states[0]
            with torch.set_grad_enabled(True):
                xt.requires_grad_()
                ut = ode_func(xt, t)

                if exact_divergence:
                    # Compute exact divergence
                    div = 0
                    for i in range(ut.flatten(1).shape[1]):
                        div += gradient(ut[:, i], xt, create_graph=True)[:, i]
                else:
                    # Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
                    ut_dot_z = torch.einsum(
                        "ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1)
                    )
                    grad_ut_dot_z = gradient(ut_dot_z, xt)
                    div = torch.einsum(
                        "ij,ij->i",
                        grad_ut_dot_z.flatten(start_dim=1),
                        z.flatten(start_dim=1),
                    )

            return ut.detach(), div.detach()

        y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
        ode_opts = {"step_size": step_size} if step_size is not None else {}

        with torch.set_grad_enabled(enable_grad):
            sol, log_det = odeint(
                dynamics_func,
                y_init,
                time_grid,
                method=method,
                options=ode_opts,
                atol=atol,
                rtol=rtol,
            )

        x_source = sol[-1]
        source_log_p = log_p0(x_source)

        if return_intermediates:
            return sol, source_log_p + log_det[-1]
        else:
            return sol[-1], source_log_p + log_det[-1]

    @torch.no_grad()
    def sample_out(
        self,
        x_init: Tensor,
        step_size: Optional[float] = None,
        method: str = "euler",
        atol: float = 1e-5,
        rtol: float = 1e-5,
        time_grid: Tensor = torch.tensor([0.0, 1.0]),
        out_dir: str = "/data/baek/flow_matching3/cifar10/progress",
        save_every: int = 1,
        grid_nrow: int = 5,
        save_init: bool = True,
        filename_prefix: str = "progress",
        enable_grad: bool = False,
        to_image_fn: Optional[Callable[[Tensor, str], None]] = None,
        **model_extras,
    ) -> Union[Tensor, Sequence[Tensor]]:
        r"""
        ODE를 풀면서 지정된 시간 지점마다(=한 스텝마다) 진행 이미지를 저장.

        - adaptive 솔버(dopri5 등)는 내부 스텝은 노출되지 않음 → `time_grid` 지점에서만 저장
        - 고정 스텝처럼 저장하려면 `step_size`를 주면 균등 time_grid를 자동 생성

        Args:
            x_init: BCHW, 값 범위 [-1,1] 가정(시각화 용도)
            step_size: 고정 간격 스텝 크기. 주면 time_grid를 균등 분할로 생성.
            method: torchdiffeq 메서드(euler, heun3, midpoint, dopri5 등)
            atol, rtol: adaptive 솔버용 허용 오차
            time_grid: 저장하고 싶은 시간 지점들. 내림차순도 허용.
            out_dir: 진행 이미지 저장 폴더
            save_every: 몇 스텝(인덱스)마다 저장할지
            grid_nrow: 그리드 열 개수
            save_init: 초기 상태도 저장할지
            filename_prefix: 파일명 프리픽스
            enable_grad: True면 샘플링 중 grad 계산 허용
            to_image_fn: (선택) 사용자 정의 저장 함수. 시그니처 (bchw, out_path) -> None
            **model_extras: 모델 추가 인자

        Returns:
            마지막 시점의 텐서 (또는 전체 시점 텐서 시퀀스)
        """
        device = x_init.device
        time_grid = time_grid.to(device)

        # step_size가 주어지면 균등 분할 time_grid 생성
        if step_size is not None:
            t0 = float(time_grid[0].item())
            t1 = float(time_grid[-1].item())
            length = abs(t1 - t0)
            n_steps = max(1, math.ceil(length / float(step_size)))
            # 방향 유지(오름/내림)
            tg = torch.linspace(t0, t1, n_steps + 1, device=device)
            time_grid = tg

        def ode_func(t, x):
            return self.velocity_model(x=x, t=t, **model_extras)

        ode_opts = {"step_size": step_size} if step_size is not None else {}

        os.makedirs(out_dir, exist_ok=True)

        # 저장 함수를 결정
        saver = to_image_fn if to_image_fn is not None else (lambda bchw, path: _save_grid(bchw, path, nrow=grid_nrow))

        with torch.set_grad_enabled(enable_grad):
            # 전체 time_grid에 대해 한번에 적분 (메모리/성능 밸런스가 좋고, 각 지점 결과를 얻기 쉬움)
            sol = odeint(
                ode_func,
                x_init,
                time_grid,
                method=method,
                options=ode_opts,
                atol=atol,
                rtol=rtol,
            )  # [T, B, C, H, W] 가정 (x_init와 동일 shape)
            print(sol.shape)

        # 진행 저장
        # 0번째(초기)도 저장하고 싶다면 save_init=True
        for idx, (t, x_t) in enumerate(zip(time_grid, sol)):
            if idx == 0 and not save_init:
                continue
            if (idx % save_every) == 0 or (idx == len(time_grid) - 1):
                t_val = float(t.item())
                out_path = os.path.join(out_dir, f"{filename_prefix}_idx{idx:04}_t{t_val:+.6f}.png")
                saver(x_t, out_path)

        # 마지막 상태 반환
        return sol[-1]
