"""Test to check that acquisition from the base model works as expected."""


import unittest

import torch
import torch.nn as nn

from models.base import BaseModel



class TestAcquisition(unittest.TestCase):

  def test_no_missing_data(self):
    config = {
      "num_con_features": 5,
      "num_cat_features": 0,
      "most_categories": 0,
      "out_dim": 10,
      "max_dim": None
    }
    model = BaseModel(config)
    model.dummy_param = nn.Parameter(torch.randn(1))  # Used so model has a device.

    mask_acq = torch.tensor([
      [1.0, 1.0, 1.0, 1.0, 0.0],
      [0.0, 0.0, 0.0, 0.0, 1.0],
      [1.0, 0.0, 1.0, 0.0, 0.0],
      [0.0, 0.0, 0.0, 0.0, 0.0]
    ])

    x = torch.randn_like(mask_acq)
    mask_data = torch.ones_like(mask_acq)

    # First acquisition.
    mask_acq_next = model.acquire(x, mask_acq, mask_data)
    change = torch.sum(mask_acq_next - mask_acq, dim=-1)
    self.assertTrue(torch.all(change == torch.tensor([1.0, 1.0, 1.0, 1.0])))
    mask_acq = mask_acq_next
    # Second acquisition.
    mask_acq_next = model.acquire(x, mask_acq, mask_data)
    change = torch.sum(mask_acq_next - mask_acq, dim=-1)
    self.assertTrue(torch.all(change == torch.tensor([0.0, 1.0, 1.0, 1.0])))
    mask_acq = mask_acq_next
    # Third acquisition.
    mask_acq_next = model.acquire(x, mask_acq, mask_data)
    change = torch.sum(mask_acq_next - mask_acq, dim=-1)
    self.assertTrue(torch.all(change == torch.tensor([0.0, 1.0, 1.0, 1.0])))
    mask_acq = mask_acq_next
    # Fourth acquisition.
    mask_acq_next = model.acquire(x, mask_acq, mask_data)
    change = torch.sum(mask_acq_next - mask_acq, dim=-1)
    self.assertTrue(torch.all(change == torch.tensor([0.0, 1.0, 0.0, 1.0])))
    mask_acq = mask_acq_next
    # Fifth acquisition.
    mask_acq_next = model.acquire(x, mask_acq, mask_data)
    change = torch.sum(mask_acq_next - mask_acq, dim=-1)
    self.assertTrue(torch.all(change == torch.tensor([0.0, 0.0, 0.0, 1.0])))
    mask_acq = mask_acq_next
    # Final acquisition.
    mask_acq_next = model.acquire(x, mask_acq, mask_data)
    change = torch.sum(mask_acq_next - mask_acq, dim=-1)
    self.assertTrue(torch.all(change == torch.tensor([0.0, 0.0, 0.0, 0.0])))
    mask_acq = mask_acq_next

    self.assertTrue(torch.all(mask_acq == 1.0))

  def test_missing_data(self):
    config = {
      "num_con_features": 5,
      "num_cat_features": 0,
      "most_categories": 0,
      "out_dim": 10,
      "max_dim": None
    }
    model = BaseModel(config)
    model.dummy_param = nn.Parameter(torch.randn(1))  # Used so model has a device.

    mask_data = torch.tensor([
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [1.0, 0.0, 0.0, 1.0, 1.0],
      [1.0, 1.0, 0.0, 1.0, 1.0],
      [0.0, 0.0, 0.0, 0.0, 0.0],
    ])

    mask_acq = torch.tensor([
      [1.0, 1.0, 0.0, 0.0, 0.0],
      [0.0, 0.0, 1.0, 1.0, 1.0],
      [0.0, 0.0, 0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0, 0.0, 0.0]
    ])

    model.use_fixed_order = True
    model.fixed_order_scores = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0])

    x = torch.randn_like(mask_acq)


    # First acquisition.
    mask_acq = model.acquire(x, mask_acq, mask_data)
    mask_acq_true = torch.tensor([
      [1.0, 1.0, 0.0, 0.0, 1.0],
      [1.0, 0.0, 1.0, 1.0, 1.0],
      [0.0, 0.0, 0.0, 0.0, 1.0],
      [1.0, 0.0, 0.0, 0.0, 0.0]
    ])
    self.assertTrue(torch.all(mask_acq == mask_acq_true))
    # Second Acquisition.
    mask_acq = model.acquire(x, mask_acq, mask_data)
    mask_acq_true = torch.tensor([
      [1.0, 1.0, 0.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [0.0, 0.0, 0.0, 1.0, 1.0],
      [1.0, 1.0, 0.0, 0.0, 0.0]
    ])
    self.assertTrue(torch.all(mask_acq == mask_acq_true))
    # Third Acquisition.
    mask_acq = model.acquire(x, mask_acq, mask_data)
    mask_acq_true = torch.tensor([
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [0.0, 1.0, 0.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 0.0, 0.0]
    ])
    self.assertTrue(torch.all(mask_acq == mask_acq_true))
    # Fourth Acquisition.
    mask_acq = model.acquire(x, mask_acq, mask_data)
    mask_acq_true = torch.tensor([
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [1.0, 1.0, 0.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 1.0, 0.0]
    ])
    self.assertTrue(torch.all(mask_acq == mask_acq_true))
    # Fifth Acquisition.
    mask_acq = model.acquire(x, mask_acq, mask_data)
    mask_acq_true = torch.tensor([
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 1.0, 1.0],
      [1.0, 1.0, 1.0, 1.0, 1.0]
    ])
    self.assertTrue(torch.all(mask_acq == mask_acq_true))



if __name__ == "__main__":
  unittest.main()