from __future__ import annotations
import numpy as np
from abc import ABC, abstractmethod
from typing import Optional
from numpy.typing import NDArray
from dataclasses import dataclass


@dataclass
class TimeSeriesDataset:
    noisyObservations: NDArray
    groundTruth: NDArray
    nonstationary: NDArray


class DataGenerator(ABC):
    @abstractmethod
    def sample(self, num_steps: int) -> TimeSeriesDataset:
        pass


class ARGenerator(DataGenerator):
    def __init__(self, coefs: NDArray, noiseStd: float = 1.0) -> None:
        super().__init__()
        self.coefs = coefs
        self.order = coefs.size
        self.noiseStd = noiseStd

    def sample(self, num_steps: int) -> TimeSeriesDataset:
        groundTruth = np.empty(2 * num_steps)
        noisyObservations = np.empty(2 * num_steps)
        noisyObservations[: self.order] = self.noiseStd * np.random.randn(self.order)
        for i in range(self.order, groundTruth.size):
            groundTruth[i] = self.coefs @ noisyObservations[i - self.order : i]
            noisyObservations[i] = groundTruth[i] + self.noiseStd * np.random.randn()
        return TimeSeriesDataset(
            noisyObservations[num_steps:],
            groundTruth[num_steps:],
            np.zeros(num_steps),
        )


class SinusoidMixtureGenerator(DataGenerator):
    def __init__(
        self,
        numFundamental: int,
        numSeries: int,
        periodLower: float,
        periodUpper: float,
        noiseStd: float = 1.0,
        autoregressive: float = False,
        maxSlope: float = 0.0,
    ) -> None:
        super().__init__()
        self.numFundamental = numFundamental
        self.numSeries = numSeries
        self.noiseStd = noiseStd
        self.periods = np.random.uniform(periodLower, periodUpper, numFundamental)
        self.angularFreqs = 2 * np.pi / self.periods
        self.phases = np.random.uniform(0, 2 * np.pi, numFundamental)
        self.slopes = np.random.uniform(0.0, maxSlope, numFundamental)
        self.mixtureCoefs = np.random.randn(numFundamental, numSeries)
        self.autoregressive = autoregressive
        if autoregressive:
            self.arGenerators = [
                ARGenerator(np.array([-0.5]), noiseStd=self.noiseStd)
                for _ in range(numSeries)
            ]

    def sample(self, num_steps: int) -> TimeSeriesDataset:
        timeMatrix = np.tile(np.arange(num_steps), (self.numFundamental, 1)).T
        fundamentals = (
            np.sin(self.angularFreqs * timeMatrix + self.phases)
            + self.slopes * timeMatrix
        )
        mixtures = fundamentals @ self.mixtureCoefs
        f = np.array(mixtures)
        if self.autoregressive:
            arData = [gen.sample(num_steps) for gen in self.arGenerators]
            noisyMixtures = mixtures.copy()
            for i in range(self.numSeries):
                noisyMixtures[:, i] += arData[i].noisyObservations
                mixtures[:, i] += arData[i].groundTruth
        else:
            noisyMixtures = mixtures + self.noiseStd * np.random.randn(*mixtures.shape)
        return TimeSeriesDataset(noisyMixtures, mixtures, f)


class RealDataGenerator(DataGenerator):
    def __init__(
        self,
        dataPath: str,
        startIndexIncl: Optional[int] = None,
        endIndexExcl: Optional[int] = None,
        standardize: bool = False,
    ) -> None:
        super().__init__()
        data = np.load(dataPath)
        self.numSeries = data.shape[1]
        startIndexIncl = startIndexIncl or 0
        endIndexExcl = endIndexExcl or self.numSeries
        self.data = data[:, startIndexIncl:endIndexExcl]
        if standardize:
            self.data = (self.data - self.data.mean(axis=0)) / self.data.std(axis=0)

    def sample(self, num_steps: int) -> TimeSeriesDataset:
        assert num_steps <= self.data.shape[0], "I don't have that much data for you!"
        return TimeSeriesDataset(
            self.data[:num_steps], self.data[:num_steps], self.data[:num_steps]
        )
