import torch
from torch.utils.data import Dataset
import numpy as np

class SyntheticDataset(Dataset):
    """A dataset to run experiments with.

    Creates a dataset of torch tensors. Tensors must be loaded from numpy
    array files or passed in as numpy arrays.


    Attributes:
        self.x: The state as a torch.Tensor of shape (timesteps x x_dim).
        self.x_dot: The time derivative of the state as a torch.tensor of
            shape (timesteps x x_dim). The derivative is calculated using
            fourth order finite differences.
        self.x_dot_standard: A standardized version of self.x_dot:
            self.x_dot_standard = (self.x_dot - self.x_dot.mean(0)) / self.x_dot.std(0)
        self.x_lib: The SINDy library form of self.x, as a torch.Tensor of
            of shape (timesteps x library_dim).
        self.dataset: The name (str) of the dataset.
        self.t: A torch.tensor denoting each timepoint.
    """

    def __init__(self, library, dataset="lorenz", dt=0.01, model="HyperSINDy",
                 x=None, fpath=None, t=None, x_dot_noise=None):
        """Initializes the SyntheticDataset.

        Initializes the SyntheticDataset using the given parameters.

        Args:
            library: The SINDy library object (from library_utils)
                used to transform the data.
            dataset: A string for the name of the dataset. The default is lorenz.
            dt: The time between adjacent states (e.g. between x[0] and x[1], x[1] and x[2]).
                The default is 0.01.
            model: A string for the name of the model being trained. The default is HyperSINDy.
                This is an artifact of old code.
            x: A Numpy array of the data to use. This parameter is an alternate to fpath.
                The default is None. If x is not None, it will be used instead of
                the file located at fpath.
            fpath: The fll path to the data file of x. The default is None. I
            t: A numpy array of shape (x.size(0), ) denoting the corresponding
                timepoint for each (state, derivative) pair. The default is None.
            x_dot_noise: A numpy array of shape (x.size(0), x.size(1)) that can be
                added to x_dot to induce more noise. The default is None.

        Returns:
            A SyntheticDataset.
        """
        if x is not None:
            self.x = torch.from_numpy(x)
        elif fpath is not None:
            self.x = torch.from_numpy(np.load(fpath))
        else:
            print("ERROR: at least one of fpath or x must not be None.")
            exit()
        
        self.x_dot = self.fourth_order_diff(self.x, dt)
        if x_dot_noise is not None:
            self.x_dot = self.x_dot + x_dot_noise

        self.x_dot_standard = (self.x_dot - self.x_dot.mean(0)) / self.x_dot.std(0)

        self.x_lib = library.transform(self.x)
        self.dataset = dataset
        self.t = t
        if t is not None:
            self.t = torch.from_numpy(t)

    def __len__(self):
        """The length of the dataset.

        Gets the length of the dataset (in timesteps).

        Args:
            None

        Returns:
            The length of the dataset along dimension 0.
        """
        return len(self.x)
    
    def __getitem__(self, idx):
        """Gets the item.

        Gets the item at the current index.

        Args:
            idx: The integer index to access the data.

        Returns:
            If t was NOT given during construction of the dataset:
                A tuple of (tensor_a, tensor_b, tensor_c, tensor_d)
                where tensor_a is the state, tensor_b is the library,
                tensor_c is the derivative, and tensor_c is the standardized
                derivative.
            If t was given during construction:
                A tuple of (tensor_a, tensor_b, tensor_c, tensor_d, tensor_e)
                where tensors a, b, c, and d are the same as above, and tensor_e
                is the associated timepoints.
        """
        if self.t is None:
            return self.x[idx], self.x_lib[idx], self.x_dot[idx], self.x_dot_standard[idx]
        else:
            return self.x[idx], self.x_lib[idx], self.x_dot[idx], self.x_dot_standard[idx], self.t[idx]

    def fourth_order_diff(self, x, dt):
        """Gets the derivatives of the data.

        Gets the derivative of x with respect to time using fourth order
        differentiation.
        The code for this function was taken from:
        https://github.com/urban-fasel/EnsembleSINDy

        Args:
            x: The data (torch.Tensor of shape (timesteps x x_dim)) to
                differentiate.
            dt: The amount of time between two adjacent data points (i.e.,
                the time between x[0] and x[1], or x[1] and x[2]).

        Returns:
            A torch.tensor of the derivatives of x.
        """
        dx = torch.zeros(x.size())
        dx[0] = (-11.0 / 6) * x[0] + 3 * x[1] - 1.5 * x[2] + x[3] / 3
        dx[1] = (-11.0 / 6) * x[1] + 3 * x[2] - 1.5 * x[3] + x[4] / 3
        dx[2:-2] = (-1.0 / 12) * x[4:] + (2.0 / 3) * x[3:-1] - (2.0 / 3) * x[1:-3] + (1.0 / 12) * x[:-4]
        dx[-2] = (11.0 / 6) * x[-2] - 3.0 * x[-3] + 1.5 * x[-4] - x[-5] / 3.0
        dx[-1] = (11.0 / 6) * x[-1] - 3.0 * x[-2] + 1.5 * x[-3] - x[-4] / 3.0
        return dx / dt