### Imports ###########################################################################################################

import pytest
import numpy as np
import torch
from gcontrol.guidance_controllers.controller_utils import GController
from fixtures.grad_preprocessors_fixtures import *

#######################################################################################################################

ATOL = 1e-8
RTOL = 1e-8

### Tests #############################################################################################################


# GController tests
def test_GController_instantiate():

    gcontroller = GController()
    assert isinstance(gcontroller, GController)


def test_GController_from_config():

    gcontroller1 = GController()
    gcontroller2 = GController.from_config(gcontroller1.config)
    assert isinstance(gcontroller2, GController)


def test_GController_None_dtype():

    gcontroller = GController()
    assert gcontroller.dtype == torch.float32


def test_GController_Parameter_dtype_register_parameter():

    class TestController(GController):

        def __init__(self, **kwargs):
            super().__init__(**kwargs)

            self.register_parameter(
                "f16", torch.nn.Parameter(torch.tensor(2, dtype=torch.float16), requires_grad=False)
            )
            self.register_parameter(
                "f32", torch.nn.Parameter(torch.tensor(2, dtype=torch.float32), requires_grad=False)
            )
            self.register_module("f64", torch.nn.Linear(10, 10, dtype=torch.float64))

    gcontroller = TestController()
    assert gcontroller.dtype == torch.float16


def test_GController_Parameter_dtype_self():

    class TestController(GController):

        def __init__(self, **kwargs):
            super().__init__(**kwargs)

            self.f16 = torch.nn.Parameter(torch.tensor(2, dtype=torch.float16), requires_grad=False)
            self.f32 = torch.nn.Parameter(torch.tensor(2, dtype=torch.float32), requires_grad=False)
            self.register_module("f64", torch.nn.Linear(10, 10, dtype=torch.float64))

    gcontroller = TestController()
    assert gcontroller.dtype == torch.float16


def test_GController_Module_dtype():

    class TestController(GController):

        def __init__(self, **kwargs):
            super().__init__(**kwargs)

            self.register_module("f64", torch.nn.Linear(10, 10, dtype=torch.float64))

    gcontroller = TestController()
    assert gcontroller.dtype == torch.float64


def test_GController_dtype_conversion():

    class TestController(GController):

        def __init__(self, **kwargs):
            super().__init__(**kwargs)

            self.register_parameter(
                "f16", torch.nn.Parameter(torch.tensor(2, dtype=torch.float16), requires_grad=False)
            )
            self.register_parameter(
                "f32", torch.nn.Parameter(torch.tensor(2, dtype=torch.float32), requires_grad=False)
            )
            self.register_module("f64", torch.nn.Linear(10, 10, dtype=torch.float64))

    gcontroller = TestController()
    gcontroller = gcontroller.to(dtype=torch.float32)
    assert gcontroller.dtype == torch.float32


def test_GController_cpu_device():

    gcontroller = GController()
    assert gcontroller.device == torch.device("cpu")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
def test_GController_cuda_device():

    class TestController(GController):

        def __init__(self, **kwargs):
            super().__init__(**kwargs)

            self.register_parameter(
                "f16", torch.nn.Parameter(torch.tensor(2, dtype=torch.float16), requires_grad=False)
            )
            self.register_parameter(
                "f32", torch.nn.Parameter(torch.tensor(2, dtype=torch.float32), requires_grad=False)
            )
            self.register_module("f64", torch.nn.Linear(10, 10, dtype=torch.float64))

    gcontroller = TestController()
    gcontroller = gcontroller.to("cuda")

    assert "cuda" in str(gcontroller.device)
    assert "cuda" in str(gcontroller.f16.device)
    assert "cuda" in str(gcontroller.f32.device)
    assert "cuda" in str(next(gcontroller.f64.parameters()).device)


#######################################################################################################################
