import unittest
from unittest import TestCase

import torch

from src.dl.fixed_point_solvers.fixed_point_iterator import fixed_point_iterator, truncated_fixed_point_iterator


class TestFixedPointIterator(TestCase):
    def test_fixed_point_iterator(self):
        # Define a simple discrete map. See if the iterator correctly iterates over this.
        f_map = lambda x: x + 1
        num_iters = 5
        x0 = 0

        correct_soln = x0 + num_iters

        iterator_output = fixed_point_iterator(f=f_map, x0=x0, num_iters=num_iters)["result"]
        self.assertEqual(iterator_output, correct_soln)

    def test_trunceted_fixed_point_iterator(self):
        # Define a simple discrete map.
        y = torch.tensor([1.])
        y.requires_grad = True
        alpha = 1.2
        f_map = lambda x: alpha * x * y

        # Set up the solver and run it. Make sure of two things 1) the forwaard pass computes the right value. 2)
        # the backwards pass computes the right value.
        num_iters = 15
        num_keep_grads_iters = 7
        num_no_grad_iters = 15 - 7
        z0 = torch.tensor([1.])

        z_out = truncated_fixed_point_iterator(f=f_map, x0=z0, num_iters=num_iters,
                                               num_keep_grads_iters=num_keep_grads_iters)['result']
        z_out.backward()

        # Check forward pass correctness.
        self.assertTrue(torch.allclose(z_out, (alpha * y) ** num_iters))

        # Check backward pass correctness.
        new_z0 = (alpha * y) ** num_no_grad_iters
        correct_grad = new_z0 * alpha ** num_keep_grads_iters * num_keep_grads_iters * y ** (num_keep_grads_iters - 1)
        self.assertTrue(torch.allclose(y.grad, correct_grad))

        # Check that keeping all and none of the fixed point iterations in the computational graph works.
        z_out = truncated_fixed_point_iterator(f=f_map, x0=z0, num_iters=num_iters,
                                               num_keep_grads_iters=num_iters)['result']

        z_out = truncated_fixed_point_iterator(f=f_map, x0=z0, num_iters=num_iters,
                                               num_keep_grads_iters=0)['result']


if __name__ == "__main__":
    """
    Run tests from root:
    python -m unittest -v src.dl.fixed_point_solvers.test_fixed_point_iterator
    """
    unittest.main()
