import numpy as np
from typing import Union
import sys
sys.path.append('./')
from est.models.Model import Model
from est.simulator.Stepper import EulerMaruyamaStepper


class Simulator(object):
    def __init__(self,
                 dim: int,
                 m: int,
                 x0: Union[float, np.ndarray],
                 n: int,
                 dt: float,
                 model: Model,
                 seed: int = None):
        """
        Class for simulating paths of diffusion (SDE) process
        :param dim: int, dimension of the SDE system
        :param m: int, dimension of the Wiener process
        :param x0: Union[float, np.ndarray], initial value of process
        :param n: int, number of time steps (path will be size n+1, as it contains S0)
        :param dt: float, time step size
        :param model: obj, the model
        :param seed: int, the random seed (used for reproducibility of experiments)
        """
        self._dim = dim
        self._m = m
        self._x0 = x0
        self._n = n
        self._dt = dt
        self._model = model

        self.set_seed(seed=seed)

    def set_seed(self, seed: int = None):
        np.random.seed(seed=seed)
        return self

    @property
    def model(self) -> Model:
        """ Access the underlying model """
        return self._model

    def sim_path(self, num_paths: int) -> np.ndarray:
        """
        Simulate N new path(s) of size n + 1
        :param num_paths: int, number of independent paths to simulate.
        :return: array, path(s) of process, num_paths*(n+1)*d
        """
        stepper = EulerMaruyamaStepper(model=self._model)
        path = self._init_path(path_shape=(num_paths, self._n + 1, self._dim))
        norms = np.random.multivariate_normal(mean=np.zeros(self._m), cov=np.identity(self._m), size=(num_paths, self._n))
        for k in range(num_paths):
            for i in range(self._n):
                path[k, i + 1, :] = stepper(t=i * self._dt, dt=self._dt, x=path[k, i, :], dZ=norms[k, i, :])
        return path

    def _init_path(self, path_shape: tuple):
        path = np.zeros(shape=path_shape)
        path[:, 0, :] = self._x0
        return path