import unittest
from unittest import TestCase
from functools import partial
import numpy as np
import math

import torch

from src.dl.fixed_point_solvers.anderson import anderson, normalized_l2_norm
from src.dl.fixed_point_solvers.anderson import M, LAM, THRESHOLD, EPS, BETA, DEFAULT_STOP_MODE
from src.utils.testing_utils.bilevel_quadratic_test_case import test_solver_correctness_using_bilevel_quadratic_problem


def _reference_anderson(f, x0, m=M, lam=LAM, threshold=THRESHOLD, eps=EPS, stop_mode=DEFAULT_STOP_MODE, beta=BETA,
                        **kwargs):
    if threshold <= 2:
        message = f"The minimum number of anderson iterations (threshold) must be at least 3. "
        raise ValueError(message)

    """ Anderson acceleration for fixed point iteration. """
    bsz, d, L = x0.shape
    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    X = torch.zeros(bsz, m, d * L, dtype=x0.dtype, device=x0.device)
    F = torch.zeros(bsz, m, d * L, dtype=x0.dtype, device=x0.device)
    X[:, 0], F[:, 0] = x0.reshape(bsz, -1), f(x0).reshape(bsz, -1)
    X[:, 1], F[:, 1] = F[:, 0], f(F[:, 0].reshape_as(x0)).reshape(bsz, -1)

    H = torch.zeros(bsz, m + 1, m + 1, dtype=x0.dtype, device=x0.device)
    H[:, 0, 1:] = H[:, 1:, 0] = 1
    y = torch.zeros(bsz, m + 1, 1, dtype=x0.dtype, device=x0.device)
    y[:, 0] = 1

    trace_dict = {'abs': [],
                  'rel': []}
    lowest_dict = {'abs': 1e8,
                   'rel': 1e8}
    lowest_step_dict = {'abs': 0,
                        'rel': 0}

    for k in range(2, threshold):
        n = min(k, m)
        G = F[:, :n] - X[:, :n]
        H[:, 1:n + 1, 1:n + 1] = torch.bmm(G, G.transpose(1, 2)) + lam * torch.eye(n, dtype=x0.dtype, device=x0.device)[
            None]
        alpha = torch.solve(y[:, :n + 1], H[:, :n + 1, :n + 1])[0][:, 1:n + 1, 0]  # (bsz x n)

        X[:, k % m] = beta * (alpha[:, None] @ F[:, :n])[:, 0] + (1 - beta) * (alpha[:, None] @ X[:, :n])[:, 0]
        F[:, k % m] = f(X[:, k % m].reshape_as(x0)).reshape(bsz, -1)
        gx = (F[:, k % m] - X[:, k % m]).view_as(x0)
        abs_diff = gx.norm().item()
        rel_diff = abs_diff / (1e-5 + F[:, k % m].norm().item())
        diff_dict = {'abs': abs_diff,
                     'rel': rel_diff}
        trace_dict['abs'].append(abs_diff)
        trace_dict['rel'].append(rel_diff)

        for mode in ['rel', 'abs']:
            if diff_dict[mode] < lowest_dict[mode]:
                if mode == stop_mode:
                    lowest_xest, lowest_gx = X[:, k % m].view_as(x0).clone().detach(), gx.clone().detach()
                lowest_dict[mode] = diff_dict[mode]
                lowest_step_dict[mode] = k

        if trace_dict[stop_mode][-1] < eps:
            for _ in range(threshold - 1 - k):
                trace_dict[stop_mode].append(lowest_dict[stop_mode])
                trace_dict[alternative_mode].append(lowest_dict[alternative_mode])
            break

    out = {"result": lowest_xest,
           "lowest": lowest_dict[stop_mode],
           "nstep": lowest_step_dict[stop_mode],
           "prot_break": False,
           "abs_trace": trace_dict['abs'],
           "rel_trace": trace_dict['rel'],
           "eps": eps,
           "threshold": threshold}
    X = F = None
    return out


