from typing import Dict, Any
import os
import torch

from ..guidance_pl import get_sample_guide_fn
from ..models import NoisyCuboidTransformerEncoder


pretrained_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "pretrained"))

class NoisyCuboidTransformerEncoderWrapper(NoisyCuboidTransformerEncoder):
    r"""
    Concatenate xc to x and feed into the NoisyCuboidTransformerEncoder.
    """
    def __init__(self, **kwargs):
        r"""
        Notice that the seq_len in input_shape (T, H, W, C) should be in_len + out_len
        """
        super().__init__(**kwargs)

    def forward(self, z, t, zc, verbose=False, **kwargs):
        r"""
        Parameters
        ----------
        z:  torch.Tensor
            Shape (B, T_out, H, W, C)
        t:  torch.Tensor
            Shape (B, )
        zc: torch.Tensor
            Shape (B, T_in, H, W, C)
        verbose:    bool

        Returns
        -------
        out:    torch.Tensor
            Shape = (B, T, C) if self.readout_seq is True, T = self.out_len if given else T of input x.
            Shape = (B, C) if self.readout_seq is False
        """
        z_cat = torch.cat([zc, z], dim=1)
        return super(NoisyCuboidTransformerEncoderWrapper, self).forward(z_cat, t, verbose=verbose, **kwargs)

class NbodyGuidanceEnergy():

    def __init__(self,
                 guide_type: str = "sum_energy",
                 out_len: int = 10,
                 guide_scale: float = 1.0,
                 model_type: str = "cuboid",
                 model_args: Dict[str, Any] = None,
                 model_ckpt_path: str = None,
                 ):
        r"""

        Parameters
        ----------
        guide_type: str
        context_intensity_multiplier:   float
            take effect only when `guide_type=="context_intensity"`
        guide_scale:    float
        model_type: str
        model_args: Dict[str, Any]
        model_ckpt_path:    str
            if not None, load the model from the checkpoint
        """
        super().__init__()
        assert guide_type in ["sum_energy", ], \
            f"guide_type={guide_type} is not implemented!"
        self.guide_type = guide_type
        self.out_len = out_len
        self.guide_scale = guide_scale
        if model_args is None:
            model_args = {}
        if model_type == "cuboid":
            self.model = NoisyCuboidTransformerEncoderWrapper(**model_args)
        else:
            raise NotImplementedError(f"model_type={model_type} is not implemented")
        if model_ckpt_path is not None and os.path.exists(os.path.join(pretrained_dir, model_ckpt_path)):
            self.model.load_state_dict(torch.load(os.path.join(pretrained_dir, model_ckpt_path), map_location="cpu"))

    def model_objective(self, x, y=None, **kwargs):
        r"""
        Parameters
        ----------
        x, y
            not used
        kwargs: Dict[str, Any]
            should contain `energy`, a torch.Tensor with shape (b, t, 2)

        Returns
        -------
        energy: torch.Tensor
        """
        energy = kwargs.get("energy")[:, -self.out_len:, :]
        return energy

    def guidance_fn(self, zt, t, y=None, zc=None, **kwargs):
        r"""
        transform the learned model to the final guidance \mathcal{F}.

        Parameters
        ----------
        zt: torch.Tensor
            noisy latent z
        t:  torch.Tensor
            timestamp
        y:  torch.Tensor
            context sequence in pixel space
        zc: torch.Tensor
            encoded context sequence in latente space
        kwargs: Dict[str, Any]
            additional arguments
            if `guide_type=="future_intensity"`, then `avg_x_gt`: float is required.
        Returns
        -------
        ret:    torch.Tensor
        """
        pred = self.model(zt, t, zc=zc, y=y, **kwargs)
        if self.guide_type == "sum_energy":
            target = kwargs.get("energy")[:, -self.out_len:, :]
        else:
            raise NotImplementedError
        ret = torch.linalg.vector_norm(pred.sum(dim=-1) - target.sum(dim=-1),
                                       ord=2)
        return ret

    def get_mean_shift(self, zt, t, y=None, zc=None, **kwargs):
        r"""
        Parameters
        ----------
        zt: torch.Tensor
            noisy latent z
        t:  torch.Tensor
            timestamp
        y:  torch.Tensor
            context sequence in pixel space
        zc: torch.Tensor
            encoded context sequence in latente space
        xt: torch.Tensor
            decoded noisy latent zt
        Returns
        -------
        ret:    torch.Tensor
            \nabla_zt U
        """
        grad_fn = get_sample_guide_fn(self.guidance_fn)
        grad = grad_fn(zt, t, y=y, zc=zc, **kwargs)
        return -self.guide_scale * grad
