### Imports ###########################################################################################################

import pytest
import numpy as np
import torch
from gcontrol.guidance_controllers.controller_utils import GController
from gcontrol.guidance_controllers.common.classifier_free import ClassifierFreeGuidance
from fixtures.grad_preprocessors_fixtures import *

#######################################################################################################################

ATOL = 1e-8
RTOL = 1e-8

### Tests #############################################################################################################


# GController tests
def test_ClassifierFreeGuidannce_instantiate():

    cf_gcontroller = ClassifierFreeGuidance()
    assert isinstance(cf_gcontroller, ClassifierFreeGuidance)


def test_ClassifierFreeGuidannce_from_config():

    cf_gcontroller1 = GController()
    cf_gcontroller2 = GController.from_config(cf_gcontroller1.config)
    assert isinstance(cf_gcontroller2, GController)


@pytest.mark.parametrize(
    "u_noise, c_noise, in_shape, guidance_scale",
    [
        (
            torch.linspace(0, 100, 1 * 3 * 20 * 25, dtype=torch.float16),
            torch.linspace(200, 250, 1 * 3 * 20 * 25, dtype=torch.float16),
            (1, 3, 20, 25),
            6.2,
        ),
        (
            torch.linspace(0, 100, 5 * 20 * 25, dtype=torch.float32),
            torch.linspace(200, 250, 5 * 20 * 25, dtype=torch.float32),
            (5, 20, 25),
            3,
        ),
        (
            torch.linspace(0, 100, 20 * 25, dtype=torch.float64),
            torch.linspace(200, 250, 20 * 25, dtype=torch.float64),
            (20, 25),
            1.5,
        ),
    ],
)
def test_ClassifierFreeGuidance_forward(u_noise, c_noise, in_shape, guidance_scale):

    u_noise = torch.reshape(u_noise, in_shape)
    c_noise = torch.reshape(c_noise, in_shape)
    gtruth = u_noise + guidance_scale * (c_noise - u_noise)

    cf_gcontroller = ClassifierFreeGuidance()
    res = cf_gcontroller(None, u_noise, c_noise, guidance_scale)

    assert torch.allclose(res, gtruth, rtol=RTOL, atol=ATOL)
    assert res.shape == gtruth.shape
    assert res.dtype == gtruth.dtype


@pytest.mark.parametrize(
    "guidance_scale, gtruth",
    [(2, True), (1, False), (5.2, True), (7.8, True), (1.0, False)],
)
def test_ClassifierFreeGuidance_forward(guidance_scale, gtruth):

    cf_gcontroller = ClassifierFreeGuidance()

    assert cf_gcontroller.do_gcontrol(guidance_scale) == gtruth


#######################################################################################################################
