import numpy as np
from sklearn.metrics import accuracy_score

# def preprocess_logits_for_metrics(outputs, labels):
#     """Preprocess prediction results
    
#     Returns:
#         tuple: (classification_predictions, regression_predictions)
#         - classification_predictions: Predictions for classification task (if exists)
#         - regression_predictions: Predictions for regression task (if exists)
#     """
#     print(len(outputs), outputs[0].shape, outputs[1].shape, outputs[2].shape)
#     if hasattr(outputs, 'classification_logits') and outputs.classification_logits is not None:
#         classification_preds = outputs.classification_logits.argmax(dim=-1)
#     else:
#         classification_preds = None
        
#     regression_preds = outputs.regression_logits if hasattr(outputs, 'regression_logits') else None
    
#     print(classification_preds.shape if classification_preds is not None else None)
#     print(regression_preds.shape if regression_preds is not None else None)
    
#     return (classification_preds, regression_preds)


def preprocess_logits_for_metrics(outputs, labels):
    """Preprocess prediction results
    
    Returns:
        tuple: (classification_predictions, regression_predictions)
        - classification_predictions: Predictions for classification task (if exists)
        - regression_predictions: Predictions for regression task (if exists)
    """
    return outputs[0].argmax(dim=-1)


def compute_metrics(eval_preds, ignore_index=-100):
    """Calculate evaluation metrics
    
    Args:
        eval_preds: Tuple of (predictions, labels)
            - predictions: Output from preprocess_logits_for_metrics
            - labels: Tuple of (classification_labels, regression_labels)
    """
    
    predictions, labels = eval_preds
    classification_preds = predictions
    
    metrics = {}
    
    # Calculate classification metrics
    if classification_preds is not None:
        classification_labels = labels[0] if isinstance(labels, tuple) else labels
        valid_mask = classification_labels != ignore_index
        
        if valid_mask.any():
            valid_preds = classification_preds[valid_mask]
            valid_labels = classification_labels[valid_mask]
            error = 1 - accuracy_score(valid_labels, valid_preds)
            metrics["classification_error"] = error
    
    # Calculate regression metrics
    # if regression_preds is not None:
    #     regression_labels = labels[-1] if isinstance(labels, tuple) else None
    #     if regression_labels is not None:
    #         valid_mask = np.isfinite(regression_labels) & (regression_labels != ignore_index)
            
    #         if valid_mask.any():
    #             valid_preds = regression_preds[valid_mask]
    #             valid_labels = regression_labels[valid_mask]
                
    #             # Calculate MSE
    #             mse = np.mean((valid_preds - valid_labels) ** 2)
    #             metrics["regression_mse"] = mse
                
    #             # Threshold-based error rate (if needed)
    #             errors = (valid_labels != np.round(valid_preds)).any(axis=-1).mean()
    #             metrics["regression_error"] = errors
    
    return metrics