"""Contains the DiffEqs used for experiments as torch datasets
"""
import time
import torch
import numpy as np

from scipy import io
from pyDOE import lhs
from abc import abstractmethod
from scipy.integrate import solve_ivp
from error_correction.utils import logger
from .diffoperators import gradient, laplace
from torch.utils.data import Dataset, DataLoader


class DifferentialEquation:
    """Base class for all differential equations
    """
    def __init__(self) -> None:
        self.has_bcs = False
        self.has_ics = False

    @abstractmethod
    def sample_init(self, size):
        """Samples points for initial condition
        """
        raise NotImplementedError

    @abstractmethod
    def initial_condition(self, x, ic_estimate):
        """Computes error of some estimate to a specified initial condition
        """
        raise NotImplementedError

    @abstractmethod
    def sample_boundary(self, size):
        """Samples points for the boundary condition
        """
        raise NotImplementedError

    @abstractmethod
    def boundary_condition(self, x, bnd_estimate):
        """Returns difference of estimate from exact BC
        """
        raise NotImplementedError

    @abstractmethod
    def operator(self, x, u):
        """Implements the F[] operator of the diffeq
        """
        raise NotImplementedError
    
    def dataloader(self, **kwargs):
        return DataLoader(self, **kwargs)


class NonlinearPBE(DifferentialEquation, Dataset):
    r"""Implements a form of the nonlinear Poisson-Boltzmann equation with operator:

        F[] = laplacian[] - sinh[] + ff

    where ff is the forcing function that makes F[\phi] = 0
    """
    def __init__(self, hparams, mode="train"):
        super().__init__()
        self.hparams = hparams
        self.mode = mode
        self.dims = hparams.data.dims  
        self.slice_dims = hparams.data.slice_dims  # for visualization purposes
        self.N = hparams.data.resolution  # for testing & visualization purposes
        self.lims = [eval(n) for n in hparams.data.lims]  # domain limits
        self.omega = hparams.data.omega  # solution frequency parameter
        self.n_test_points = hparams.data.n_test_points

        # must set these attributes
        self.has_bcs = hparams.data.has_bcs
        self.has_ics = hparams.data.has_ics
        hparams.model.dim.input = self.dims
        hparams.model.dim.output = 1

        # generates fixed test set. additionally concatenates
        # a visualization set on a regular grid
        if mode == "test":
            self.domain_lims = [self.lims, self.lims]  # for visualization
            self.domain_shape = (self.N, self.N)

            # trick: stack visualization slice onto randomly generated 'real' test data
            visualization_data = self.slice_data_2D()
            self.visualization_sol = self.phi(visualization_data)
            rand_data = torch.rand(self.n_test_points, self.dims)
            rand_data = (self.lims[1] - self.lims[0]) * rand_data + self.lims[0]
            self.data = torch.cat([visualization_data, rand_data])
            self.sol_data = self.phi(self.data)

            # for visualization
            self.xlabel = f"$x_{self.slice_dims[0]}$"
            self.ylabel = f"$x_{self.slice_dims[1]}$"

    def slice_data_2D(self):
        """Get data in a uniformly spaced grid with only 2 dimensions populated

        NOTE: FOR VISUALIZATION PURPOSES ONLY
        """
        xs = torch.linspace(*self.lims, self.N)
        ys = torch.linspace(*self.lims, self.N)
        Y, X = torch.meshgrid(xs, ys, indexing='ij')
        points = torch.stack([X, Y], dim=-1).reshape(-1, 2)

        data = torch.ones(self.N * self.N, self.dims) * (np.pi / 2)
        data[..., self.slice_dims] = points

        return data

    def phi(self, x):
        """Implements the solution f(x_1, ..., x_D) = sin(w*x_1) * .... * sin(w*x_D) 
        """
        return torch.sin(self.omega * x).prod(dim=-1)

    def forcing_fn(self, x):
        """Implements the forcing function 
        """
        phi = self.phi(x)
        return -torch.sinh(phi) - self.dims * (self.omega**2) * phi

    def operator(self, x, phi_hat):
        """Applies the differential operator of nPBE to a differntiable solution estimate
        """
        laplacian = laplace(phi_hat, x).squeeze()
        return -laplacian + torch.sinh(phi_hat) + self.forcing_fn(x)

    def sample_boundary(self, size):
        """Samples points along faces of hypercube domain
        """
        points = (self.lims[1]-self.lims[0]) * torch.rand(size, self.dims) + self.lims[0]
        step = size // self.dims
        for i, j in zip(range(self.dims), range(0, size, step)):
            dimsize = step if i < self.dims - 1 else size - j
            points[j: j + dimsize // 2, i] = self.lims[0]
            points[j + dimsize // 2: j + dimsize, i] = self.lims[1]
        return points

    def boundary_condition(self, _, bnd_estimate):
        """Computes solution estimate deviation from the boundary conditions
        """
        # for the nPBE, zero-valued homogenous BCs => epsilon = (phi_boundary - 0)
        return bnd_estimate

    def __len__(self):
        """Programmatically define the length as this a randomly sampled dataset
        """
        return (
            self.hparams.batch_size if self.mode == 'train'  # 1 epoch == 1 randomly generated batch
            else self.N * self.N + self.n_test_points  
        )

    def __getitem__(self, idx):
        """Randomly sample a point from the domain
        """
        coord = (self.lims[1] - self.lims[0]) * torch.rand(self.dims) + self.lims[0]
        x = coord.float()

        if self.mode == 'test':
            return self.data[idx].float(), self.sol_data[idx].float()

        return x


class HenonHeiles(DifferentialEquation, Dataset):
    r"""Implements the Henon-Heiles chaotic ODE

        \dot{\vector{\phi}} = J @ \grad{H}

        H = 1/2 * (px^2 + py^2) + 1/2 * (x^2 + y^2) + \lambda * (x^2 * y - y^3/3)

        J = [[ 0  0  1  0]
             [ 0  0  0  1]
             [-1  0  0  0]
             [ 0 -1  0  0]]

        \vector{\phi} = [x y px py]
    """
    def __init__(self, hparams, mode="train"):
        super().__init__()
        self.hparams = hparams
        self.mode = mode
        self.dims = hparams.data.dims
        self.domain = [eval(n) for n in hparams.data.domain]
        self.lamb = hparams.data.L
        self.tol = hparams.data.tol
        self.eband = hparams.data.energy_band
        self.viz_coords = hparams.data.get('viz_coords', [0, 1])  # for visualization

        # must set these attributes
        self.has_bcs = hparams.data.has_bcs
        self.has_ics = hparams.data.has_ics
        hparams.model.dim.input = self.dims
        hparams.model.dim.output = 4

        ic = np.array(hparams.data.init_cond or self.generate_ic())
        if self.mode == 'train':
            logger.info(f"[{self.hparams.data.name.upper()}] initial condition: {ic} | Hamiltonian = {self.hamiltonian(*ic):.4f}")
        self.initial_cond = torch.from_numpy(ic)

        if mode == "train":
            pass  # randomly generated train data
        elif mode == "test": 
            self.xlabel, self.ylabel = [["$x$", "$y$", "$p_x$", "$p_y$"][i] for i in self.viz_coords]
            self.time, sol_data = self.solve(self.initial_cond.clone())
            self.domain_shape = self.time.shape
            self.domain_lims = self.domain
            self.n_test_points = len(self.time)
            self.data = torch.from_numpy(self.time).unsqueeze(-1)
            self.sol_data = torch.from_numpy(sol_data)
        else:
            raise NotImplementedError

    def generate_ic(self):
        """
        Generate numerically stable initial condition (some initial conditions in phase space
        generate orbits that shoot off to infinity in finite time)

        It is known that when the Hamiltonian of an initial condition is less than 1/6, then the
        Henon-Heiles equations of motion produce a bounded orbit.

        See: https://jfuchs.hotell.kau.se/kurs/amek/prst/11_hehe.pdf        
        """
        if self.mode == "train":
            logger.info(f"[{self.hparams.data.name.upper()}] Sampling valid initial conditions...")
        rng = np.random.default_rng(self.hparams.random_seed)
        ic = np.ones(self.hparams.model.dim.output)
        while not (
            (self.hamiltonian(*ic) > self.eband[0] and self.hamiltonian(*ic) < self.eband[1]) and  # chaotic energy regime
            (np.abs(ic[0]) < np.sqrt(3)/2) and  # x-coord lies within the 1/6 energy triangle 
            (-0.5 < ic[1] and ic[1] < (1 - np.sqrt(3)*np.abs(ic[0])))  # y-coord also lies within the 1/6 energy triangle
        ):
            ic = rng.standard_normal(self.hparams.model.dim.output)
        return ic

    def hamiltonian(self, x, y, px, py):
        """Henon-Heiles hamiltonian
        """
        return 0.5 * (x*x + y*y + px*px + py*py) + self.lamb * (x*x*y - (y**3)/3)

    def solve(self, init_cond):
        """Groundtruth data via ODE solver
        """
        start_time = time.time()
        sol_data = solve_ivp(
            self.phi_dot, self.domain, init_cond, rtol=self.tol, 
            method=self.hparams.data.gt_solver, max_step=self.hparams.data.max_step
        )
        solve_time = time.time() - start_time

        if not sol_data.success:
            raise Exception(
                f"""[{self.hparams.data.name.upper()}] solve_ivp with method '{self.hparams.data.gt_solver}'"""
                f""" failed with initial condition {init_cond.numpy()} |"""
                f""" Hamiltonian(IC) = {self.hamiltonian(*init_cond)} |"""
                f""" Please try running again with a different seed."""
            )
        logger.info(f"[{self.hparams.data.name.upper()}] solve_ivp ({self.hparams.data.gt_solver}) time: {solve_time:.4f}s")

        return sol_data.t, sol_data.y.T

    def sample_init(self, size):
        """Samples point for initial condition
        """
        return torch.zeros(size, 1)

    def initial_condition(self, _, ic_estimate):
        """Computes error of some estimate to a specified initial condition
        """
        return (ic_estimate - self.initial_cond)

    def phi_dot(self, _, phi):
        """Computes exact derivatives of the dynamical coordinates

        dummy variable is for the scipy odeint solver
        """
        if len(phi.shape) > 1:  # for pytorch
            x, y, px, py = phi.T
            deriv = torch.stack([px, py, -x-2*self.lamb*x*y, -y-self.lamb*(x**2-y**2)], dim=-1)
        else: 
            x, y, px, py = phi
            deriv = [px, py, -x-2*self.lamb*x*y, -y-self.lamb*(x**2-y**2)]
        return deriv

    def operator(self, t, phi):
        """Computes the residual of the model estimate
        """
        model_phi_dot = torch.stack([gradient(coord, t).squeeze() for coord in phi.T], dim=-1)
        exact_phi_dot = self.phi_dot(None, phi)
        return model_phi_dot - exact_phi_dot

    def reparameterize(self, phi, t, error=False):
        """Reparameterize the output to be output(t) * (1-exp(-t)) + initial_condition
        """
        return (1 - torch.exp(-t))*phi + (not error)*self.initial_cond.to(phi.device)

    def __len__(self):
        """Programmatically define the length as this a randomly sampled dataset
        """
        return (
            self.hparams.batch_size if self.mode == 'train'  # 1 epoch == 1 randomly generated batch
            else self.n_test_points  
        )

    def __getitem__(self, index):
        x = (self.domain[1]-self.domain[0]) * torch.rand(self.dims) + self.domain[0]

        if self.mode == 'test':
            return self.data[index].float(), self.sol_data[index].float()

        return x.float()


class NonlinearOscillator(DifferentialEquation, Dataset):
    r"""Implements a simple Hamiltonian system - a non-linear oscillator ODE

        \dot{\vector{\phi}} = J @ \grad{H}

        H = 1/2 * (x^2 + px^2) + 1/4 * x^4

        J = [[ 0  1]
             [-1  0]]

        \vector{\phi} = [x px]
    """
    def __init__(self, hparams, mode="train"):
        super().__init__()
        self.hparams = hparams
        self.mode = mode
        self.dims = hparams.data.dims
        self.domain = [eval(n) for n in hparams.data.domain]
        self.ps_domain = hparams.data.phase_space_domain
        self.tol = hparams.data.tol
        self.viz_coords = hparams.data.get('viz_coords', [0, 1])  # for visualization

        # must set these attributes
        self.has_bcs = hparams.data.has_bcs
        self.has_ics = hparams.data.has_ics
        hparams.model.dim.input = self.dims
        hparams.model.dim.output = 2

        ic = np.array(hparams.data.init_cond or self.generate_ic())
        if self.mode == 'train':
            logger.info(f"[{self.hparams.data.name.upper()}] initial condition: {ic} | Hamiltonian = {self.hamiltonian(*ic):.4f}")
        self.initial_cond = torch.from_numpy(ic)

        if mode == "train":
            pass  # randomly generate training data
        elif mode == "test": 
            self.xlabel, self.ylabel = ["$x$", "$p_x$"]
            self.time, sol_data = self.solve(self.initial_cond.clone())
            self.domain_shape = self.time.shape
            self.domain_lims = self.domain
            self.n_test_points = len(self.time)
            self.data = torch.from_numpy(self.time).unsqueeze(-1)
            self.sol_data = torch.from_numpy(sol_data)
        else:
            raise NotImplementedError

    def generate_ic(self):
        """Generate random initial condition
        """
        if self.mode == "train":
            logger.info(f"[{self.hparams.data.name.upper()}] Sampling initial conditions...")
        rng = np.random.default_rng(self.hparams.random_seed)
        return (self.ps_domain[1]-self.ps_domain[0])*rng.random(self.hparams.model.dim.output) + self.ps_domain[0]

    def hamiltonian(self, x, px):
        """Nonlinear oscillator hamiltonian
        """
        return (x**4)/4 + (x*x + px*px)/2

    def solve(self, init_cond):
        """Groundtruth data via ODE solver
        """
        start_time = time.time()
        sol_data = solve_ivp(
            self.phi_dot, self.domain, init_cond, rtol=self.tol, 
            method=self.hparams.data.gt_solver, max_step=self.hparams.data.max_step
        )
        solve_time = time.time() - start_time

        if not sol_data.success:
            raise Exception(
                f"""[{self.hparams.data.name.upper()}] solve_ivp with method '{self.hparams.data.gt_solver}'"""
                f""" failed with initial condition {init_cond.numpy()} |"""
                f""" Hamiltonian(IC) = {self.hamiltonian(*init_cond)} |"""
                f""" Please try running again with a different seed."""
            )
        logger.info(f"[{self.hparams.data.name.upper()}] solve_ivp ({self.hparams.data.gt_solver}) time: {solve_time:.4f}s")

        return sol_data.t, sol_data.y.T

    def sample_init(self, size):
        """Samples point for initial condition
        """
        return torch.zeros(size, 1)

    def initial_condition(self, _, ic_estimate):
        """Computes error of some estimate to a specified initial condition
        """
        return (ic_estimate - self.initial_cond)

    def phi_dot(self, _, phi):
        """Computes exact derivatives of the dynamical coordinates

        dummy variable is for the scipy odeint solver
        """
        if len(phi.shape) > 1:  # for pytorch
            x, px = phi.T
            deriv = torch.stack([px, -x**3 - x], dim=-1)
        else:  # for the scipy lsoda solver
            x, px = phi
            deriv = [px, -x**3 - x]
        return deriv

    def operator(self, t, phi):
        """Computes the residual of the model estimate
        """
        model_phi_dot = torch.stack([gradient(coord, t).squeeze() for coord in phi.T], dim=-1)
        exact_phi_dot = self.phi_dot(None, phi)
        return model_phi_dot - exact_phi_dot

    def reparameterize(self, phi, t, error=False):
        """Reparameterize the output to be output(t) * (1-exp(-t)) + initial_condition
        """
        return (1 - torch.exp(-t))*phi + (not error)*self.initial_cond.to(phi.device)

    def __len__(self):
        """Programmatically define the length as this a randomly sampled dataset
        """
        return (
            self.hparams.batch_size if self.mode == 'train'  # 1 epoch == 1 randomly generated batch
            else self.n_test_points  
        )

    def __getitem__(self, index):
        x = (self.domain[1]-self.domain[0]) * torch.rand(self.dims) + self.domain[0]

        if self.mode == 'test':
            return self.data[index].float(), self.sol_data[index].float()

        return x.float()

