import abc
import unittest

import torch
from copy import deepcopy

import nflows.flows.base

import nfmc_jax.sinf.SINF
from nfmc_jax.flows.base import (
    SINFInterface,
    SNFInterface,
    MAFInterface,
    RealNVPInterface,
    RQNSFInterface,
    HierarchicalFlowInterface,
    TRENFInterface,
    HierarchicalSINFInterface,
    SimplifiedSNFInterface
)
from nfmc_jax.utils.torch_distributions import Funnel, TestDistribution1


class InterfaceTestCases:
    class InterfaceTestCase(unittest.TestCase, abc.ABC):
        n_dim: int = 2
        n_train_samples: int = 100
        device = torch.device('cpu')
        atol: float = 1e-4

        @abc.abstractmethod
        def test_create_flow(self):
            """
            Can we initialize and create the flow?
            """
            raise NotImplementedError

        def test_train_flow(self):
            """
            Can we train the flow?
            """
            torch.manual_seed(0)
            interface = self.make_interface_and_create_flow()
            x, weights = self.make_train_data()[:2]

            interface.train_flow(x=x, n_epochs=3)
            interface.train_flow(x=x, n_epochs=3, weights=weights)

        def test_forward_inverse(self):
            """
            Push data to latent space and back.
            We check whether:
            * the shapes in data and latent space are consistent,
            * the original data points are approximately equal to reconstructed data points,
            * the log determinant of the Jacobian for the forward pass is equal to the negative log determinant of the
                Jacobian for the inverse pass.
            """
            torch.manual_seed(0)
            interface = self.make_interface_and_create_flow()
            x, weights = self.make_train_data()[:2]

            z = interface.forward(x)
            x_reconstructed = interface.inverse(z)
            self.assertEqual(
                x.shape, z.shape,
                msg=f"x.shape = {x.shape}, z.shape = {z.shape}"
            )
            self.assertEqual(
                x.shape, x_reconstructed.shape,
                msg=f"x.shape = {x.shape}, x_reconstructed.shape = {x_reconstructed.shape}"
            )
            self.assertTrue(
                torch.allclose(x.to(x_reconstructed.device), x_reconstructed, atol=self.atol),
                msg=f"Max reconstruction error: {torch.max(torch.abs(x.to(x_reconstructed.device) - x_reconstructed))}"
                    f"Mean reconstruction error: {torch.mean(torch.abs(x.to(x_reconstructed.device) - x_reconstructed))}"
            )

            logj_forward = interface.logj_forward(x)
            logj_backward = interface.logj_backward(z)
            self.assertTrue(
                torch.allclose(logj_forward, -logj_backward, atol=self.atol),
                msg=f"Max difference: {torch.max(logj_forward - (-logj_backward))} | "
                    f"Mean difference: {torch.mean(logj_forward - (-logj_backward))}"
            )

        def test_grad(self):
            """
            We check whether grad_x_logq, grad_z_logp, grad_z_logj methods run.
            We check whether .grad and .requires_grad attributes are appropriately set.
            We check whether gradient shapes are consistent with data and latent point shapes.
            """
            torch.manual_seed(0)
            interface = self.make_interface_and_create_flow()
            x, weights = self.make_train_data()[:2]

            z = interface.forward(x)

            grad_x_logq = interface.grad_x_logq(x)
            self.assertIsNotNone(grad_x_logq)
            self.assertFalse(grad_x_logq.requires_grad)
            self.assertFalse(x.requires_grad)
            self.assertEqual(grad_x_logq.shape, x.shape)

            grad_z_logp = interface.grad_z_logp(z, grad_x_logq)
            self.assertIsNotNone(grad_z_logp)
            self.assertFalse(grad_z_logp.requires_grad)
            self.assertFalse(z.requires_grad)
            self.assertEqual(grad_z_logp.shape, z.shape)

            grad_z_logj = interface.grad_z_logj(z)
            self.assertIsNotNone(grad_z_logj)
            self.assertFalse(grad_z_logj.requires_grad)
            self.assertFalse(z.requires_grad)
            self.assertEqual(grad_z_logj.shape, z.shape)

        @abc.abstractmethod
        def make_interface_and_create_flow(self, *args, **kwargs):
            """
            Abstract helper method to create the interface and the flow. To be used in tests.
            """
            raise NotImplementedError

        def make_train_data(self):
            """
            Abstract helper method to create training data. Defaults to a Gaussian, but can be overwritten by
                subclasses. To be used in tests.
            """
            x = torch.randn(self.n_train_samples, self.n_dim) * 3 + 4
            weights = torch.rand(self.n_train_samples)
            return x, weights

        def test_no_modify_data(self):
            """
            Check that the flow does not modify data when doing forward and inverse operations.
            """
            torch.manual_seed(0)

            interface = self.make_interface_and_create_flow()
            x, weights = self.make_train_data()[:2]
            x_original = deepcopy(x)

            interface.forward(x)
            self.assertTrue(torch.equal(x, x_original))

            interface.inverse(interface.forward(x))
            self.assertTrue(torch.equal(x, x_original))

        def test_shape(self):
            """
            Another test for shape consistency among data points, latent points, and gradients.
            """
            torch.manual_seed(0)

            interface = self.make_interface_and_create_flow()
            x, weights = self.make_train_data()[:2]

            # Test shapes before training
            z = interface.forward(x)
            self.assertEqual(x.shape, z.shape)

            x_reconstructed = interface.inverse(z)
            self.assertEqual(x_reconstructed.shape, x.shape)

            grad_x_logq = interface.grad_x_logq(x)
            self.assertEqual(x.shape, grad_x_logq.shape)

            grad_z_logp = interface.grad_z_logp(z, grad_x_logq)
            self.assertEqual(z.shape, grad_z_logp.shape)

            grad_z_logj = interface.grad_z_logj(z)
            self.assertEqual(z.shape, grad_z_logj.shape)

            # Train the flow
            interface.train_flow(x=x, n_epochs=3)

            # Test shapes after training
            z = interface.forward(x)
            self.assertEqual(x.shape, z.shape)

            x_reconstructed = interface.inverse(z)
            self.assertEqual(x_reconstructed.shape, x.shape)

            grad_x_logq = interface.grad_x_logq(x)
            self.assertEqual(x.shape, grad_x_logq.shape)

            grad_z_logp = interface.grad_z_logp(z, grad_x_logq)
            self.assertEqual(z.shape, grad_z_logp.shape)

            grad_z_logj = interface.grad_z_logj(z)
            self.assertEqual(z.shape, grad_z_logj.shape)


