from abc import abstractmethod, ABC
from typing import List, Tuple, Callable, Optional

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

from dpconvcnp.data.data import SyntheticGenerator
from dpconvcnp.random import Seed, randn, randu
from dpconvcnp.utils import i32, f32, f64, to_tensor, cast

tfd = tfp.distributions


def _sawtooth_series(z: tf.Tensor, n: int):
    series_terms = [tf.sin(2 * m * np.pi * z) / m for m in range(1, n + 1)]
    return 2 / np.pi * tf.reduce_sum(series_terms, axis=0)


def _sawtooth(
    d: tf.Tensor,
    x: tf.Tensor,
    phi: tf.Tensor,
    freq: tf.Tensor,
) -> tf.Tensor:
    """
    Computes output of the sawtooth function given a normalised direction
    vector `d`, a tensor of inputs `x`, a phase parameter `phi` and a
    frequency `freq`.

    Arguments:
        d: Tensor of shape (batch_size, dim) containing normalised direction
            vectors.
        x: Tensor of shape (batch_size, num_ctx + num_trg, dim) containing
            the context and target inputs.
        phi: Tensor of shape (batch_size,) containing phase parameters.
        freq: Tensor of shape (batch_size,) containing frequencies.

    Returns:
        y: Tensor of shape (batch_size, num_ctx + num_trg, 1) containing
            the context and target outputs.
    """
    freq = freq[:, None, None]
    phi = (1 / freq) * phi[:, None, None]
    y = tf.einsum("bd, bnd -> bn", d, x)[:, :, None] + phi
    y = _sawtooth_series(y * freq, n=2)
    return y


def _tophat_series(z: tf.Tensor, n: int):
    series_terms = [
        tf.sin(4 * (2 * m + 1) * np.pi * z) / (2 * m + 1)
        for m in range(1, n + 1)
    ]
    return 4 / np.pi * tf.reduce_sum(series_terms, axis=0)


def _tophat(
    d: tf.Tensor,
    x: tf.Tensor,
    phi: tf.Tensor,
    freq: tf.Tensor,
) -> tf.Tensor:
    """
    Computes output of the tophat function given a normalised direction
    vector `d`, a tensor of inputs `x`, a phase parameter `phi` and a
    frequency `freq`.

    Arguments:
        d: Tensor of shape (batch_size, dim) containing normalised direction
            vectors.
        x: Tensor of shape (batch_size, num_ctx + num_trg, dim) containing
            the context and target inputs.
        phi: Tensor of shape (batch_size,) containing phase parameters.
        freq: Tensor of shape (batch_size,) containing frequencies.

    Returns:
        y: Tensor of shape (batch_size, num_ctx + num_trg, 1) containing
            the context and target outputs.
    """
    freq = freq[:, None, None]
    phi = phi[:, None, None]
    y = (tf.einsum("bd, bnd -> bn", d, x)[:, :, None] + phi) / (2 / freq)
    y = _tophat_series(y, n=4)  # y = 2.0 * ((y // (0.5 / freq)) % 2) - 1.0
    return y


class WaveformGenerator(SyntheticGenerator, ABC):
    def __init__(
        self,
        *,
        waveform_func: str,
        min_frequency: float,
        max_frequency: float,
        noise_std: float,
        dim: int,
        smooth_power: Optional[int] = 4,
        **kwargs,
    ):
        super().__init__(**kwargs)

        if waveform_func == "sawtooth":
            self.waveform_func = _sawtooth

        elif waveform_func == "tophat":
            self.waveform_func = _tophat

        else:
            raise ValueError(f"Unknown waveform function: {waveform_func}")

        assert smooth_power % 2 == 0

        self.min_frequency = min_frequency
        self.max_frequency = max_frequency
        self.noise_std = noise_std
        self.dim = dim
        self.smooth_power = smooth_power

    def sample_outputs(
        self,
        seed: Seed,
        x: tf.Tensor,
    ) -> Tuple[Seed, tf.Tensor, Callable]:
        """Sample context and target outputs, given the inputs `x`.

        Arguments:
            seed: Random seed.
            x: Tensor of shape (batch_size, num_ctx + num_trg, dim) containing
                the context and target inputs.

        Returns:
            seed: Random seed generated by splitting.
            y: Tensor of shape (batch_size, num_ctx + num_trg, 1) containing
                the context and target outputs.
        """

        B, N, D = x.shape

        # Draw normalised direction vector
        seed, d = randn(
            shape=(B, D),
            seed=seed,
            mean=tf.zeros((B, D), dtype=x.dtype),
            stddev=tf.ones((B, D), dtype=x.dtype),
        )
        d = (
            d
            / tf.reduce_sum(
                d**2.0 + 1e-6,
                axis=-1,
                keepdims=True,
            )
            ** 0.5
        )

        # Draw frequency
        seed, freq = randu(
            shape=(B,),
            seed=seed,
            minval=self.min_frequency * tf.ones((B,), dtype=x.dtype),
            maxval=self.max_frequency * tf.ones((B,), dtype=x.dtype),
        )

        # Draw phase parameter phi
        seed, phi = randu(
            shape=(B,),
            seed=seed,
            minval=tf.zeros((B,), dtype=x.dtype),
            maxval=tf.ones((B,), dtype=x.dtype),
        )

        # Sample observation noise
        seed, noise = randn(
            shape=(B, N, 1),
            seed=seed,
            mean=tf.zeros((B, N, 1), dtype=x.dtype),
            stddev=self.noise_std * tf.ones((B, N, 1), dtype=x.dtype),
        )

        # Use waveform function to generate outputs and add noise
        y = self.waveform_func(d, x, phi, freq) + noise

        return seed, y, None
