"""Here we are testing the input layers. These are layers that take arbitrary
style inputs and prep them for the standard layers. We test by changing one
value where the mask is 1 and seeing if only that part of the output changes.

We also do the same if the mask is 0 and see if nothing changes.
"""


import unittest

import torch
import torch.nn as nn

from models.standard_layers import ContinuousInput, CategoricalInput, MixedInput


# Constants.
batchsize = 5
hidden_dim = 4

# Continuous constants.
num_con_features = 8

# Categorical constants.
num_cat_features =  6
most_categories = 10


def produce_continuous_outputs(batchsize, num_features, hidden_dim, mask_val):
  # Make outputs for a continuous input, the first value changes.
  continuous_input = ContinuousInput()
  lin = nn.Linear(2*num_features, hidden_dim)
  x1 = torch.randn((batchsize, num_features))
  x2 = x1.clone()
  x2[:, 0] = -1*x2[:, 0]
  mask = torch.bernoulli(0.5*torch.ones((batchsize, num_features)))
  mask[:, 0] = mask_val
  out1 = lin(continuous_input(x1, mask))
  out2 = lin(continuous_input(x2, mask))
  return out1, out2


def produce_categorical_outputs(batchsize, num_features, most_categories, hidden_dim, mask_val):
  # Make outputs for a categorical layer where the first value changes.
  categorical_input = CategoricalInput(most_categories)
  lin = nn.Linear(num_features*(most_categories+1), hidden_dim)
  x1 = (most_categories*torch.rand((batchsize, num_features))).int().float()
  x2 = x1.clone()
  x2[:, 0] = (x1[:, 0] + 1) % most_categories
  mask = torch.bernoulli(0.5*torch.ones((batchsize, num_features)))
  mask[:, 0] = mask_val
  out1 = lin(categorical_input(x1, mask))
  out2 = lin(categorical_input(x2, mask))
  return out1, out2


# Continuous input.
class TestContinuousInput(unittest.TestCase):

  def test_mask1(self):
    out1, out2 = produce_continuous_outputs(batchsize, num_con_features, hidden_dim, 1.0)
    self.assertFalse(torch.all(out1==out2))

  def test_mask0(self):
    out1, out2 = produce_continuous_outputs(batchsize, num_con_features, hidden_dim, 0.0)
    self.assertTrue(torch.all(out1 == out2))

  def test_values(self):
    con_input = ContinuousInput()
    x = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
    mask = torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]])
    true_out = torch.tensor([[0.0, 0.0, 2.0, 1.0, 0.0, 1.0], [0.0, 4.0, 0.0, 0.0, 1.0, 0.0]])
    test_out = con_input(x, mask)
    self.assertTrue(torch.allclose(test_out, true_out))


# Categorical input.
class TestCategoricalInput(unittest.TestCase):

  def test_mask1(self):
    out1, out2 = produce_categorical_outputs(batchsize, num_cat_features, most_categories, hidden_dim, 1.0)
    self.assertFalse(torch.all(out1 == out2))
  
  def test_mask0(self):
    out1, out2 = produce_categorical_outputs(batchsize, num_cat_features, most_categories, hidden_dim, 0.0)
    self.assertTrue(torch.all(out1 == out2))

  def test_values(self):
    most_categories = 3
    cat_input = CategoricalInput(most_categories)
    x = torch.tensor([[0.0, 2.0], [1.0, 1.0], [2.0, 0.0]])
    mask = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
    true_out = torch.tensor([[0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
                             [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
                             [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0]])
    test_out = cat_input(x, mask)
    self.assertTrue(torch.allclose(test_out, true_out))


class TestMixedInput(unittest.TestCase):

  def test_values(self):
    num_con = 3
    most_categories = 3
    x_con = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
    x_cat = torch.tensor([[0.0, 2.0], [1.0, 1.0]])
    m_con = torch.tensor([[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]])
    m_cat = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
    x = torch.cat([x_con, x_cat], dim=-1)
    mask = torch.cat([m_con, m_cat], dim=-1)
    true_out = torch.tensor([[0.0, 0.0, 2.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
                             [0.0, 4.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]])
    mixed_input = MixedInput(num_con, most_categories)
    test_out = mixed_input(x, mask)
    self.assertTrue(torch.allclose(test_out, true_out))


if __name__ == "__main__":
  unittest.main()