import numpy as np
from torch.utils.data import Dataset
import random


class DatasetSampler(Dataset):
    """
    Pytorch Class that generates data from a a NMAR model

    Args:
        n_timepoints:  number of time points
    """
    def __init__(self, 
                 n_timepoints, 
                 noise_level=0.2
                 ):
        self.n_timepoints = n_timepoints
        self.noise_level = noise_level
        self.data = self._simulate(n_timepoints)

    def __len__(self):
        return self.n_timepoints 

    def _simulate(self,n_timepoints):

        # Parameters
        num_nodes = 15           # Total number of nodes
        num_clusters = 3         # Number of clusters
        cluster_size = 5         # Nodes per cluster
        n_lags = 3               # Number of autoregressive lags
        num_points = n_timepoints + n_lags         # Number of time points
        noise_level = self.noise_level      # Noise level
        interaction_strength = 0.2  # Strength of intra-cluster interactions
        weaker_interaction = 0.01   # Strength of inter-cluster interactions

        # Define the nonlinear interaction function
        def nonlinear_interaction(x):
            # return np.tanh(x)  # Use tanh to naturally limit output to [-1, 1]
            return np.sin(x)  # Use tanh to naturally limit output to [-1, 1]

        # Cluster structure
        clusters = [list(range(i * cluster_size, (i + 1) * cluster_size)) for i in range(num_clusters)]

        # Initialize time series
        data = (np.random.rand(num_nodes, n_lags) - 0.5)/5  # Random values centered around 0, scaled by a factor of 5

        # Simulate time series data
        for t in range(n_lags, num_points):
            new_values = np.zeros(num_nodes)
            for i in range(num_nodes):
                # Autoregressive component
                autoregressive_term = np.sum(data[i, t-n_lags:t] * np.random.uniform(0.2, 0.5, n_lags))

                # Interaction components
                interaction_term = 0
                for cluster in clusters:
                    if i in cluster:
                        # Intra-cluster interactions
                        for j in cluster:
                            if j != i:
                                interaction_term += nonlinear_interaction(data[j, t-1]) * np.random.uniform(0.02, interaction_strength)
                    else:
                        # Inter-cluster interactions
                        for j in cluster:
                            interaction_term += nonlinear_interaction(data[j, t-1]) * np.random.uniform(0.005, weaker_interaction)

                # Combine terms with added noise
                new_value = 0.5 * autoregressive_term + 0.5 * interaction_term + np.random.normal(0, noise_level)

                # Apply soft stabilization (scale instead of hard clipping)
                if abs(new_value) > 1:
                    new_value /= abs(new_value)

                new_values[i] = new_value

            # Append the new values to the data
            data = np.column_stack((data, new_values))

        # Keep only the last num_points
        data = data[:, -n_timepoints:]

        return data


    def __getitem__(self,idx):
        """
        idx is ignored -- it's just a requirement for dataset object specified by torch
        return a sample time series window from dataset
        """
        # cannot sample a time point where timepoint + window_size exceeds the session length
        sample = self.data[:,idx]
        sample = sample.reshape(sample.shape[0],1)

        return sample
