"""
The MortgageDataset class, a Dataset for a synthetic mortgage timeseries task. 

The synthetic mortgage timeseries task have several properties:
1. Rare events are very important to model the task well.
2. There are many sources of stochasticisty:
    - Each sequence has a feature that partly determines the dynamics.
    - There are two variables that are common among all loans and 
      depend on time. One is hidden and the other is a feature.   
"""
import sys
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
sys.path.append(BASE_PATH)
import numpy as np
from torch.utils.data import Dataset
import torch
from src.dataloaders.base import SequenceDataset
import os
import json
import copy
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import copy
from scripts.notebooks.hawkes_kalman_filter import kalman_update

SEED_NR = 42

class MortgageDataset(SequenceDataset):
    """Creates a synthetic timeseries dataset.

    Attributes:
        dataset_train: Training dataset
        dataset_val: Validation dataset
        dataset_test: Test dataset 
    """
    _name_= "timeseries_synthetics"

    def init(self):
        pass

    def setup(self):
        np.random.seed(SEED_NR)

        # Set the random seed for PyTorch (CPU)
        torch.manual_seed(SEED_NR)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Set the random seed for PyTorch (GPU, if available)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(SEED_NR)
            torch.cuda.manual_seed_all(SEED_NR)  # if using multi-GPU

        # Ensures reproducibility in some operations that use CUDA
        torch.backends.cudnn.deterministic = True
        
        if len(self._collate_arg_names) == 0:
                self._collate_arg_names.append("M")
                self._collate_arg_names.append("I")
        self.config = self._load_config()
        if self.load_saved_data:
            path = self.saved_data_directory
            #path = path.replace("/mnt", new_base, 1)
            #load data from path if it exist else create directory
            if not os.path.exists(path):
                os.makedirs(path)
            #Open all the json files in the directory, load them into a list of dicts
            #and check if any of the json files have the same config as the current config
            #if so, load the corresponding np object file (same name) with a dictionary 
            #with keys X,Y,M,I,f,v, hidden_path_var
            os.chdir(path)
            data_found = False
            for file in os.listdir(path):
                if file.endswith(".json"):
                    with open(file, "r") as f:
                        try:
                            
                            config = json.load(f)
                            if config["generator"]["level"] == self.generator["level"]:
                                
                                config_generator = {}
                                for elem in config["generator"]:
                                    config_generator[elem] = config["generator"][elem]
                                self.config["generator"] = config_generator
                                if config == self.config:
                                    with open(path +"/"+ file[:-5]+".np", "rb") as f:
                                        data = np.load(f, allow_pickle=True)
                                        X = data.item().get("X")
                                        Y = data.item().get("Y")
                                        M = data.item().get("M")
                                        I = data.item().get("I")
                                        f = data.item().get("f")
                                        v = data.item().get("v")
                                        hidden_path_var = data.item().get("hidden_path_var")
                                        partial_M = data.item().get("partial_M") 
                                        
                                        data_found = True
                                        print("Loaded data from {}".format(path+file))
                                        break
                        except:
                            print("Failed to load data from {}".format(path+file))
                            pass
                  
        if not self.load_saved_data or not data_found:
            print("Generating data")
            X,Y,M,I,f,v, hidden_path_var, partial_M = self.create_dataset()
            data_dict = {
                "X": X, 
                "Y": Y, 
                "M": M, 
                "I": I, 
                "f": f, 
                "v": v, 
                "hidden_path_var": hidden_path_var,
                "partial_M": partial_M,
                "n_obs_partial_hawkes": np.array(self.config["n_obs_partial_hawkes"]),
                }
            if self.save_data:
                self._save_to_path(data_dict, self.saved_data_directory)
        self.dataset_train, self.dataset_val, self.dataset_test = (
            self._split_dataset(X, Y, M, I, f, v, hidden_path_var, partial_M))
    
    def _save_to_path(self, data_dict, path):
        #Generate a unique name for the data (random 6 digit number)
        #Save config as json file and data as np object file
        # Ensure the directory exists
        os.makedirs(path, exist_ok=True)
        
        name = str(np.random.randint(100000,999999)) # This is a pretty stupid solution?!!!
        #CREATE NEW config with same content:
        config_generator = {}
        for elem in self.config["generator"]:
            config_generator[elem] = self.config["generator"][elem]
        self.config["n_obs_partial_hawkes"] = list(self.config["n_obs_partial_hawkes"])
        self.config["random_input_size_options"] = list(self.config["random_input_size_options"])
        self.config["random_input_size_probabilities"] = list(self.config["random_input_size_probabilities"])
        self.config["generator"] = config_generator
        with open(path +"/"+ name+".json", "w") as f:
            json.dump(self.config, f)
            print("saved to path: {}".format(path + name+".json"))
        with open(path+"/"+ name+".np", "wb") as f:
            np.save(f, data_dict)
    
    
    def _load_config(self):
        config = {}
        config["generator"] = self.generator
        config["loan_pool_size"] = self.loan_pool_size
        config["_name_"] =  "timeseries_synthetics"
        config["use_feature"] = self.use_feature
        config["num_seq"] = self.num_seq
        config["n_obs_partial_hawkes"] = self.n_obs_partial_hawkes
        config["val_split"] = self.val_split
        config["test_split"] = self.test_split
        config["nr_steps"] = self.nr_steps
        config["use_feature"] = self.use_feature
        try:
            config["partial_obs_method"] = self.partial_obs_method
        except:
            config["partial_obs_method"] = "partially_observed_hawkes_kalman"
        
        try:
            config["continuous_mixture_weight"] = self.continuous_mixture_weight
            config["continuous_mixture_distribution"] = self.continuous_mixture_distribution
            print("Using continuous mixture with weight {} and distribution {}".format(
                self.continuous_mixture_weight, self.continuous_mixture_distribution))
        except:
            config["continuous_mixture_weight"] = 0
            config["continuous_mixture_distribution"] = "loguniform"

        config["use_random_input_size"] = self.use_random_input_size
        config["random_input_size_options"] = self.random_input_size_options
        config["random_input_size_probabilities"] = self.random_input_size_probabilities
        config["forecasting"] = self.forecasting
        if self.forecasting:
            config["simulation_steps"] = 100
        else:
            config["simulation_steps"] = self.simulation_steps
        config = self.set_level(config, self.generator["level"])
        return config

    
    def _split_dataset(self, X, Y, M, I, f, v, hidden_path_var, partial_M):
        num_seq = X.shape[0]
        assert num_seq == self.num_seq
        num_val = int(max(num_seq * self.val_split,1))
        num_test = int(max(num_seq * self.test_split,1))
        num_train = num_seq - num_val - num_test
        assert num_train > 0
        assert num_val > 0
        assert num_test > 0
        s = Split(num_train, num_val)
        X_train, X_val, X_test = s.split(X)
        Y_train, Y_val, Y_test = s.split(Y)
        M_train, M_val, M_test = s.split(M)
        partial_M_train, partial_M_val, partial_M_test = s.split(partial_M)
        I_train, I_val, I_test = s.split(I)
        total_steps = f.shape[0]
        val_steps = int(total_steps * self.val_split)
        test_steps = int(total_steps * self.test_split)
        train_steps = total_steps - val_steps - test_steps
        s2 = Split(train_steps, val_steps)
        f_train, f_val, f_test = s2.split(f)
        v_train, v_val, v_test = s2.split(v)
        hidden_path_var_train, hidden_path_var_val, hidden_path_var_test = (
            s.split(hidden_path_var))
        forecasting=self.config["forecasting"] 
        forecasting_horizon = self.forecasting_horizon
        lookback_horizon = self.lookback_horizon
        #If true, the objective is a multi-step forecasting objective, otherwise, 
        # it's a single step objective.
        if forecasting:
            return (SyntheticTimeSeriesForecasting(X_train, Y_train, M_train, 
                        I_train, f_train, v_train, hidden_path_var_train, 
                        lookback_horizon=lookback_horizon,
                        forecasting_horizon=forecasting_horizon), 
                    SyntheticTimeSeriesForecasting(X_val, Y_val, M_val, 
                        I_val, f_val, v_val, hidden_path_var_val,
                        lookback_horizon=lookback_horizon,
                        forecasting_horizon=forecasting_horizon), 
                    SyntheticTimeSeriesForecasting(X_test, Y_test, M_test,
                        I_test, f_test, v_test, hidden_path_var_test,
                        lookback_horizon=lookback_horizon,
                        forecasting_horizon=forecasting_horizon)
                    )
        else:
            return (SyntheticTimeSeries(X_train, Y_train, M_train, 
                        I_train, f_train, v_train, hidden_path_var_train, 
                        partial_M_train, self.config["n_obs_partial_hawkes"], self.config["use_random_input_size"],
                        self.config["random_input_size_options"], self.config["random_input_size_probabilities"],
                        self.config["continuous_mixture_weight"], self.config["continuous_mixture_distribution"]
                        ), 
                    SyntheticTimeSeries(X_val, Y_val, M_val, 
                        I_val, f_val, v_val, hidden_path_var_val, partial_M_val, 
                        self.config["n_obs_partial_hawkes"], self.config["use_random_input_size"],
                        [X_test.shape[2]], [1]), 
                    SyntheticTimeSeries(X_test, Y_test, M_test,
                        I_test, f_test, v_test, hidden_path_var_test, partial_M_test, self.config["n_obs_partial_hawkes"], self.config["use_random_input_size"],
                        self.config["random_input_size_options"], self.config["random_input_size_probabilities"],
                        self.config["continuous_mixture_weight"], self.config["continuous_mixture_distribution"]
                        
                        ) # Always sample from the full input size in the test set
                    )
        
    def stack(self,x):
        return torch.stack(x,dim=0).float().to(self.device)
    
    def generate_path(self, path_length, sigma):
        y = np.zeros(path_length)
        beta = 0.5
        if sigma == 0:
            return y
        for i in range(1, path_length):
            y[i] = beta*y[i-1] + np.random.normal(scale=sigma)
        return y
    
    def set_level(self,config,level):
        """Sets the difficulty of the task, by updating the config dict.
        
        Args:
            config: A dict containing the task configuration.
            level: "veasy", "easy", "medium", "hard", "vhard"
        Returns:
            config: The updated config dict.
        """
        assert level in ["supereasy_1d", "supereasy_2d","supereasy_2d_long_lookback", "supereasy_2d_no_loan_specific_feature","veasy", "easy","easy_2d", "medium", "hard", "vhard", "2d_path_dependency", "2d_with_stochasticity"], "Level must be one of: veasy, easy, medium, hard, vhard"

        if level == "supereasy_1d": # Level 1
            config["sigma_v"] = 0
            config["num_terminal_states"] = 1
            config["num_states"] = 3
            config["sigma_f"] = 0
            config["use_feature"] = True
            config["generator"]["path_dependency_dimension"] = 1
            config["use_zero_one_x"] = True
            config["generator"]["use_loan_specific_feature"] = False
        
        elif level == "supereasy_2d_no_loan_specific_feature": # Level 2
            config["sigma_v"] = 0
            config["num_terminal_states"] = 1
            config["num_states"] = 3
            config["sigma_f"] = 0
            config["use_feature"] = True
            config["generator"]["path_dependency_dimension"] = 2
            config["use_zero_one_x"] = True
            config["generator"]["use_loan_specific_feature"] = False

        elif level == "supereasy_2d": # Level 3
            config["sigma_v"] = 0
            config["num_terminal_states"] = 1
            config["num_states"] = 3
            config["sigma_f"] = 0
            config["use_feature"] = True
            config["use_zero_one_x"] = True
            config["generator"]["use_loan_specific_feature"] = True

        elif level == "supereasy_2d_long_lookback": # Level 3.5
            config["sigma_v"] = 0
            config["num_terminal_states"] = 1
            config["num_states"] = 3
            config["sigma_f"] = 0
            config["use_feature"] = True
            config["generator"]["path_dependency_dimension"] = 2
            config["use_zero_one_x"] = True
            config["generator"]["use_loan_specific_feature"] = True
        
        elif level == "2d_with_stochasticity": # Level 3.7
            config["sigma_v"] = self.generator.get("sigma_v", 0.0)
            config["num_terminal_states"] = 1
            config["num_states"] = 3
            config["sigma_f"] = self.generator.get("sigma_f", 0.1)
            config["use_feature"] = True
            config["generator"]["path_dependency_dimension"] = 2
            config["use_zero_one_x"] = True
            config["generator"]["use_loan_specific_feature"] = True
        
        elif level == "easy_2d": # Level 4
            config["sigma_v"] = 0
            config["num_terminal_states"] = 2
            config["num_states"] = 10
            config["sigma_f"] = 0
            config["use_feature"] = True
            config["generator"]["path_dependency_dimension"] = 2
            config["use_zero_one_x"] = True
            config["generator"]["use_loan_specific_feature"] = True
        
        elif level == "medium": # Level 5
            config["sigma_v"] = 0
            config["num_terminal_states"] = 2
            config["num_states"] = 10
            config["sigma_f"] = 0
            config["use_feature"] = True
            config["generator"]["path_dependency_dimension"] = 2
            config["generator"]["use_loan_specific_feature"] = True
            config["use_zero_one_x"] = False
        
        elif level == "hard": # Level 6
            config["sigma_v"] = 0
            config["num_terminal_states"] = 2
            config["num_states"] = 10
            config["sigma_f"] = 0.1
            config["use_feature"] = True
            config["generator"]["path_dependency_dimension"] = 2
            config["generator"]["use_loan_specific_feature"] = True
            config["use_zero_one_x"] = False
        else:
            raise ValueError(
                "Level must be one of: supereasy,veasy, easy, medium, hard, vhard")
        return config

    def create_full_feature(self, feature, data):
        """
        data has shape (units_per_simulation, simulation_steps, num_states)
        feature has shape (1 + units_per_simulation, simulation_steps)
        """
        feature = torch.tensor(feature)
        feature_1 = feature[0, :]  # (simulation_steps,)
        feature_2 = feature[1:, :]  # (units_per_simulation, simulation_steps)
        
        # Add feature_1 across all time steps for each unit in data
        units_per_simulation, simulation_steps, num_states = data.shape
        # Reshape and repeat feature_1 for each unit and state
        feature_1_expanded = feature_1.unsqueeze(0).repeat(units_per_simulation, 1).unsqueeze(0)  # Shape: (1, units_per_simulation, simulation_steps)
        # Permute and reshape data for concatenation
        data_X = torch.permute(data, (2, 0, 1))  # (num_states, units_per_simulation, simulation_steps)
        # Concatenate along the first dimension (for num_states + 2)
        combined = torch.cat([feature_1_expanded, feature_2.unsqueeze(0), data_X], dim=0)  # Shape: (num_states + 2, units_per_simulation, simulation_steps)
        assert torch.sum(data[:, 0, :]).item() == self.loan_pool_size, "Sum of initial values does not match loan pool size."
        return combined #(num_states + 2, units_per_simulation, simulation_steps)
        

    def create_dataset(self):
        print("Creating dataset")
        config = copy.deepcopy(self.config)
        self.num_states = config["num_states"]
        num_seq = self.num_seq
        self.X = []
        self.Y = []
        self.transition_probabilities = []
        self.partial_transition_probabilities = []
        self.sequence_start_index = [] #Starttindex
        self.hidden_path_var = []
        if self.forecasting:
            self.simulation_steps = 100
        L = self.simulation_steps
        path_length = num_seq*L
        f = self.generate_path(path_length, config["sigma_f"])
        v = self.generate_path(path_length, config["sigma_v"])
        for j in range(num_seq):
            feature = torch.zeros(
                self.loan_pool_size+1, 
                self.simulation_steps-1) #simulation_steps = simulation length
                # units_per_simulation = number of loans being simulated
            config["f"] = f[j * L:(j+1) * L]
            config["v"] = v[j * L:(j+1) * L]
            gen = PathGenerator(config)
            with torch.cuda.amp.autocast():
                gen.m_sample(self.loan_pool_size, self.simulation_steps)
            if config["use_feature"]:
                feature_1 = gen.f_mat #(1, simulation_steps-1)
                feature_2 = gen.x_mat #(loan_pool_size,simulation_steps-1)
                feature = np.concatenate((feature_1,feature_2))
            self.hidden_path_var.append(torch.tensor(gen.h_mat))
            data = gen.transitions  # (units_per_simulation,simulation_steps-1)

            data = F.one_hot(
                torch.tensor(data), num_classes = self.num_states)  # (units_per_simulation,simulation_steps,num_states)
            data_Y = data
            data_X = self.create_full_feature(feature,data)
            self.transition_probabilities.append(torch.tensor(gen.probs))
            if len(self.config["n_obs_partial_hawkes"]) > 0:
                self.partial_transition_probabilities.append(torch.tensor(gen.partially_observed_probs))
            self.X.append(data_X)
            self.Y.append(data_Y)
            self.sequence_start_index.append(j*L)
           
        X,Y,M = (
            self.stack(self.X), 
            self.stack(self.Y), 
            self.stack(self.transition_probabilities)
            )

        if len(self.config["n_obs_partial_hawkes"]) > 0:
            self.partial_transition_probabilities = self.stack(self.partial_transition_probabilities)
        I = torch.tensor(self.sequence_start_index).float().to(self.device)
        hidden_path_var = (self.stack(self.hidden_path_var)).float().to(self.device)
        if len(self.config["n_obs_partial_hawkes"]) > 0:
            return X, Y, M, I, f, v, hidden_path_var, self.partial_transition_probabilities
        else:
            return X, Y, M, I, f, v, hidden_path_var, None


