from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any


@dataclass
class Rule:
    f: Callable[..., tuple[Any, Any]]
    arguments: dict[str, str]
    loss_function: Callable[[Any, Any], Any]

    def __call__(self, **kwargs) -> Any:
        try:
            arguments = {k: kwargs[v] for k, v in self.arguments.items()}
        except KeyError as e:
            raise ValueError('Missing input') from e

        prediction, target = self.f(**arguments)
        # print(f'{prediction = }, {target = }')
        # print(f'{prediction.shape = }, {target.shape = }')
        target_v = target.view(-1)
        prediction_v = prediction.view(target_v.shape[-1], -1)
        # print(f'{prediction.shape = }, {prediction_v.shape = }')
        # print(f'{target.shape = }, {target_v.shape = }')
        loss = self.loss_function(prediction_v, target_v)

        return {'loss': loss,
                'prediction': prediction,
                'target': target}


@dataclass
class System:
    rules: dict[str, Rule] = field(default_factory=dict)

    def add_rule(self, name, f, loss_function, /, **arguments) -> None:
        # TODO add introspection
        self.rules[name] = Rule(f, arguments, loss_function)

    def __call__(self, **kwargs: Any) -> Any:
        results = {}
        for rule_name, rule in self.rules.items():
            # print(f'{rule_name = }')
            results[rule_name] = rule(**kwargs)

        return results
