import gym
import numpy as np
import os
from pathlib import Path
import unittest

from src.rllib.models.preprocessors import GenericPixelPreprocessor
from src.rllib.models.torch.modules.convtranspose2d_stack import \
    ConvTranspose2DStack
from src.rllib.utils.framework import try_import_torch, try_import_tf
from src.rllib.utils.images import imread

torch, nn = try_import_torch()
tf1, tf, tfv = try_import_tf()


class TestConvTranspose2DStack(unittest.TestCase):
    """Tests our ConvTranspose2D Stack modules/layers."""

    def test_convtranspose2d_stack(self):
        """Tests, whether the conv2d stack can be trained to predict an image.
        """
        batch_size = 128
        input_size = 1
        module = ConvTranspose2DStack(input_size=input_size)
        preprocessor = GenericPixelPreprocessor(
            gym.spaces.Box(0, 255, (64, 64, 3), np.uint8), options={"dim": 64})
        optim = torch.optim.Adam(module.parameters(), lr=0.0001)

        rllib_dir = Path(__file__).parent.parent.parent
        img_file = os.path.join(rllib_dir,
                                "tests/data/images/obstacle_tower.png")
        img = imread(img_file)
        # Preprocess.
        img = preprocessor.transform(img)
        # Make channels first.
        img = np.transpose(img, (2, 0, 1))
        # Add batch rank and repeat.
        imgs = np.reshape(img, (1, ) + img.shape)
        imgs = np.repeat(imgs, batch_size, axis=0)
        # Move to torch.
        imgs = torch.from_numpy(imgs)
        init_loss = loss = None
        for _ in range(10):
            # Random inputs.
            inputs = torch.from_numpy(
                np.random.normal(0.0, 1.0, (batch_size, input_size))).float()
            distribution = module(inputs)
            # Construct a loss.
            loss = -torch.mean(distribution.log_prob(imgs))
            if init_loss is None:
                init_loss = loss
            print("loss={}".format(loss))
            # Minimize loss.
            loss.backward()
            optim.step()
        self.assertLess(loss, init_loss)


if __name__ == "__main__":
    import pytest
    import sys

    sys.exit(pytest.main(["-v", __file__]))
