import numpy as np



class MetricLogger:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.metrics = {
            'iou': [],
            'chamfer': [],
            'emd': [],
            'f1': [],
            'precision': [],
            'recall': [],
            'mse': [],
            'mae': [],
            'normal_consistency': [],
            'normal_consistency_mesh': [],
            'hausdorff': [],
        }
        self.names = []  
        
    def update(self, name, iou, chamfer, f1, precision, recall, mse, mae, normal_consistency):
        self.names.append(str(name))
        self.metrics['iou'].append(float(iou))
        self.metrics['chamfer'].append(float(chamfer))
        self.metrics['f1'].append(float(f1))
        self.metrics['precision'].append(float(precision))
        self.metrics['recall'].append(float(recall))
        self.metrics['mse'].append(float(mse))
        self.metrics['mae'].append(float(mae))
        self.metrics['normal_consistency'].append(float(normal_consistency))
        self.metrics['emd'].append(float(emd))
        self.metrics['hausdorff'].append(float(hausdorff))
    def update_metrics(self, metrics):
        for metric in metrics:
            self.metrics[metric].append(float(metrics[metric]))
        
    def average(self):
        return {k: np.mean(v) if len(v) > 0 else 0.0 for k, v in self.metrics.items()}
    
    def save_csv(self, filename="metrics.csv"):
        import csv
        from datetime import datetime
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{filename.split('.')[0]}_{timestamp}.csv"
        
        with open(filename, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['Name', 'IoU', 'Chamfer', 'F1', 'Precision', 'Recall', 'MSE', 'MAE', 'Normal Consistency'])
            for name, iou, chamfer, f1, prec, rec, mse, mae in zip(
                self.names,
                self.metrics['iou'],
                self.metrics['chamfer'],
                self.metrics['f1'],
                self.metrics['precision'],
                self.metrics['recall'],
                self.metrics['mse'],
                self.metrics['mae'],
                self.metrics['normal_consistency'],
                self.metrics['emd'],
                self.metrics['hausdorff']                
            ):
                writer.writerow([name, iou, chamfer, f1, prec, rec, mse, mae, normal_consistency, emd, hausdorff])
            avg = self.average()
            writer.writerow([
                'Average',
                avg['iou'],
                avg['chamfer'],
                avg['f1'],
                avg['precision'],
                avg['recall'],
                avg['mse'],
                avg['mae'],
                avg['normal_consistency'],
                avg['emd'],
                avg['hausdorff']
            ])