import unittest
import torch as th
from torch import nn
from torch.autograd import Variable
from nn.background_v2 import UncertantyBackground


class UncertantyBackgroundTest(unittest.TestCase):
    def setUp(self):
        # Set up the parameters and inputs for the tests.
        self.masking_ratio = 0.5
        self.uncertainty_noise_ratio = 0.1
        self.motion_context_size = 128
        self.depth_context_size = 128
        self.latent_channels = 64
        self.num_layers = 6
        self.hyper_channels = 32
        self.hyper_layers = 3
        self.depth_input = False
        self.loci = False

        self.source = Variable(th.randn(10, 3, 224, 224)) # assuming input channels = 3, batch size = 1, image size = 224x224
        self.target = Variable(th.randn(10, 3, 224, 224)) # assuming input channels = 3, batch size = 1, image size = 224x224
        self.source_uncertainty = Variable(th.randn(10, 1, 224, 224)) # assuming input channels = 1, batch size = 1, image size = 224x224
        self.target_uncertainty = Variable(th.randn(10, 1, 224, 224)) # assuming input channels = 1, batch size = 1, image size = 224x224

    def test_initialization(self):
        # Test that the module can be successfully initialized.
        try:
            model = UncertantyBackground(
                self.masking_ratio,
                self.uncertainty_noise_ratio,
                self.motion_context_size,
                self.depth_context_size,
                self.latent_channels,
                self.num_layers,
                self.hyper_channels,
                self.hyper_layers,
                self.depth_input,
                self.loci
            )
        except Exception as e:
            self.fail(f"Initialization failed with exception: {e}")

    def test_forward(self):
        # Test that the forward function produces output of the correct shape
        model = UncertantyBackground(
            self.masking_ratio,
            self.uncertainty_noise_ratio,
            self.motion_context_size,
            self.depth_context_size,
            self.latent_channels,
            self.num_layers,
            self.hyper_channels,
            self.hyper_layers,
            self.depth_input,
            self.loci
        )

        try:
            rgb, depth, motion_context, depth_context = model(self.source, self.target, self.source_uncertainty, self.target_uncertainty)
            
            # Check the shape of the output tensors
            self.assertEqual(rgb.shape, (10, 3, 224, 224))
            self.assertEqual(depth.shape, (10, 1, 224, 224))
            self.assertEqual(motion_context.shape, (10, self.motion_context_size))
            self.assertEqual(depth_context.shape, (10, self.depth_context_size))
        except Exception as e:
            self.fail(f"Forward pass failed with exception: {e}")

    # Similarly, you can write tests for other methods as well.

if __name__ == "__main__":
    unittest.main()
