from collections import OrderedDict

import numpy as np
import pandas as pd
from sklearn.metrics import mean_absolute_error, root_mean_squared_error, accuracy_score, balanced_accuracy_score
from collections import OrderedDict
from .base import EvalBase


class ImputationScore(EvalBase):
    """
    Evaluate imputation performance on categorical and numerical features.
    It requires the ground-truth DataFrame (gt) to have valid (non-NaN) values
    only at the positions where imputation was needed. All other positions should
    remain NaN, ensuring that the metrics are only computed on the originally
    missing cells.

    Attributes:
        cat_metrics (list):
            List of metric functions for categorical features
            (e.g., accuracy_score, balanced_accuracy_score).
        num_metrics (list):
            List of metric functions for numerical features
            (e.g., mean_absolute_error, root_mean_squared_error).

    Inherits:
        EvalBase: A base class that provides an `_evaluation` hook.
    """

    def _evaluation(self, gt: pd.DataFrame, fake: pd.DataFrame) -> dict:
        """
        Perform imputation evaluation by comparing ground-truth values against
        the corresponding imputed values. This calculation is only done at positions
        that were originally missing (non-NaN in `gt` indicate positions that
        should be included in the metrics).

        Args:
            gt (pd.DataFrame):
                The "ground-truth" DataFrame. Must have valid (non-NaN) entries only where
                imputation results need to be compared. All other positions should be NaN.
            fake (pd.DataFrame):
                The imputed DataFrame to evaluate.

        Returns:
            dict: A dictionary of metric results, keyed by "<metric_name>[.minmax_norm]".
        """
        assert len(gt) == len(fake), f"Mismatch in number of rows: {len(gt)} vs {len(fake)}"
        self.cat_metrics = [accuracy_score, balanced_accuracy_score]
        self.num_metrics = [mean_absolute_error, root_mean_squared_error]
        # Transform ground-truth and fake data (splitting into categorical/numerical arrays).
        gt_num, gt_cat = self.transform.transform(gt, scaler='standard', onehot=False)
        fake_num, fake_cat = self.transform.transform(fake, scaler='standard', onehot=False)
        outlier = (gt_num > 100) | (gt_num < -100)
        gt_num[outlier] = np.nan
        # Generate masks for valid positions (non-NaN) in the ground-truth arrays
        mask_cat = ~np.isnan(gt_cat)
        mask_num = ~np.isnan(gt_num)

        scores = OrderedDict()
        column_results = OrderedDict()
        # Pack each set of arrays with a corresponding mask and suffix
        for metric_list, arr_gt, arr_fake, mask, suffix, col_name in zip(
                [self.cat_metrics, self.num_metrics],
                [gt_cat, gt_num],
                [fake_cat, fake_num],
                [mask_cat, mask_num],
                ['', '.zscore'],
                [self.transform.categorical_columns, self.transform.numerical_columns]
        ):
            # Compute each metric for each column, only on valid (non-NaN) entries
            for metric in metric_list:
                metric_scores = []
                for col_gt, col_fake, col_mask, col in zip(arr_gt.T, arr_fake.T, mask.T, col_name):
                    valid_count = col_mask.sum()
                    if valid_count > 0:
                        score = metric(col_gt[col_mask], col_fake[col_mask])
                        metric_scores.append(score)
                        column_results['%s.%s' % (col, metric.__name__ + suffix)] = score
                scores[metric.__name__ + suffix] = np.mean(metric_scores) if metric_scores else np.nan
        scores['column_results'] = column_results
        return scores

