import torch
from torchmetrics import Metric
import smplx
from mld.transforms.joints2rots import config
from drag_dev.shape_optimization.coap_selfpene_loss import COAPSelfPenetrationLoss

class SelfPeneMetrics(Metric):
    def __init__(self, dist_sync_on_step: bool = True) -> None:
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        self.name = "Self Penetration Metrics"
        
        
        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("penetration_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("max_penetration_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("penetration_frames", default=torch.tensor(0), dist_reduce_fx="sum")
        
        self.penetration_threshold = 0.01

    def compute(self) -> dict:
        metrics = {
            "avg_penetration": self.penetration_sum / self.count,
            "max_penetration": self.max_penetration_sum / self.count,
            "penetration_rate": self.penetration_frames / self.count
        }
        return metrics

    def update(self, smpl_outputs_list) -> None:
        
        if not hasattr(self, 'coap_loss'):
            self.coap_loss = COAPSelfPenetrationLoss(
                smpl_model=smplx.create(
                    config.SMPL_MODEL_DIR,
                    model_type="smpl",
                    gender="neutral",
                    ext="pkl",
                    batch_size=1
                ).to(smpl_outputs_list[0].vertices.device),
                device=smpl_outputs_list[0].vertices.device
            )
        
        
        for smpl_output in smpl_outputs_list:
            stats = self.coap_loss.compute_loss_with_stats(
                smpl_output, 
                penetration_threshold=self.penetration_threshold
            )
            
            
            self.count += stats['total_frames']
            self.penetration_sum += stats['all_penetration_losses'].sum()
            self.max_penetration_sum += stats['max_penetration']
            self.penetration_frames += stats['penetration_frames']
            
            
            torch.cuda.empty_cache() 