class MAFTestCase(InterfaceTestCases.InterfaceTestCase):
    def test_create_flow(self):
        interface = MAFInterface(n_dim=self.n_dim, device=self.device)
        self.assertTrue(interface.flow is None)

        interface.create_flow()
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nflows.flows.base.Flow)

    def make_interface_and_create_flow(self):
        torch.manual_seed(0)
        interface = MAFInterface(n_dim=self.n_dim, device=self.device)
        interface.create_flow()
        return interface


class SINFTestCase(InterfaceTestCases.InterfaceTestCase):
    def test_create_flow(self):
        x, weights = self.make_train_data()

        interface = SINFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

        interface = SINFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, weights=weights, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

    def make_interface_and_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()
        interface = SINFInterface(device=self.device)
        interface.create_flow(x=x, iteration=3)
        return interface

    def test_train_flow(self):
        torch.manual_seed(0)
        interface = self.make_interface_and_create_flow()
        x, weights = self.make_train_data()

        interface.train_flow(x=x)
        interface.train_flow(x=x, weights=weights)

    def test_shape(self):
        pass


class SNFTestCase(InterfaceTestCases.InterfaceTestCase):
    def test_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()

        interface = SNFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

        interface = SNFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, weights=weights, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

    def test_create_flow_random_init(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()

        interface = SNFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, iteration=3, random_init=True)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

        interface = SNFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, weights=weights, iteration=3, random_init=True)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

    def make_interface_and_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()
        interface = SNFInterface(device=self.device)
        interface.create_flow(x=x, iteration=3)
        return interface


