### Imports ###########################################################################################################

import pytest
import numpy as np
import torch
from timm import create_model
from gcontrol.utils.hugging_utils import parse_timm_preprocess_config, get_timm_config
from fixtures.grad_preprocessors_fixtures import *

#######################################################################################################################

ATOL = 1e-8
RTOL = 1e-8

### Tests #############################################################################################################


@pytest.mark.parametrize(
    "timm_config, expected_config",
    [
        (
            {
                "input_size": (3, 299, 299),
                "interpolation": "bicubic",
                "mean": (0.5, 0.5, 0.5),
                "std": (0.5, 0.5, 0.5),
                "crop_pct": 0.875,
                "crop_mode": "center",
            },
            {
                "size": (3, 299, 299),
                "image_mean": [0.5, 0.5, 0.5],
                "image_std": [0.5, 0.5, 0.5],
                "crop_pct": 0.875,
                "rescale_factor": 1 / 255.0,
                "do_normalize": True,
                "do_rescale": True,
                "do_resize": True,
            },
        ),
        (
            {
                "input_size": (3, 299, 299),
                "interpolation": "bicubic",
                "crop_mode": "center",
            },
            {
                "size": (3, 299, 299),
                "rescale_factor": 1 / 255.0,
                "do_normalize": False,
                "do_rescale": True,
                "do_resize": True,
            },
        ),
        (
            {
                "input_size": (3, 299, 299),
                "interpolation": "bicubic",
                "mean": (0.5, 0.5, 0.5),
                "crop_pct": 0.875,
                "crop_mode": "center",
            },
            {
                "size": (3, 299, 299),
                "image_mean": [0.5, 0.5, 0.5],
                "crop_pct": 0.875,
                "rescale_factor": 1 / 255.0,
                "do_normalize": False,
                "do_rescale": True,
                "do_resize": True,
            },
        ),
        (
            {
                "input_size": (3, 299, 299),
                "interpolation": "bicubic",
                "std": (0.5, 0.5, 0.5),
                "crop_pct": 0.875,
                "crop_mode": "center",
            },
            {
                "size": (3, 299, 299),
                "image_std": [0.5, 0.5, 0.5],
                "crop_pct": 0.875,
                "rescale_factor": 1 / 255.0,
                "do_normalize": False,
                "do_rescale": True,
                "do_resize": True,
            },
        ),
    ],
)
def test_timm_preprocess_config(timm_config, expected_config):

    config = parse_timm_preprocess_config(timm_config)

    assert len(config) == len(expected_config)

    for key, val in expected_config.items():
        assert val == config[key]


@pytest.mark.parametrize(
    "timm_config, expected_config",
    [
        (
            {
                "interpolation": "bicubic",
                "mean": (0.5, 0.5, 0.5),
                "std": (0.5, 0.5, 0.5),
                "crop_pct": 0.875,
                "crop_mode": "center",
            },
            {
                "size": (3, 299, 299),
                "image_mean": [0.5, 0.5, 0.5],
                "image_std": [0.5, 0.5, 0.5],
                "crop_pct": 0.875,
                "rescale_factor": 1 / 255.0,
                "do_normalize": True,
                "do_rescale": True,
                "do_resize": True,
            },
        ),
    ],
)
def test_timm_preprocess_config_exception(timm_config, expected_config):

    try:
        config = parse_timm_preprocess_config(timm_config)
        raise AssertionError
    except KeyError:
        pass


@pytest.mark.parametrize(
    "timm_model_name, expected_config",
    [
        (
            "inception_v4",
            {
                "size": (3, 299, 299),
                "image_mean": [0.5, 0.5, 0.5],
                "image_std": [0.5, 0.5, 0.5],
                "crop_pct": 0.875,
                "rescale_factor": 1 / 255.0,
                "do_normalize": True,
                "do_rescale": True,
                "do_resize": True,
            },
        ),
    ],
)
def test_timm_preprocess_config(timm_model_name, expected_config):

    model = create_model(
        model_name=timm_model_name,
        pretrained=True,
    )
    config = get_timm_config(model)

    assert len(config) == len(expected_config)

    for key, val in expected_config.items():
        assert val == config[key]


#######################################################################################################################
