"""
Different sampling strategies for the generator matching framework. 
""" 
import torch
from abc import ABC
from flow_matching.solver import MixtureDiscreteEulerSolver
from torch import nn, Tensor


# Model Wrapper
class ModelWrapper(ABC, nn.Module):
    """
    This class is used to wrap around another model, adding custom forward pass logic.
    """

    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
        r"""
        This method defines how inputs should be passed through the wrapped model.
        Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input,
        along with any additional keyword arguments.

        Optional things to do here:
            - check that t is in the dimensions that the model is expecting.
            - add a custom forward pass logic.
            - call the wrapped model.

        | given x, t
        | returns the model output for input x at time t, with extra information `extra`.

        Args:
            x (Tensor): input data to the model (batch_size, ...).
            t (Tensor): time (batch_size).
            **extras: additional information forwarded to the model, e.g., text condition.

        Returns:
            Tensor: model output.
        """
        return self.model(x=x, t=t, **extras)

class WrappedModel(ModelWrapper):
    def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
        # Note: logit's precision is important.
        return torch.softmax(self.model(x_t=x, time=t, **extras).float(), -1)

# ------------------------------------
# Euler sampling - default 
# ------------------------------------ 
def euler_sampler(
        text_model,
        image_embedding,
        initial_text,
        vocab_size,
        path,
        num_steps=100,
        trajectory=False
):
    """
    Perform Euler sampling for image generation.

    Args:
        image_model: The image generation model.
        text_model: The text model for conditioning.
        initial_image: Initial image tensor.
        initial_text: Initial text tensor.
        num_steps: Number of sampling steps.
        cfg_scale: Classifier-free guidance scale.
        guidance_low: Low guidance scale.
        guidance_high: High guidance scale.

    Returns:
        Generated image tensor after sampling.
    """
    image_dtype = image_embedding.dtype

    # For text sampling
    wrapped_probability_denoiser = WrappedModel(model=text_model.to(dtype=image_dtype))

    # Define time steps for sampling 
    t_steps = torch.linspace(0, 1, num_steps+1, dtype=image_dtype)

    # Initialize image and text tensors
    device = image_embedding.device
    txt_next = initial_text.to(device)

    if trajectory:
        txt_vec = []

    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            # -----------------------------
            # Update text sample 
            # ----------------------------- 
            txt_cur = txt_next

            solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
            txt_next = solver.sample(x_init=txt_cur,
                                     step_size=None,
                                     dtype_categorical=image_dtype,
                                     time_grid=torch.tensor([t_cur, t_next]).to(dtype=image_dtype),
                                     img_tokens=image_embedding
                                    )
            
            if trajectory:
                txt_vec.append(txt_next)

    if trajectory:
        return txt_next, txt_vec
    else:
        return txt_next