class HierarchicalSINFTestCase(InterfaceTestCases.InterfaceTestCase):
    def make_train_data(self):
        x = torch.randn(self.n_train_samples, self.n_dim) * 3 + 4
        weights = torch.rand(self.n_train_samples)
        rv_mask = torch.randint(low=0, high=self.n_dim, size=(self.n_dim,))
        rv_mask[0] = 0
        dag_edges = list(set([
            (
                int(torch.randint(low=0, high=self.n_dim, size=(1,))),
                int(torch.randint(low=0, high=self.n_dim, size=(1,)))
            )
            for _ in range(5)
        ]))
        return x, weights, rv_mask, dag_edges

    def test_create_flow(self):
        torch.manual_seed(0)
        x, weights, rv_mask, dag_edges = self.make_train_data()

        interface = HierarchicalSINFInterface(rv_mask=rv_mask, dag_edges=dag_edges, device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, iteration=3)
        self.assertTrue(interface.flow is None)

        interface = HierarchicalSINFInterface(rv_mask=rv_mask, dag_edges=dag_edges, device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, weights=weights, iteration=3)
        self.assertTrue(interface.flow is None)

    def make_interface_and_create_flow(self):
        torch.manual_seed(0)
        x, weights, rv_mask, dag_edges = self.make_train_data()
        interface = HierarchicalSINFInterface(rv_mask=rv_mask, dag_edges=dag_edges, device=self.device)
        interface.create_flow(x=x, iteration=3)
        return interface


class SimplifiedSNFTestCase(InterfaceTestCases.InterfaceTestCase):
    def test_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()

        interface = SimplifiedSNFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

        interface = SimplifiedSNFInterface(device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x=x, weights=weights, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nfmc_jax.sinf.SINF.SINF)

    def make_interface_and_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()
        interface = SimplifiedSNFInterface(device=self.device)
        interface.create_flow(x=x, iteration=3)
        return interface


class RealNVPTestCase(InterfaceTestCases.InterfaceTestCase):
    def test_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()

        interface = RealNVPInterface(n_dim=self.n_dim, device=self.device)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x_train=x, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nflows.flows.realnvp.SimpleRealNVP)

    def make_interface_and_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()
        interface = RealNVPInterface(n_dim=self.n_dim, device=self.device)
        interface.create_flow(x_train=x, iteration=3)
        return interface


class RQNSFTestCase(InterfaceTestCases.InterfaceTestCase):
    def test_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()

        interface = RQNSFInterface(n_dim=self.n_dim)
        self.assertTrue(interface.flow is None)
        interface.create_flow(x_train=x, iteration=3)
        self.assertTrue(interface.flow is not None)
        self.assertEqual(type(interface.flow), nflows.flows.base.Flow)

    def make_interface_and_create_flow(self):
        torch.manual_seed(0)
        x, weights = self.make_train_data()
        interface = RQNSFInterface(n_dim=self.n_dim)
        interface.create_flow(x_train=x, iteration=3)  # TODO get rid of this kwarg, it is meant for SINF.
        return interface


@unittest.skip("Need to implement make_interface_and_create_flow")
class TRENFTestCase(InterfaceTestCases.InterfaceTestCase):
    def test_create_flow(self):
        pass

    def make_interface_and_create_flow(self, *args, **kwargs):
        torch.manual_seed(0)
        x, weights = self.make_train_data()
        interface = TRENFInterface()
        interface.create_flow(ndim=x.shape[1], layers=...)  # TODO finish this
        return interface


class HierarchicalFlowTestCases:
    class HierarchicalFlowTestCase(InterfaceTestCases.InterfaceTestCase, abc.ABC):
        @abc.abstractmethod
        def prepare_case_data(self):
            # Return x_train, mask, edges, interfaces
            raise NotImplementedError

        def test_create_flow(self):
            x_train, mask, edges, interfaces = self.prepare_case_data()
            hierarchical_interface = HierarchicalFlowInterface(
                rv_mask=mask,
                dag_edges=edges,
                interfaces=interfaces
            )
            hierarchical_interface.create_flow()

        def make_interface_and_create_flow(self, *args, **kwargs):
            x_train, mask, edges, interfaces = self.prepare_case_data()
            hierarchical_interface = HierarchicalFlowInterface(
                rv_mask=mask,
                dag_edges=edges,
                interfaces=interfaces
            )
            hierarchical_interface.create_flow()
            return hierarchical_interface

        def make_train_data(self):
            x_train, mask, edges, interfaces = self.prepare_case_data()
            weights = torch.ones(len(x_train))
            return x_train, weights