class PathGenerator():
    def __init__(self, config):
        self.config = config
        self.num_terminal_states = config["num_terminal_states"]
        self.num_states = config["num_states"]
        self.unobserved_macro_variable = config["v"]
        self.observed_macro_variable= config["f"]
        self.generator = config["generator"]

        # Set the random seed for NumPy


    def get_rows_vectorized(self, states, h, f, v, x):
        # Vectorized computation of the transition matrix rows for all loans at a given time step
        A = self.num_terminal_states  # Number of terminal states
        C = self.num_states  # Total number of states

        is_terminal = states >= C - A  # Boolean array indicating terminal states
        row_matrix = np.ones((states.size, C))  # Initialize matrix of shape (M, C)

        # Handle non-terminal states
        non_terminal_indices = np.where(~is_terminal)[0]
        if non_terminal_indices.size > 0:
            non_terminal_states = states[non_terminal_indices]
            alpha = 1
            row_matrix[non_terminal_indices] *= alpha
            c = 1 + x[non_terminal_indices, 0]
            row_matrix[non_terminal_indices, non_terminal_states] *= c

            lambda_params = (h[non_terminal_indices]) * (1 + 0.1 * x[non_terminal_indices, 0])

            for i in range(A):
                row_matrix[non_terminal_indices, C - i - 1] *= lambda_params

            # Normalize each row
            row_sums = row_matrix[non_terminal_indices].sum(axis=1, keepdims=True)
            row_matrix[non_terminal_indices] /= row_sums

        # Handle terminal states: transition back to the initial state
        row_matrix[is_terminal, :] = 0
        row_matrix[is_terminal, 0] = 1

        assert np.all(row_matrix >= 0), f"Probabilities must be positive, but got {row_matrix}"
        assert np.allclose(row_matrix.sum(axis=1), 1), f"Probabilities must sum to 1, but got {row_matrix.sum(axis=1)}"
        return row_matrix

    def step_vectorized1(self, p_matrix):
        new_states = np.array([np.random.choice(self.num_states, p=p) for p in p_matrix])
        return new_states
    
    def step_vectorized(self, p_matrix):
        """
        Vectorized sampling of new states from a probability matrix.

        Args:
            p_matrix (numpy.ndarray): 2D array of shape (n, num_states) where each row is a probability distribution.
        
        Returns:
            numpy.ndarray: 1D array of new states sampled from the probability distributions.
        """
        # Generate cumulative probabilities
        cdf_matrix = np.cumsum(p_matrix, axis=1)
        
        # Generate random uniform samples for each row
        random_values = np.random.rand(p_matrix.shape[0]).reshape(-1, 1)
        
        # Find the indices where the random values fit into the CDF
        new_states = (random_values < cdf_matrix).argmax(axis=1)
        return new_states

    def get_hidden_path_var_value(self, samples, time_step, x, h_prev, method="hawkes", path_var_params=None, p_prev=None):
        if path_var_params is None:
            print("Using default path_var_params")
            path_var_params = {
                "alpha": 0.001,
                "beta": 0.85,
                "mu": 1/400,
                "n_obs_samples_partial_hawkes": self.config["n_obs_partial_hawkes"]
            }
        alpha = path_var_params["alpha"]
        beta = path_var_params["beta"]
        mu = path_var_params["mu"]
        n_obs_samples_partial_hawkes = path_var_params["n_obs_samples_partial_hawkes"]
        M = samples.shape[0]
        terminal_states = [self.num_states-1-i for i in range(self.num_terminal_states)]    
        if method == "hawkes":
            relevant_samples = samples[:,time_step]
            if self.config["generator"]["path_dependency_dimension"] == 1:
                cnt = np.sum(np.isin(relevant_samples, terminal_states))
                h = mu + beta*(h_prev-mu) + alpha*cnt
                return h, None
            
            elif self.config["generator"]["path_dependency_dimension"] == 2:
                alpha = alpha/2
                h = np.zeros(M)
                cnt_defaults_x_less_0_5 = np.sum(np.isin(relevant_samples[x[:,0] < 0.5], terminal_states))
                cnt_defaults_x_more_0_5 = np.sum(np.isin(relevant_samples[x[:,0] >= 0.5], terminal_states))

                h[x[:,0] < 0.5] = mu + beta*(h_prev[x[:,0] < 0.5]-mu) + alpha*cnt_defaults_x_less_0_5
                h[x[:,0] >= 0.5] = mu + beta*(h_prev[x[:,0] >= 0.5]-mu) + alpha*cnt_defaults_x_more_0_5
                return h, None
            else:
                raise ValueError(
                    "path_dependency_dimension must be 1 or 2, but is {}".format(
                        self.config["path_dependency_dimension"]))
            
        elif method == "partially_observed_hawkes":
            relevant_samples = samples[:n_obs_samples_partial_hawkes,time_step]
            if n_obs_samples_partial_hawkes == 0:
                return np.ones(M)*mu
            diff_factor = samples.shape[0]/n_obs_samples_partial_hawkes
            if self.config["generator"]["path_dependency_dimension"] == 1:
                cnt = np.sum(np.isin(relevant_samples, terminal_states))
                h = mu + beta*(h_prev-mu) + alpha*cnt*diff_factor
                return h, None
            
            elif self.config["generator"]["path_dependency_dimension"] == 2:
                alpha = alpha/2
                h = np.zeros(M)
                
                cnt_defaults_x_less_0_5 = np.sum(
                    np.isin(
                        relevant_samples[x[:n_obs_samples_partial_hawkes,0] < 0.5], terminal_states))
                cnt_defaults_x_more_0_5 = np.sum(
                    np.isin(
                        relevant_samples[x[:n_obs_samples_partial_hawkes,0] >= 0.5], terminal_states))

                h[x[:,0] < 0.5] = mu + beta*(h_prev[x[:,0] < 0.5]-mu) + alpha*cnt_defaults_x_less_0_5*diff_factor
                h[x[:,0] >= 0.5] = mu + beta*(h_prev[x[:,0] >= 0.5]-mu) + alpha*cnt_defaults_x_more_0_5*diff_factor
                return h, None
            else:
                raise ValueError(
                    "path_dependency_dimension must be 1 or 2, but is {}".format(
                        self.config["path_dependency_dimension"]))
            
        elif method == "partially_observed_hawkes_kalman":
            assert p_prev is not None
            relevant_samples = samples[:n_obs_samples_partial_hawkes,time_step]
            if n_obs_samples_partial_hawkes == 0:
                return np.ones(M)*mu ## This is a special case we will deal with later
            diff_factor = samples.shape[0]/n_obs_samples_partial_hawkes
            if self.config["generator"]["path_dependency_dimension"] == 1:

                cnt = np.sum(np.isin(relevant_samples, terminal_states))
                barN_t = cnt*diff_factor*alpha
                K_t = [0]*M
                P_t =[0]*M
                h = [0]*M
                for i in range(M):
                    h[i], _, P_t[i], K_t[i] = kalman_update(n_obs_samples_partial_hawkes, M, x[i,0], alpha, beta, mu, h_prev[i], barN_t, p_prev[i])

                d = {"P_t": P_t, "K_t": K_t}
                return h, d
            
            elif self.config["generator"]["path_dependency_dimension"] == 2:
                alpha = alpha/2
                h = np.zeros(M)
                
                cnt_defaults_x_less_0_5 = np.sum(
                    np.isin(
                        relevant_samples[x[:n_obs_samples_partial_hawkes,0] < 0.5], terminal_states))
                cnt_defaults_x_more_0_5 = np.sum(
                    np.isin(
                        relevant_samples[x[:n_obs_samples_partial_hawkes,0] >= 0.5], terminal_states))
                
                cnt_samples_x_less_0_5 = np.sum(x[:n_obs_samples_partial_hawkes,0] < 0.5)
                cnt_samples_x_more_0_5 = np.sum(x[:n_obs_samples_partial_hawkes,0] >= 0.5)
                total_cnt_x_less_0_5 = np.sum(x[:,0] < 0.5)
                total_cnt_x_more_0_5 = np.sum(x[:,0] >= 0.5)
                total_cnt = np.zeros(M)
                total_cnt[x[:,0] < 0.5] = total_cnt_x_less_0_5+1
                total_cnt[x[:,0] >= 0.5] = total_cnt_x_more_0_5+1
                cnt_samples = np.zeros(M)
                cnt_samples[x[:,0] < 0.5] = cnt_samples_x_less_0_5+1
                cnt_samples[x[:,0] >= 0.5] = cnt_samples_x_more_0_5+1
                cnt = np.zeros(M)
                cnt[x[:,0] < 0.5] = cnt_defaults_x_less_0_5*(total_cnt_x_less_0_5+1)/(cnt_samples_x_less_0_5+1)
                cnt[x[:,0] >= 0.5] = cnt_defaults_x_more_0_5*(total_cnt_x_more_0_5+1)/(cnt_samples_x_more_0_5+1)
                K_t = np.zeros(M)
                P_t = np.zeros(M)
                h = np.zeros(M)
                h, _, P_t, K_t = kalman_update(
                    cnt_samples, total_cnt, x[:,0], alpha, beta,
                        mu, h_prev[:], alpha*cnt, p_prev[:])

                d = {"P_t": P_t, "K_t": K_t}
                return h, d
            else:
                raise ValueError(
                    "path_dependency_dimension must be 1 or 2, but is {}".format(
                        self.config["path_dependency_dimension"]))

    def m_sample(self, M, L):
        """Generate L transitions for a pool of M loans."""
        f = self.observed_macro_variable
        v = self.unobserved_macro_variable
        x = np.zeros((M, L-1))
        if self.config["use_feature"]:
            if self.config["use_zero_one_x"]:
                x[:, 0] = np.random.choice([0, 1], M)
            else:
                x[:, 0] = np.random.uniform(0, 1, M)
        samples = np.zeros((M, L),dtype=int)
        probs = np.zeros(( M, L, self.num_states))
        probs[:, 0, 0] = np.ones(M) # start at state 0.
        path_var_params = {
                "alpha": self.config["generator"]["alpha"], # Remember to add these to the base config
                "beta": self.config["generator"]["beta"],
                "mu": self.config["generator"]["mu"],
                "n_obs_samples_partial_hawkes": self.config["n_obs_partial_hawkes"] # list of integers
            }
        if len(self.config["n_obs_partial_hawkes"]) > 0:
            L_n_obs = len(self.config["n_obs_partial_hawkes"])
            partially_observed_probs = np.zeros((M, L, self.num_states, L_n_obs))
            partially_observed_probs[:, 0, 0,:] = np.ones((M, L_n_obs)) # start at state 0.
            h_partially_observed = np.ones((M, L_n_obs))*path_var_params["mu"]
            state_estimation_var = np.ones((M, L_n_obs))*0.01
            gain_matrix = np.zeros((M, L_n_obs, L-1))
        
        hidden_path_var = np.zeros((M,L-1))
        h = np.ones(M)*path_var_params["mu"]
        for time_step in range(L-1):
            h, _ = self.get_hidden_path_var_value(
                samples, 
                time_step, 
                x,h, path_var_params=path_var_params)
            hidden_path_var[:, time_step] = h
            states = samples[:, time_step].astype(int)
            row_matrix = self.get_rows_vectorized(
                states, h, self.observed_macro_variable[time_step + 1],
                self.unobserved_macro_variable[time_step + 1], x)
            
            for idx, n in enumerate(self.config["n_obs_partial_hawkes"]):
                path_var_params["n_obs_samples_partial_hawkes"] = n
                h_partially_observed[:,idx], return_dict = self.get_hidden_path_var_value(
                samples, 
                time_step, 
                x,
                h_partially_observed[:,idx], 
                method = self.config["partial_obs_method"], 
                path_var_params=path_var_params, p_prev=state_estimation_var[:, idx])
                if self.config["partial_obs_method"] == "partially_observed_hawkes_kalman":
                    state_estimation_var[:, idx] = return_dict["P_t"]
                    gain_matrix[:, idx, time_step] = return_dict["K_t"]
                partially_observed_row_matrix = self.get_rows_vectorized(
                states, h_partially_observed[:,idx], self.observed_macro_variable[time_step + 1],
                self.unobserved_macro_variable[time_step + 1], x)
                partially_observed_probs[:, time_step + 1, :, idx] = partially_observed_row_matrix

            new_states  = self.step_vectorized(row_matrix)
            # Store new states and transition probabilities
            samples[:, time_step + 1] = new_states
            probs[:, time_step + 1, :] = row_matrix
        #np.mean(np.mean(gain_matrix, axis=2),axis=0)
        #public attributes
        self.transitions = samples[:,:-1]
        self.probs = probs[:,:-1,:]
        if len(self.config["n_obs_partial_hawkes"]) > 0:
            self.partially_observed_probs = partially_observed_probs[:,:-1,:,:]
        self.h_mat = np.array(hidden_path_var)
        self.f_mat = np.array([self.observed_macro_variable[1:]])
        #repeat the first component of x L-1 times
        if self.generator["use_loan_specific_feature"]: # TODO: This does not make sense, x should an array not a matrix.
            self.x_mat = np.repeat(np.array([x[:,0]]),L-1,axis=0).T
        else:
            self.x_mat = np.ones((M,L-1)) 


