import torch
from torchmetrics import Metric
from torchmetrics.utilities import dim_zero_cat

from mld.utils.temos_utils import remove_padding
from .utils import calculate_skating_ratio, calculate_trajectory_error, control_l2

class DragMetrics(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("collision_rate", default=torch.tensor(0.0))
        self.add_state("drag_error", default=torch.tensor(0.0))
        
    def update(self, motion, drag_control):
        
        collision = torch.tensor(0.1)  
        self.collision_rate = collision

        
        error = torch.tensor(0.2)  
        self.drag_error = error

    def compute(self):
        
        return {
            "collision_rate": self.collision_rate,
            "drag_error": self.drag_error
        } 