from abc import ABC, abstractmethod
from typing import Tuple, List

import numpy as np
import torch
from torch import Tensor, tensor
from torch.autograd import grad
from torch.distributions import Uniform

from enum import Enum, auto


class ST(Enum):
    dirichlet = auto()
    none = auto()
    uniform = auto()
    gaussian = auto()
    mh_pdpinn = auto()
    pdpinn_euler = auto()
    pdpinn_noise = auto()
    importance_sampling = auto()
    it_pdpinn = auto()

    def __str__(self):
        return self.name


class CircleDistribution(ABC, torch.nn.Module):
    """

    """

    def __init__(self, radius, time_step=None, dtype=torch.float32, device="cpu"):
        super().__init__()
        self.dtype = dtype
        self.device = device

        self.radius = radius
        self.time_step = time_step

    def sample(self, event_shape: Tuple[int, ...], over_time=False) -> Tuple[Tensor, Tensor]:
        sampled_radius, sampled_theta, sampled_time = self.sample_polar(event_shape, over_time)
        cartesian_coords = self.to_cartesian(sampled_radius, sampled_theta, sampled_time)
        return cartesian_coords, sampled_radius

    def to_cartesian(self, sampled_radius, sampled_theta, sampled_time):
        cartesian_coords = torch.stack(((sampled_radius * torch.cos(sampled_theta)),
                                        (sampled_radius * torch.sin(sampled_theta)),
                                        sampled_time), -1)
        return cartesian_coords

    @abstractmethod
    def sample_polar(self, event_shape: Tuple[int, ...], over_time=False) -> Tuple[Tensor, Tensor, Tensor]:
        ...


class CircleDistributionMC(CircleDistribution):
    """

    """
    __doc__ += CircleDistribution.__doc__

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.rv_radius = Uniform((tensor(0, dtype=self.dtype)),
                                 (tensor(self.radius, dtype=self.dtype)))
        self.rv_theta = Uniform((tensor(0, dtype=self.dtype)),
                                (tensor(2 * np.pi, dtype=self.dtype)))

        if self.time_step is not None:
            self.rv_time = Uniform(tensor(0, dtype=self.dtype),
                                   tensor(self.time_step, dtype=self.dtype))
        else:
            self.rv_time = None

    def sample_polar(self, event_shape: Tuple[int, ...], over_time=False) -> Tuple[Tensor, Tensor, Tensor]:
        """
                Draw a sample from the circle of given radius. Returns x,y coordinates.
        :param event_shape:
        :param over_time:
        :return:
        """
        sampled_radius = self.rv_radius.rsample(event_shape).to(self.device)
        sampled_theta = self.rv_theta.rsample(event_shape).to(self.device)
        # last dimension is time
        if over_time:
            sampled_time = self.rv_time.rsample(event_shape)
        else:
            sampled_time = torch.zeros(event_shape)
        sampled_time = sampled_time.to(self.device)

        return sampled_radius, sampled_theta, sampled_time


class CircleDistributionQMC(CircleDistribution):
    """

    """
    __doc__ += CircleDistribution.__doc__

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        if self.time_step is not None:
            sample_dim = 3
        else:
            sample_dim = 2

        self.sobol_engine = torch.quasirandom.SobolEngine(dimension=sample_dim, scramble=True)

    def sample_polar(self, event_shape: Tuple[int, ...], over_time=False) -> Tuple[Tensor, Tensor, Tensor]:
        """
                Draw a sample from the circle of given radius. Returns x,y coordinates.
        :param event_shape:
        :param over_time:
        :return:
        """
        samples = self.sobol_engine.draw(n=np.prod(event_shape), dtype=self.dtype).to(self.device)
        sampled_radius = samples[..., 0].reshape(event_shape) * self.radius
        sampled_theta = samples[..., 1].reshape(event_shape) * 2 * np.pi

        if self.time_step is not None:
            sampled_time = samples[..., 2] * self.time_step
        else:
            sampled_time = torch.zeros(event_shape)
        sampled_time = sampled_time.to(self.device)

        return sampled_radius, sampled_theta, sampled_time


def laplace(y, x, x_offset=1):
    grad = gradient(y, x)
    return divergence(grad, x, x_offset=x_offset)


def divergence(y, x, x_offset=1):
    div = 0.
    for i in range(y.shape[-1]):
        tmp_grad = grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0]
        div += tmp_grad[..., [i + x_offset]]
    return div


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad


def custom_hessian(y, x):
    num_observations = y.shape[0]
    hess = torch.zeros(num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device)
    grad_y = torch.ones_like(y[..., 0]).to(y.device)

    for output_i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        dydx = torch.autograd.grad(y[..., output_i], x, grad_y, create_graph=True)[0]

        # calculate hessian on y for each x value
        for dim_j in range(x.shape[-1]):
            hess[..., output_i, dim_j, :] = grad(dydx[..., dim_j], x, grad_y, create_graph=True)[0][..., :]

    status = 0
    if torch.any(torch.isnan(hess)):
        status = -1
    return hess, status


def hessian(y, x):
    ''' hessian of y wrt x
    y: shape (meta_batch_size, num_observations, channels)
    x: shape (meta_batch_size, num_observations, 2)
    '''
    meta_batch_size, num_observations = y.shape[:2]
    grad_y = torch.ones_like(y[..., 0]).to(y.device)
    h = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device)
    for i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        dydx = grad(y[..., i], x, grad_y, create_graph=True)[0]

        # calculate hessian on y for each x value
        for j in range(x.shape[-1]):
            h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][..., :]

    status = 0
    if torch.any(torch.isnan(h)):
        status = -1
    return h, status