class HierarchicalFunnelTestCase(HierarchicalFlowTestCases.HierarchicalFlowTestCase):
    def prepare_case_data(self):
        dist = Funnel(4)
        x_train = dist.sample(self.n_train_samples)
        mask = torch.tensor([0, 1, 1, 1])
        edges = [(0, 1)]
        interfaces = [
            MAFInterface(n_dim=1),
            MAFInterface(n_dim=3)
        ]
        interfaces[0].create_flow()
        interfaces[1].create_flow()
        return x_train, mask, edges, interfaces


class HierarchicalGaussianTestCase1(HierarchicalFlowTestCases.HierarchicalFlowTestCase):
    def prepare_case_data(self):
        x_train = torch.randn(self.n_train_samples, 6)
        mask = torch.tensor([0, 0, 1, 1, 2, 2])
        edges = []
        interfaces = [
            MAFInterface(n_dim=2),
            RealNVPInterface(n_dim=2),
            RQNSFInterface(n_dim=2)
        ]
        interfaces[0].create_flow()
        interfaces[1].create_flow()
        interfaces[2].create_flow()
        return x_train, mask, edges, interfaces


class HierarchicalGaussianTestCase2(HierarchicalFlowTestCases.HierarchicalFlowTestCase):
    def prepare_case_data(self):
        x_train = torch.randn(self.n_train_samples, 6)
        mask = torch.tensor([0, 1, 2, 3, 4, 5])
        edges = []
        interfaces = [
            MAFInterface(n_dim=1),
            RealNVPInterface(n_dim=1),
            RealNVPInterface(n_dim=1),
            MAFInterface(n_dim=1),
            RealNVPInterface(n_dim=1),
            RealNVPInterface(n_dim=1)
        ]
        interfaces[0].create_flow()
        interfaces[1].create_flow()
        interfaces[2].create_flow()
        interfaces[3].create_flow()
        interfaces[4].create_flow()
        interfaces[5].create_flow()
        return x_train, mask, edges, interfaces


class HierarchicalGaussianTestCase3(HierarchicalFlowTestCases.HierarchicalFlowTestCase):
    def prepare_case_data(self):
        x_train = torch.randn(self.n_train_samples, 8)
        mask = torch.tensor([0, 1, 1, 2, 3, 3, 4, 4])
        edges = []
        interfaces = [
            SNFInterface(),
            RQNSFInterface(n_dim=2),
            RealNVPInterface(n_dim=1),
            MAFInterface(n_dim=2),
            SNFInterface()
        ]
        interfaces[0].create_flow(x=x_train[:, 0].reshape(-1, 1), iteration=5)
        interfaces[1].create_flow()
        interfaces[2].create_flow()
        interfaces[3].create_flow()
        interfaces[4].create_flow(x=x_train[:, -2:].reshape(-1, 2), iteration=5)
        return x_train, mask, edges, interfaces


class HierarchicalTestDistribution1TestCase(HierarchicalFlowTestCases.HierarchicalFlowTestCase):
    def prepare_case_data(self):
        dist = TestDistribution1()
        x_train = dist.sample(self.n_train_samples)
        mask = torch.tensor([0, 1, 2, 3])
        edges = [
            (0, 1),
            (0, 2),
            (0, 3),
            (1, 2),
            (1, 3)
        ]
        interfaces = [
            SNFInterface(),
            RealNVPInterface(n_dim=1),
            MAFInterface(n_dim=1),
            SNFInterface()
        ]
        interfaces[0].create_flow(x=x_train[:, 0].reshape(-1, 1), iteration=5)
        interfaces[1].create_flow()
        interfaces[2].create_flow()
        interfaces[3].create_flow(x=x_train[:, 3].reshape(-1, 1), iteration=5)
        return x_train, mask, edges, interfaces


if __name__ == '__main__':
    unittest.main()
