import warnings
from collections.abc import Sequence
from copy import deepcopy
import os
from filelock import FileLock

import torch

from src.evaluation.smoothing_measures.mad import mad_batch
from src.evaluation.smoothing_measures.node_similarity import wu_smoothness
from src.evaluation.smoothing_measures.dirichlet_energy import dirichlet_energy_pyg
from src.evaluation.smoothing_measures.mad_gap import mad_madgap_batch
from src.utils.path_io import get_path_up_to

warnings.filterwarnings("ignore")

ROOT_PATH = get_path_up_to(__file__, "src")
LOCK_PATH: str = os.path.join(ROOT_PATH, 'runs', 'measure_lock')

MEASURE_MAPPING = {
    "dirichlet": dirichlet_energy_pyg,
    "mad_gap": mad_madgap_batch,
    "mad": mad_batch,
    "jacobian": None,
    "wu_smoothness": wu_smoothness
}


class Measurements:

    def __init__(self, measures: dict):
        """
        Initialize the Scorer with the metrics to be calculated.

        Args:
            measures (dict): A dictionary where keys are metric names and values are dictionaries of parameters for each metric.
            class_names (list[str], optional): List of class names for multi-class metrics. If None, default names will be used.
        """

        self.metric_params = measures

    def __call__(self, batch, embeddings: torch.tensor) -> dict[str, float]:
        return self.calc_test_measure(batch, embeddings)

    def calc_test_measure(self,
                          batch,
                          embeddings: torch.tensor) -> dict[str, float]:
        """
        Calculate the accuracy, F1, precision, and recall scores for passed predictions and targets.

        Given the input of targets and predictions this function calculates the total score and the separate scores for each
        class for each of the metrics.

        :param targets: True labels
        :param predictions: Predicted labels
        :return: A dict with a dict for each score containing the total and separate scores for each class
        """

        # Calculate scores for each metric
        measures = {}

        with FileLock(LOCK_PATH + ".lock"):

            for measure, params in self.metric_params.items():
                print(f"[MEASURE]: Calculate {measure}")

                p = deepcopy(params)
                f = p.pop('function')
                res = MEASURE_MAPPING[f](batch, embeddings=embeddings, **p)
                if isinstance(res, dict):
                    measures.update(res)
                else:
                    measures[measure] = res

            measures['n'] = len(batch)

        return measures
