import torch
from torchmetrics import Metric
from mld.utils.physics import all_physics_metrics, all_physics_metrics_fast

class PhysicsMetrics(Metric):
    def __init__(self, dist_sync_on_step: bool = True) -> None:
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.name = "Physics Metrics"
        
        # Initialize metric states for physics evaluation
        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("penetration_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("float_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_std_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_total_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_perc_sum", default=torch.tensor(0.), dist_reduce_fx="sum")

    def compute(self) -> dict:
        metrics = {
            "penetration": self.penetration_sum / self.count,
            "float": self.float_sum / self.count,
            "skate": self.skate_sum / self.count,
            "skate_std": self.skate_std_sum / self.count,
            "skate_total": self.skate_total_sum / self.count,
            "skate_percentage": self.skate_perc_sum / self.count
        }
        return metrics

    def update(self, verts: torch.Tensor, joints: torch.Tensor) -> None:
        num_samples = joints.shape[0]
        self.count += num_samples
        
        # Compute physics metrics for the current batch
        avg_penetration, avg_float_dist, avg_skate, avg_skate_std, avg_skate_total, avg_skate_perc = all_physics_metrics(
            verts, joints, device=joints.device
        )
        
        # Accumulate batch metrics
        self.penetration_sum += avg_penetration * num_samples
        self.float_sum += avg_float_dist * num_samples
        self.skate_sum += avg_skate * num_samples
        # self.skate_total_sum += avg_skate_total * num_samples
        # self.skate_perc_sum += skate_perc 

class FastPhysicsMetrics(Metric):
    def __init__(self, dist_sync_on_step: bool = True) -> None:
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.name = "Fast Physics Metrics"
        
        # Initialize metric states for fast physics evaluation
        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("penetration_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("float_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_std_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_total_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("skate_perc_sum", default=torch.tensor(0.), dist_reduce_fx="sum")

    def compute(self) -> dict:
        metrics = {
            "penetration": self.penetration_sum / self.count,
            "float": self.float_sum / self.count,
            "skate": self.skate_sum / self.count,
            "skate_std": self.skate_std_sum / self.count,
            "skate_total": self.skate_total_sum / self.count,
            "skate_percentage": self.skate_perc_sum / self.count
        }
        return metrics

    def update(self, joints: torch.Tensor) -> None:
        self.count += joints.shape[0]
        
        # Compute fast physics metrics for the current batch
        penetration, float_dist, skate, skate_std, skate_total, skate_perc = all_physics_metrics_fast(
            joints, device=joints.device
        )
        
        # Accumulate fast batch metrics
        self.penetration_sum += penetration
        self.float_sum += float_dist
        self.skate_sum += skate
        self.skate_total_sum += skate_total.sum()
        # self.skate_perc_sum += skate_perc 
        # self.skate_std_sum += skate_std