import numpy as np

from deepxde import config
from deepxde.utils import get_num_args, run_if_all_none

import tensorflow as tf
import deepxde as dde

backend_name = "tensorflow"

import abc


class Data(abc.ABC):
    """Data base class."""

    # def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
    #     """Return a list of losses, i.e., constraints."""
    #     raise NotImplementedError("Data.losses is not implemented.")
    #
    # def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
    #     """Return a list of losses for training dataset, i.e., constraints."""
    #     return self.losses(targets, outputs, loss_fn, inputs, model, aux=aux)
    #
    # def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None):
    #     """Return a list of losses for test dataset, i.e., constraints."""
    #     return self.losses(targets, outputs, loss_fn, inputs, model, aux=aux)

    @abc.abstractmethod
    def sample_test_points(self):
        """Return a test dataset."""


class Tuple(Data):
    """Dataset with each data point as a tuple.

    Each data tuple is split into two parts: input tuple (x) and output tuple (y).
    """

    def __init__(self, train_x, train_y, test_x, test_y):
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y

    def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
        return loss_fn(targets, outputs)

    def train_next_batch(self, batch_size=None):
        return self.train_x, self.train_y

    def sample_test_points(self):
        return self.test_x, self.test_y


class TimePDE(Data):
    """ODE or time-independent PDE solver.

    Args:
        geometry: Instance of ``Geometry``.
        pde: A global PDE or a list of PDEs. ``None`` if no global PDE.
        bcs: A boundary condition or a list of boundary conditions. Use ``[]`` if no
            boundary condition.
        num_domain (int): The number of training points sampled inside the domain.
        num_boundary (int): The number of training points sampled on the boundary.
        train_distribution (string): The distribution to sample training points. One of
            the following: "uniform" (equispaced grid), "pseudo" (pseudorandom), "LHS"
            (Latin hypercube sampling), "Halton" (Halton sequence), "Hammersley"
            (Hammersley sequence), or "Sobol" (Sobol sequence).
        anchors: A Numpy array of training points, in addition to the `num_domain` and
            `num_boundary` sampled points.
        exclusions: A Numpy array of points to be excluded for training.
        solution: The reference solution.
        num_test: The number of points sampled inside the domain for testing PDE loss.
            The testing points for BCs/ICs are the same set of points used for training.
            If ``None``, then the training points will be used for testing.
        auxiliary_var_function: A function that inputs `train_x` or `test_x` and outputs
            auxiliary variables.

    """

    def __init__(
            self,
            geometryxtime: dde.geometry.GeometryXTime,
            pde,
            ic_bcs,
            num_domain=0,
            num_boundary=0,
            num_initial=0,
            train_distribution="Sobol",
            anchors=None,
            exclusions=None,
            solution=None,
            num_test=None,
            auxiliary_var_function=None, ):
        self.geom = geometryxtime
        self.pde = pde
        self.bcs = ic_bcs if isinstance(ic_bcs, (list, tuple)) else [ic_bcs]

        self.num_domain = num_domain
        self.num_boundary = num_boundary
        self.num_initial = num_initial

        if train_distribution not in [
            "uniform",
            "pseudo",
            "LHS",
            "Halton",
            "Hammersley",
            "Sobol",
        ]:
            raise ValueError(
                "train_distribution == {} is not available choices.".format(
                    train_distribution
                )
            )
        self.train_distribution = train_distribution
        self.anchors = None if anchors is None else anchors.astype(config.real(np))
        self.exclusions = exclusions

        self.soln = solution
        self.num_test = num_test

        self.auxiliary_var_fn = auxiliary_var_function

        # init everything
        self.train_x_boundary, self.train_x_domain, self.train_x_initial = None, None, None
        self.train_y_boundary, self.train_y_domain, self.train_y_initial = None, None, None

        self.test_x, self.test_y = None, None
        self.num_bcs = len(self.bcs)
        self.resample_train()
        self.resample_test()

    def get_pde_x_train(self):
        return np.concatenate([self.train_x_initial, self.train_x_domain, self.train_x_boundary], 0).astype(np.float32)

    def get_data_x_train(self):
        """
        Return anchors, i.e. the observed x points that we directly train on.
        :return:
        """
        init_points = self.sample_initial_points()

        if self.anchors is not None:
            return np.concatenate([init_points, self.anchors], 0)
        else:
            return init_points
            # return np.empty((0, self.geom.dim), dtype=np.float32)

    def get_data_y_train(self):
        """
        Return y values of the anchors.
        :return:
        """
        if self.soln is not None:
            return self.soln(self.get_data_x_train())
        elif self.anchors_y:
            return self.anchors_y
        else:
            raise NotImplementedError("Found no y values and no solution function.")
        # return np.concatenate([self.train_y_initial, self.train_y_domain, self.train_y_boundary], 0).astype(np.float32)


    # def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
    #     outputs_pde = outputs
    #
    #     f = []
    #
    #     if self.pde is not None:
    #         if get_num_args(self.pde) == 2:
    #             f = self.pde(inputs, outputs_pde)
    #         if not isinstance(f, (list, tuple)):
    #             f = [f]
    #     if not isinstance(loss_fn, (list, tuple)):
    #         loss_fn = [loss_fn] * (len(f) + len(self.bcs))
    #     elif len(loss_fn) != len(f) + len(self.bcs):
    #         raise ValueError(
    #             "There are {} errors, but only {} losses.".format(
    #                 len(f) + len(self.bcs), len(loss_fn)
    #             )
    #         )
    #
    #     bcs_start = np.cumsum([0] + self.num_bcs)
    #     bcs_start = list(map(int, bcs_start))
    #     error_f = [fi[bcs_start[-1]:] for fi in f]
    #     losses = [
    #         loss_fn[i](tf.zeros_like(error), error) for i, error in enumerate(error_f)
    #     ]
    #     for i, bc in enumerate(self.bcs):
    #         beg, end = bcs_start[i], bcs_start[i + 1]
    #         # The same BC points are used for training and testing.
    #         error = bc.error(self.train_x, inputs, outputs, beg, end)
    #         losses.append(loss_fn[len(error_f) + i](tf.zeros_like(error), error))
    #     return losses

    def resample_train(self):
        self.train_x_boundary = self.sample_boundary_points()
        self.train_x_domain = self.sample_domain_points()
        self.train_x_initial = self.sample_initial_points()

        if self.soln:
            self.train_y_boundary = self.soln(self.train_x_boundary)
            self.train_y_domain = self.soln(self.train_x_domain)
            self.train_y_initial = self.soln(self.train_x_initial)

    def resample_test(self):
        self.test_x, self.test_y = self.sample_test_points()

    def sample_test_points(self):
        if self.num_test is None:
            test_x = self.train_x_domain
        else:
            test_x = self.test_points()

        if self.soln:
            test_y = self.soln(test_x)
        else:
            test_y = None

        return test_x, test_y

    def add_anchors(self, anchors):
        """Add new points for training PDE losses. The BC points will not be updated."""
        anchors = anchors.astype(config.real(np))
        if self.anchors is None:
            self.anchors = anchors
        else:
            self.anchors = np.vstack((anchors, self.anchors))

    def replace_with_anchors(self, anchors):
        """Replace the current PDE training points with anchors. The BC points will not be changed."""
        self.anchors = anchors.astype(config.real(np))

    def sample_domain_points(self):
        if self.train_distribution == "uniform":
            return self.geom.uniform_points(self.num_domain, boundary=False)
        else:
            return self.geom.random_points(
                self.num_domain, random=self.train_distribution
            )

    def sample_boundary_points(self) -> np.array:
        """
        Sample points on boundary of domain.
        :return:
        """
        if self.num_boundary > 0:
            if self.train_distribution == "uniform":
                return self.geom.uniform_boundary_points(self.num_boundary)
            else:
                return self.geom.random_boundary_points(
                    self.num_boundary, random=self.train_distribution
                )
        else:
            return np.empty((0, self.geom.dim))

    def is_not_excluded(self, x):
        return not np.any([np.allclose(x, y) for y in self.exclusions])

    def exclude_points(self, X):
        # if self.exclusions is not None:
        return np.array(list(filter(self.is_not_excluded, X)))

    def test_points(self) -> np.array:
        # return np.vstack((self.sample_bc_ic_points(), self.geom.uniform_points(self.num_test, boundary=False)))
        return self.geom.uniform_points(self.num_test, boundary=False)

    def sample_initial_points(self):
        if self.train_distribution == "uniform":
            return self.geom.uniform_initial_points(self.num_initial)
        else:
            return self.geom.random_initial_points(
                self.num_initial, random=self.train_distribution
            )
