import glob
from unittest import TestCase
from unittest.mock import patch

import pytest
import requests
import yaml

from evaluate.hub import push_to_hub
from tests.test_metric import DummyMetric


minimum_metadata = {
    "model-index": [
        {
            "results": [
                {
                    "task": {"type": "dummy-task"},
                    "dataset": {"type": "dataset_type", "name": "dataset_name"},
                    "metrics": [
                        {"type": "dummy_metric", "value": 1.0, "name": "Pretty Metric Name"},
                    ],
                }
            ]
        }
    ]
}

extras_metadata = {
    "model-index": [
        {
            "results": [
                {
                    "task": {"type": "dummy-task", "name": "task_name"},
                    "dataset": {
                        "type": "dataset_type",
                        "name": "dataset_name",
                        "config": "fr",
                        "split": "test",
                        "revision": "abc",
                        "args": {"a": 1, "b": 2},
                    },
                    "metrics": [
                        {
                            "type": "dummy_metric",
                            "value": 1.0,
                            "name": "Pretty Metric Name",
                            "config": "default",
                            "args": {"hello": 1, "world": 2},
                        },
                    ],
                }
            ]
        }
    ]
}


@patch("evaluate.hub.HF_HUB_ALLOWED_TASKS", ["dummy-task"])
@patch("evaluate.hub.dataset_info", lambda x: True)
@patch("evaluate.hub.model_info", lambda x: True)
@patch("evaluate.hub.metadata_update")
class TestHub(TestCase):
    @pytest.fixture(autouse=True)
    def inject_fixtures(self, caplog):
        self._caplog = caplog

    def setUp(self):
        self.metric = DummyMetric()
        self.metric.add()
        self.args = {"hello": 1, "world": 2}
        self.result = self.metric.compute()

    def test_push_metric_required_arguments(self, metadata_update):
        push_to_hub(
            model_id="username/repo",
            metric_value=self.result["accuracy"],
            metric_name="Pretty Metric Name",
            metric_type=self.metric.name,
            dataset_name="dataset_name",
            dataset_type="dataset_type",
            task_type="dummy-task",
        )

        metadata_update.assert_called_once_with(repo_id="username/repo", metadata=minimum_metadata, overwrite=False)

    def test_push_metric_missing_arguments(self, metadata_update):
        with pytest.raises(TypeError):
            push_to_hub(
                model_id="username/repo",
                metric_value=self.result["accuracy"],
                metric_name="Pretty Metric Name",
                metric_type=self.metric.name,
                dataset_name="dataset_name",
                dataset_type="dummy-task",
            )

    def test_push_metric_invalid_arguments(self, metadata_update):
        with pytest.raises(TypeError):
            push_to_hub(
                model_id="username/repo",
                metric_value=self.result["accuracy"],
                metric_name="Pretty Metric Name",
                metric_type=self.metric.name,
                dataset_name="dataset_name",
                dataset_type="dataset_type",
                task_type="dummy-task",
                random_value="incorrect",
            )

    def test_push_metric_extra_arguments(self, metadata_update):
        push_to_hub(
            model_id="username/repo",
            metric_value=self.result["accuracy"],
            metric_name="Pretty Metric Name",
            metric_type=self.metric.name,
            dataset_name="dataset_name",
            dataset_type="dataset_type",
            dataset_config="fr",
            dataset_split="test",
            dataset_revision="abc",
            dataset_args={"a": 1, "b": 2},
            task_type="dummy-task",
            task_name="task_name",
            metric_config=self.metric.config_name,
            metric_args=self.args,
        )

        metadata_update.assert_called_once_with(repo_id="username/repo", metadata=extras_metadata, overwrite=False)

    def test_push_metric_invalid_task_type(self, metadata_update):
        with pytest.raises(ValueError):
            push_to_hub(
                model_id="username/repo",
                metric_value=self.result["accuracy"],
                metric_name="Pretty Metric Name",
                metric_type=self.metric.name,
                dataset_name="dataset_name",
                dataset_type="dataset_type",
                task_type="audio-classification",
            )

    def test_push_metric_invalid_dataset_type(self, metadata_update):
        with patch("evaluate.hub.dataset_info") as mock_dataset_info:
            mock_dataset_info.side_effect = requests.HTTPError()
            push_to_hub(
                model_id="username/repo",
                metric_value=self.result["accuracy"],
                metric_name="Pretty Metric Name",
                metric_type=self.metric.name,
                dataset_name="dataset_name",
                dataset_type="dataset_type",
                task_type="dummy-task",
            )

            assert "Dataset dataset_type not found on the Hub at hf.co/datasets/dataset_type" in self._caplog.text
            metadata_update.assert_called_once_with(
                repo_id="username/repo", metadata=minimum_metadata, overwrite=False
            )

    def test_push_metric_invalid_model_id(self, metadata_update):
        with patch("evaluate.hub.model_info") as mock_model_info:
            mock_model_info.side_effect = requests.HTTPError()
            with pytest.raises(ValueError):
                push_to_hub(
                    model_id="username/bad-repo",
                    metric_value=self.result["accuracy"],
                    metric_name="Pretty Metric Name",
                    metric_type=self.metric.name,
                    dataset_name="dataset_name",
                    dataset_type="dataset_type",
                    task_type="dummy-task",
                )


class ValidateYaml(TestCase):
    def setUp(self):
        pass

    def testLoadingCards(self):
        readme_filepaths = []
        for glob_path in ["measurements/*/README.md", "metrics/*/README.md", "comparisons/*/README.md"]:
            readme_filepaths.extend(glob.glob(glob_path))
        for readme_file in readme_filepaths:
            with open(readme_file, encoding="utf8") as f_yaml:
                x = yaml.safe_load_all(f_yaml)
                self.assertIsInstance(next(x), dict)
