from typing import Dict, Optional, Tuple, Union

import torch
from jaxtyping import Bool, Float
from torch import Tensor

from proteinfoundation.flow_matching.base_flow_matcher import BaseFlowMatcher
from proteinfoundation.utils.align_utils import mean_w_mask


class RDNFlowMatcher(BaseFlowMatcher):

    def __init__(
        self,
        zero_com_noise: bool = False,
        guidance_enabled: bool = False,
        dim: int = 3,
        **kwargs,
    ):
        super().__init__(
            guidance_enabled=guidance_enabled,
            dim=dim,
        )
        self.zero_com_noise = zero_com_noise

    def _force_zero_com(
        self, x: Float[Tensor, "* n d"], mask: Optional[Bool[Tensor, "* n"]] = None
    ) -> Float[Tensor, "* n d"]:

        if mask is None:
            x = x - torch.mean(x, dim=-2, keepdim=True)
        else:
            x = (x - mean_w_mask(x, mask, keepdim=True)) * mask[..., None]
        return x

    def _apply_mask(
        self, x: Float[Tensor, "* n d"], mask: Optional[Bool[Tensor, "* n"]] = None
    ) -> Float[Tensor, "* n d"]:

        if mask is None:
            return x
        return x * mask[..., None]

    def sample_noise(
        self,
        n: int,
        device: torch.device,
        shape: Tuple = tuple(),
        mask: Optional[Bool[Tensor, "* n"]] = None,
    ) -> Float[Tensor, "* n d"]:

        noise = torch.randn(
            shape + (n, self.dim),
            device=device,
        )
        noise = self._apply_mask(noise, mask)
        if self.zero_com_noise:
            noise = self._force_zero_com(noise, mask)
        return noise

    def interpolate(
        self,
        x_0: Float[Tensor, "* n d"],
        x_1: Float[Tensor, "* n d"],
        t: Float[Tensor, "*"],
        mask: Optional[Bool[Tensor, "* n"]] = None,
    ) -> Float[Tensor, "* n d"]:

        x_0, x_1 = map(lambda args: self._apply_mask(*args), ((x_0, mask), (x_1, mask)))
        t = t[..., None, None]

        return (1.0 - t) * x_0 + t * x_1

    def nn_out_add_clean_sample_prediction(
        self,
        x_t: Float[Tensor, "* n d"],
        t: Float[Tensor, "*"],
        mask: Bool[Tensor, "* n"],
        nn_out: Dict[str, torch.Tensor],
    ) -> Dict[str, Float[Tensor, "* n d"]]:

        t = t[..., None, None]
        if "x_1" in nn_out:
            pass
        elif "v" in nn_out:
            nn_out["x_1"] = x_t + (1.0 - t) * nn_out["v"]
        else:
            raise IOError(
                f"Cannot compute clean sample prediction from keys {[k for k in nn_out]}"
            )
        nn_out["x_1"] = nn_out["x_1"] * mask[..., None]
        return nn_out

    def nn_out_add_simulation_tensor(
        self,
        x_t: Float[Tensor, "* n d"],
        t: Float[Tensor, "*"],
        mask: Bool[Tensor, "* n"],
        nn_out: Dict[str, torch.Tensor],
    ) -> Dict[str, Float[Tensor, "* n d"]]:

        t = t[..., None, None]
        if "v" in nn_out:
            pass
        elif "x_1" in nn_out:
            num = nn_out["x_1"] - x_t
            den = 1.0 - t
            nn_out["v"] = num / (den + 1e-5)
        else:
            raise IOError(
                f"Cannot compute simulation tensor (v) from keys {[k for k in nn_out]}"
            )
        nn_out["v"] = nn_out["v"] * mask[..., None]
        return nn_out

    def compute_fm_loss(
        self,
        x_0: Float[Tensor, "* n d"],
        x_1: Float[Tensor, "* n d"],
        x_t: Float[Tensor, "* n d"],
        mask: Bool[Tensor, "* n"],
        t: Float[Tensor, "*"],
        nn_out: Dict[str, Float[Tensor, "* n d"]],
    ) -> Float[Tensor, "*"]:

        nn_out = self.nn_out_add_clean_sample_prediction(
            x_t=x_t,
            t=t,
            mask=mask,
            nn_out=nn_out,
        )
        nres = torch.sum(mask, dim=-1)
        err = (x_1 - nn_out["x_1"]) * mask[..., None]
        loss = torch.sum(err**2, dim=(-1, -2)) / nres
        total_loss_w = 1.0 / ((1.0 - t) ** 2 + 1e-5)
        loss = loss * total_loss_w
        return loss

    def nn_out_add_guided_simulation_tensor(
        self,
        nn_out: Dict[str, torch.Tensor],
        nn_out_ag: Union[Dict[str, torch.Tensor], None],
        nn_out_ucond: Union[Dict[str, torch.Tensor], None],
        guidance_w: float,
        ag_ratio: float,
    ) -> Dict[str, torch.Tensor]:

        assert "v" in nn_out, "`v` should be a key in the nn_out dict"
        if not self.guidance_enabled:
            return nn_out

        v = nn_out["v"]
        v_ag = torch.zeros_like(v) if nn_out_ag is None else nn_out_ag["v"]
        v_ucond = torch.zeros_like(v) if nn_out_ucond is None else nn_out_ucond["v"]

        nn_out["v_guided"] = guidance_w * v + (1 - guidance_w) * (
            ag_ratio * v_ag + (1 - ag_ratio) * v_ucond
        )
        return nn_out

    def simulation_step(
        self,
        x_t: Float[Tensor, "* n d"],
        nn_out: Dict[str, Float[Tensor, "* n d"]],
        t: Float[Tensor, "*"],
        dt: float,
        gt: float,
        mask: Bool[Tensor, "* n"],
        simulation_step_params: Dict,
    ) -> Float[Tensor, "* n d"]:

        sampling_mode = simulation_step_params["sampling_mode"]
        sc_scale_noise = simulation_step_params["sc_scale_noise"]
        sc_scale_score = simulation_step_params["sc_scale_score"]
        t_lim_ode = simulation_step_params["t_lim_ode"]
        t_lim_ode_below = simulation_step_params["t_lim_ode_below"]
        center_every_step = simulation_step_params["center_every_step"]

        assert sampling_mode in [
            "vf",
            "sc",
            "vf_ss",
            "vf_ss_sc_sn",
        ], f"Invalid sampling mode {sampling_mode}, should be `vf`, `sc`, `vf_ss`, or `vf_ss_sc_sn`"
        assert (
            sc_scale_noise >= 0
        ), f"Scale noise for sampling should be >= 0, got {sc_scale_noise}"
        assert (
            sc_scale_score >= 0
        ), f"Scale score for sampling should be >= 0, got {sc_scale_score}"
        assert gt >= 0, f"gt for sampling should be >= 0, got {gt}"
        t_element = t.flatten()[0]
        assert torch.all(
            t_element == t
        ), "Sampling only implemented for same time for all samples"

        if self.guidance_enabled:
            v = nn_out["v_guided"]
        else:
            v = nn_out["v"]

        sc_scale_score_def = 1.5
        sc_scale_noise_def = 0.3

        if sampling_mode == "vf":
            delta_x = v * dt

        elif sampling_mode == "vf_ss":
            if t_element < t_lim_ode_below:
                score = vf_to_score(x_t, v, t)
                eps = torch.randn(x_t.shape, dtype=x_t.dtype, device=x_t.device)
                std_eps = torch.sqrt(2 * gt * sc_scale_noise_def * dt)
                delta_x = (v + gt * score) * dt + std_eps * eps
            else:
                score = vf_to_score(x_t, v, t)
                scaled_score = score * sc_scale_score
                v_scaled = score_to_vf(x_t, scaled_score, t)
                delta_x = v_scaled * dt

        elif sampling_mode == "sc":
            if t_element > t_lim_ode:
                score = vf_to_score(x_t, v, t)
                scaled_score = score * sc_scale_score_def
                v_scaled = score_to_vf(x_t, scaled_score, t)
                delta_x = v_scaled * dt
            else:
                score = vf_to_score(x_t, v, t)
                eps = torch.randn(x_t.shape, dtype=x_t.dtype, device=x_t.device)
                std_eps = torch.sqrt(2 * gt * sc_scale_noise * dt)
                delta_x = (v + gt * score) * dt + std_eps * eps

        elif sampling_mode == "vf_ss_sc_sn":
            if t_element > t_lim_ode:
                score = vf_to_score(x_t, v, t)
                scaled_score = score * sc_scale_score_def
                v_scaled = score_to_vf(x_t, scaled_score, t)
                delta_x = v_scaled * dt
            elif t_element < t_lim_ode_below:
                score = vf_to_score(x_t, v, t)
                eps = torch.randn(x_t.shape, dtype=x_t.dtype, device=x_t.device)
                std_eps = torch.sqrt(2 * gt * sc_scale_noise_def * dt)
                delta_x = (v + gt * score) * dt + std_eps * eps
            else:
                score = vf_to_score(x_t, v, t)
                scaled_score = score * sc_scale_score
                v_scaled = score_to_vf(x_t, scaled_score, t)
                eps = torch.randn(x_t.shape, dtype=x_t.dtype, device=x_t.device)
                std_eps = torch.sqrt(2 * gt * sc_scale_noise * dt)
                delta_x = (v_scaled + gt * score) * dt + std_eps * eps

        else:
            raise ValueError(f"Invalid sampling mode {sampling_mode}")

        x_next = x_t + delta_x

        x_next = self._apply_mask(x_next, mask)
        if center_every_step:
            x_next = self._force_zero_com(x_next, mask)
        return x_next


def vf_to_score(
    x_t: Float[Tensor, "* n d"],
    v: Float[Tensor, "* n d"],
    t: Float[Tensor, "*"],
) -> Float[Tensor, "* n d"]:

    assert torch.all(t < 1.0), "vf_to_score requires t < 1 (strict)"
    num = t[..., None, None] * v - x_t
    den = (1.0 - t)[..., None, None]
    score = num / den
    return score


def score_to_vf(
    x_t: Float[Tensor, "* n d"],
    score: Float[Tensor, "* n d"],
    t: Float[Tensor, "*"],
) -> Float[Tensor, "* n d"]:

    assert torch.all(t > 0.0), "score_to_vf requires t > 0 (strict)"
    t = t[..., None, None]
    return (x_t + (1.0 - t) * score) / t