class SyntheticTimeSeries(Dataset):
    """A synthetic timeseries dataset. 

    Attributes:
        sequence_start_index: Lengthwise startindex of data
        observed_macro_variable: Visible macro variable
        unboserved_macro_variable: Hidden macro variable
        transition_probabilities: Ground truth transition probabilities
        X: states and features
        Y: transitions
        len: Number of data sequences
    """
    def __init__(
            self,
            X, 
            Y, 
            transition_probabilities, 
            sequence_start_index, 
            observed_macro_variable, 
            unobserved_macro_variable, 
            hidden_path_var,
            partially_observed_transition_probabilities,
            n_obs_partial_hawkes,
            use_random_input_size,
            random_input_size_options,
            random_input_size_probabilities,
            continuous_mixture_weight=0,
            continuous_mixture_distribution='loguniform'
            ):
        super(SyntheticTimeSeries,self).__init__()
        self.sequence_start_index = sequence_start_index
        self.observed_macro_variable= observed_macro_variable
        self.unobserved_macro_variable = unobserved_macro_variable
        self.transition_probabilities = transition_probabilities
        self.partially_observed_transition_probabilities = partially_observed_transition_probabilities
        self.n_obs_partial_hawkes = n_obs_partial_hawkes
            
        self.X = X
        self.Y = Y
        self.hidden_path_var = hidden_path_var
        self.len = self.X.shape[0]
        assert not torch.is_complex(self.X), "X contains complex values"
        assert not torch.is_complex(self.Y), "Y contains complex values"
        assert not torch.is_complex(self.transition_probabilities), "Transition probabilities contain complex values"
        assert not torch.isnan(self.X).any(), "NaN in input X"
        assert not torch.isnan(self.Y).any(), "NaN in input Y"
        self.X = self.X.float()
        self.Y = self.Y.float()
        
        self.use_random_input_size = use_random_input_size
        self.discrete_mixture_values = random_input_size_options
        self.discrete_mixture_weights = random_input_size_probabilities
        self.continuous_mixture_weight = continuous_mixture_weight
        self.continuous_mixture_distribution = continuous_mixture_distribution

    def __getitem__(self,idx):
        """self.X is fed into an ML model, 
        self.X[idx,:,,:,i] is used to predict self.Y[idx,:,:, i+1]
        """
        if self.use_random_input_size:
            max_units = self.X.shape[2]
            
            # Decide whether to sample from discrete or continuous distribution
            if np.random.rand() < self.continuous_mixture_weight:
                # Continuous distribution
                if self.continuous_mixture_distribution == 'loguniform':
                    log_low = np.log(1)
                    log_high = np.log(max_units)
                    nr_units = int(np.exp(np.random.uniform(log_low, log_high)))
                    nr_units = min(nr_units, max_units)  # Ensure it does not exceed max_units
                elif self.continuous_mixture_distribution == 'uniform':
                    nr_units = int(np.random.uniform(1, max_units + 1))  # Uniform distribution over [1, max_units]
                else:
                    raise ValueError(f"Unsupported continuous mixture distribution: {self.continuous_mixture_distribution}")
            else:
                # Discrete distribution
                choice = np.random.choice(a=self.discrete_mixture_values, p=self.discrete_mixture_weights)
                nr_units = min(choice, max_units)  # Ensure the choice is valid

            # Randomly select `nr_units` indices along the `nr_units` dimension (as a tensor)
            unit_indices = torch.randperm(max_units)[:nr_units]  # Randomly pick `nr_units` indices

            # Slice X and Y along the `nr_units` dimension using tensor-based indexing
            X_batch = self.X[idx][:, unit_indices, :]  # Shape: (nr_features, nr_units, nr_timesteps)
            Y_batch = self.Y[idx][unit_indices, :, :]  # Shape: (nr_units, nr_timesteps, nr_states)
            transition_probs_batch = self.transition_probabilities[idx][unit_indices, :, :]  # Shape: (nr_units, nr_timesteps, nr_states)
            return (
                X_batch,
                Y_batch,
                transition_probs_batch,
                self.sequence_start_index[idx]  # Scalar or shape (1)
            )
        
        else:

            return (
                    self.X[idx, :, :, :],   #(nr_features, nr_timesteps, nr_samples_for_inference)
                    self.Y[idx, :, :, :],   #(nr_samples_for_inference, nr_timesteps, nr_states)
                    self.transition_probabilities[idx, :, :, :],
                    self.sequence_start_index[idx]
                )
    
    def __len__(self):
        return self.len
    
    def size(self):
        return self.X.shape[0]*self.X.shape[1]*self.X.shape[2]
    
    def num_data_per_time_step(self):
        return self.X.shape[2]
    
    def get_partial_transition_probs(self):
        return self.partially_observed_transition_probabilities, self.n_obs_partial_hawkes

