# Implementation of Flow Matching Model
# https://arxiv.org/abs/2502.02538 (Flow Q-Learning)

import torch
import torch.nn as nn
import torch.nn.functional as F
from agents.helpers import Losses


class FlowMatching(nn.Module):
    """
    Flow Matching Model for BC flow policy.

    """
    
    def __init__(self, state_dim, action_dim, model, max_action,
                 n_timesteps=10, loss_type='l2', clip_denoised=True):
        super(FlowMatching, self).__init__()
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.model = model  # velocity field v_θ(t, s, x)
        
        self.n_timesteps = int(n_timesteps)  # M: number of Euler steps for sampling
        self.clip_denoised = clip_denoised
        
        self.loss_fn = Losses[loss_type]()
    
    # ------------------------------------------ sampling ------------------------------------------#
    
    def sample(self, state, n_steps=None):
        if n_steps is None:
            n_steps = self.n_timesteps
            
        batch_size = state.shape[0]
        device = state.device
        
        # Start from noise z ~ N(0, I)
        z = torch.randn(batch_size, self.action_dim, device=device)
        
        return self.sample_from_noise(state, z, n_steps)
    
    def sample_from_noise(self, state, z, n_steps=None):

        if n_steps is None:
            n_steps = self.n_timesteps
            
        batch_size = state.shape[0]
        device = state.device
        
        # Euler integration
        dt = 1.0 / n_steps
        for i in range(n_steps):
            t = torch.full((batch_size,), i * dt, device=device)
            # v_θ(t, s, z)
            v = self.model(z, t, state)
            z = z + v * dt
        
        if self.clip_denoised:
            z = z.clamp(-self.max_action, self.max_action)
        
        return z
    
    @torch.no_grad()
    def sample_no_grad(self, state, n_steps=None):
        """Sample without gradient computation."""
        return self.sample(state, n_steps)
    
    
    # ------------------------------------------ training ------------------------------------------#
    
    def loss(self, action, state, weights=torch.tensor(1.0)):

        batch_size = action.shape[0]
        device = action.device
        
        x0 = torch.randn_like(action)
        x1 = action

        t = torch.rand(batch_size, device=device)
        t_expand = t.view(-1, 1)  # (batch_size, 1)

        xt = (1 - t_expand) * x0 + t_expand * x1
        

        target_v = x1 - x0
        pred_v = self.model(xt, t, state)
        
        loss = self.loss_fn(pred_v, target_v, weights)
        
        return loss
    
    def forward(self, state, *args, **kwargs):
        """Forward pass for sampling."""
        return self.sample(state, *args, **kwargs)
