import torch
import torch.nn as nn
import math


class GuidanceMatcher:
    """
    Cov-G 방법을 구현하는 클래스
    """
    def __init__(
        self, 
        model: nn.Module,
        action_dim: int,
        model_z: nn.Module = None,
        model_v: nn.Module = None,
        scale: float = 1.0,
        guidance_type: str = 'direct',
    ):
        self.model = model
        self.scale = scale
        self.guidance_type = guidance_type
        self.action_dim = action_dim
        self.model_z = model_z
        self.model_v = model_v

    def schedule_fn(self, t):
        #return t
        #return 1-t
        return 0.5 * (1 + torch.cos(t * math.pi))
        #return (torch.exp(-x) - math.exp(-1)) / (1 - math.exp(-1))
    
    

    def apply_guidance(self, xt, vt, grad_v, cond, t, values, eps=1e-8):


        guided_vt = vt +  grad_v * self.scale * self.schedule_fn(t)

        return guided_vt
        
    def _compute_z(self, x, cond, t):
        if self.model_z is None:
            return torch.ones(x.shape[0], device=x.device)
        else:
            z_pred = self.model_z(x, cond, t)  # (B, horizon, 1)
            return z_pred.squeeze(-1)[:, -1].exp().clamp(min=1e-8)  # (B,)
