import numbers
from collections import defaultdict
from typing import Dict, Any

import numpy as np
import pandas as pd
from impugen.metrics.high_order import *


def _is_numeric(x: Any) -> bool:
    """스칼라·넘파이 배열·판다스 Series/DataFrame 모두 ‘숫자형’인지 판별."""
    if isinstance(x, numbers.Number):
        return True
    if isinstance(x, (np.ndarray, pd.Series)):
        return np.issubdtype(x.dtype, np.number)
    if isinstance(x, pd.DataFrame):
        return all(np.issubdtype(dt, np.number) for dt in x.dtypes)
    return False


def _add_dict(a: Dict, b: Dict):
    """dict a += b  (숫자/배열/DataFrame은 요소별 합, nested dict은 재귀)."""
    for k, v in b.items():
        if k not in a:
            a[k] = v.copy() if isinstance(v, (pd.DataFrame, np.ndarray)) else v
            continue
        if isinstance(v, dict):
            _add_dict(a[k], v)
        elif _is_numeric(v):
            a[k] = a[k] + v
    return a


def _div_dict(a: Dict, n: int):
    """dict a / n (숫자/배열/DataFrame은 요소별 나눗셈, nested dict 재귀)."""
    for k, v in a.items():
        if isinstance(v, dict):
            _div_dict(v, n)
        elif _is_numeric(v):
            a[k] = v / n
    return a


class ClassConditionalMetric:

    def evaluation(self, data: pd.DataFrame, *args, **kwargs) -> Dict[str, Any]:
        tgt = self.transform.target_column
        if (tgt is None) or (tgt in self.transform.numerical_columns):
            return {}

        classes = data[tgt].dropna().unique()
        per_class: Dict[str, Dict] = {}

        macro_sum: Dict[str, Any] = defaultdict(lambda: 0)
        valid = 0

        for c in classes:
            subset = data[data[tgt] == c]
            if subset.empty:
                continue
            res = self._evaluation(subset, *args, **kwargs)
            per_class[str(c)] = res
            _add_dict(macro_sum, res)
            valid += 1

        if valid == 0:
            return {}

        macro_avg = _div_dict(macro_sum, valid)
        per_class["macro_avg"] = macro_avg
        return per_class


class MacroAlphaPrecision(ClassConditionalMetric, AlphaPrecision):
    pass