class TestAnderson(TestCase):
    def test_check_if_torch_linear_solver_is_differentiable(self):
        """Check if torch.linalg.solve is differentiable."""
        # Set up a simple linear system.
        A = torch.tensor([[1., 0.], [0., 1.]])
        b = torch.tensor([[1.], [1.]])
        b.requires_grad = True

        # Solve it, and compute a loss using its output. Backprop.
        x = torch.linalg.solve(A, b)
        loss = 0.5 * torch.sum(x ** 2)
        loss.backward()

        # Check that the computed grad is the same as the analytically expected grad.
        computed_grad = b.grad
        correct_grad = torch.linalg.inv(A.T) @ torch.linalg.inv(A) @ b
        self.assertTrue(torch.allclose(computed_grad, correct_grad))

    def test_fixed_point_and_grads_on_bilevel_quadratic_use_ift(self):
        """Test correctness of fixed point and grads using bilevel quadratic problem. Use implicit FT. for gradients."""
        # Define the solver.
        solver = partial(anderson, m=5, threshold=2000, eps=1e-8, make_differentiable=True)

        # Test it, while using IFT to compute the gradients.
        correctness_dicts = test_solver_correctness_using_bilevel_quadratic_problem(solver_fn=solver,
                                                                                    use_backward_solver=True)
        for correctness_results in correctness_dicts:
            for k, v in correctness_results.items():
                self.assertTrue(v, msg=f"Failure to show {k}.")

    def test_fixed_point_and_grads_on_bilevel_quadratic_use_unrolled_gradients(self):
        """Test correctness of fixed point and grads using bilevel quadratic problem. Use unrolled gradients."""
        # Define the solver.
        solver = partial(anderson, m=5, threshold=2000, eps=1e-8, make_differentiable=True)

        # Test it, while backpropping through the solver.
        correctness_dicts = test_solver_correctness_using_bilevel_quadratic_problem(solver_fn=solver,
                                                                                    use_backward_solver=False)
        for correctness_results in correctness_dicts:
            for k, v in correctness_results.items():
                self.assertTrue(v, msg=f"Failure to show {k}.")

    def test_compare_with_reference_implementation(self):
        # Define a nontrivial discrete map and a starting point for running the solvers. Prepare the problem.
        min_num_iters = 8
        num_iters = 20
        z0 = torch.tensor([0.5])[None, None]

        def discrete_map(z):
            return 3 * z * (1 - z)

        # Run through reference anderson solver for a couple iterations. Record the iterates.
        reference_iterates = list()
        for i in range(min_num_iters, num_iters):
            curr_z = _reference_anderson(discrete_map, x0=z0, threshold=i)["result"]
            reference_iterates.append(curr_z)

        # Do the same for the anderson implemenation used in this codebase.
        test_iterates = list()
        for i in range(min_num_iters, num_iters):
            curr_z = anderson(discrete_map, x0=z0, threshold=i)["result"]
            test_iterates.append(curr_z)

        # Make sure that these numbers match.
        self.assertEqual(reference_iterates, test_iterates)

    def test_normalized_l2_norm(self):
        # Check that the normalized L2 norm is computed correctly on torch tensors of varying sizes.
        a1 = torch.arange(24).float()
        a2 = a1.view(2, 3, 4)
        a3 = a1.view(6, 4)
        correct_l2_norm_a = np.linalg.norm(a1.numpy(), ord=2)
        num_dims_a = 24
        correct_normalized_l2_norm_sqrt = correct_l2_norm_a / math.sqrt(num_dims_a)
        correct_normalized_l2_norm_non_sqrt = correct_l2_norm_a / num_dims_a

        test_cases = [a1, a2, a3]
        for a in test_cases:
            self.assertEqual(normalized_l2_norm(a, normalize_by_sqrt=True), correct_normalized_l2_norm_sqrt)
            self.assertEqual(normalized_l2_norm(a, normalize_by_sqrt=False), correct_normalized_l2_norm_non_sqrt)


if __name__ == "__main__":
    """
    Use the following command to run the tests. 
        python -m unittest -v src.dl.fixed_point_solvers.test_anderson

    """
    unittest.main()