class SyntheticTimeSeriesForecasting(Dataset):
    """Mortgage dataset for multi-step forecasting."""
    def __init__(
            self,
            X, 
            Y, 
            transition_probabilities, 
            sequence_start_index, 
            observed_macro_variable, 
            unobserved_macro_variable, 
            hidden_path_var,
            lookback_horizon=25,
            forecasting_horizon=5
            ):
        super(SyntheticTimeSeriesForecasting,self).__init__()
        self.sequence_start_index = sequence_start_index
        self.observed_macro_variable= observed_macro_variable
        self.unobserved_macro_variable = unobserved_macro_variable
        self.transition_probabilities = transition_probabilities
        self.X = X
        self.Y = Y
        self.hidden_path_var = hidden_path_var
        self.len = self.X.shape[0]

        # New Stuff
        self.lookback_horizon = lookback_horizon
        self.forecasting_horizon = forecasting_horizon
        
        self.len = (
            self.X.shape[2]-self.lookback_horizon-self.forecasting_horizon) * (
                self.X.shape[0]
            )
        self.len_x = self.X.shape[0]

    def __getitem__(self,idx):
        """self.X is fed into an ML model, 
        self.X[idx,:,i] is used to predict self.Y[idx,:,i+1]
        """
        divisor = self.X.shape[2]-self.lookback_horizon-self.forecasting_horizon
        select_timeseries = idx // divisor
        select_start_point = idx % divisor
        return (
                self.X[
                    select_timeseries,
                    :, 
                    select_start_point:select_start_point+self.lookback_horizon,
                      ], 
                self.Y[select_timeseries,:,
                select_start_point+self.lookback_horizon-1:select_start_point+self.lookback_horizon+self.forecasting_horizon-1, :], 
                self.transition_probabilities[select_timeseries, :, select_start_point+self.lookback_horizon-1:select_start_point+self.lookback_horizon+self.forecasting_horizon-1,:],
                self.sequence_start_index[select_timeseries]
               )
    
    def __len__(self):
        return self.len
        

class Split:
    def __init__(self, num_train, num_val):
        self.num_train = num_train
        self.num_val = num_val
    def split(self,X):
        if X is None:
            return None, None, None
        X_train = X[:self.num_train, ...]
        X_val = X[self.num_train:self.num_train+self.num_val, ...]
        X_test = X[self.num_train+self.num_val:, ...]
        return X_train, X_val, X_test
