"""Sanity check for the dataloader_mortgage.py module."""
import sys
import os 
BASE_PATH = os.environ.get("BASE_PATH","")
sys.path.append(BASE_PATH) 
from src.dataloaders.dataloader_mortgage import PathGenerator, MortgageDataset, Mortgage, Split
import numpy as np
import torch
import copy

def mortgage_dataset_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    dataset = MortgageDataset(**config)
    dataset.setup()
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)

    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    assert torch.all(I_train[:100] == torch.zeros(100).cuda().float())
    assert torch.all(I_train[300:] == 300*torch.ones(100).cuda().float())
    X_train = train_set.X
    assert torch.all(torch.tensor(f_train[1:100]).cuda().float() == X_train[0,0,:])
    assert torch.all(torch.tensor(f_train[101:200]).cuda().float() == X_train[100,0,:])
    assert torch.all(torch.tensor(f_train[101:200]).cuda().float() == X_train[130,0,:])
    assert torch.all(torch.tensor(f_train[201:300]).cuda().float() == X_train[200,0,:])
    assert torch.all(torch.tensor(f_train[201:300]).cuda().float() == X_train[230,0,:])
    assert torch.all(torch.tensor(f_train[301:400]).cuda().float() == X_train[300,0,:])
    assert torch.all(torch.tensor(f_train[301:400]).cuda().float() == X_train[333,0,:])

    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    assert M_train.shape == Y_train.shape
    test_set = dataset.dataset_test
    I_test = test_set.sequence_start_index #I is the start index for the sequence
    assert torch.all(I_test[:100] == 600*torch.ones(100).cuda().float())
    assert torch.all(I_test[100:200] == 700*torch.ones(100).cuda().float())
    for sample in train_set:
        x, y,M,I = sample
        assert x.shape == (config["num_states"]+2, config["num_samples"]-1)
        assert torch.all(torch.eq(y[0,:,:].T,x[2:,:])).item()
        assert torch.all(x[1:1+config["loan_pool_size"],0] != torch.zeros(config["loan_pool_size"]).cuda().float()) #The first position of the loan specfic feature is not 0
        #assert torch.all(x[1:1+config["loan_pool_size"],1:] == torch.zeros(x[1:1+config["loan_pool_size"],1:].shape).cuda().float())

def mortgage_dataset_no_feature_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    config["generator"]["debug"] = True
    config["generator"]["hard"] = False
    config["generator"]["level"] = "supereasy_2d" #sets loan specific feature to false
    config["num_states"] = 3

    dataset = MortgageDataset(**config)
    dataset.setup()
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)

    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    assert torch.all(I_train[:100] == torch.zeros(100).cuda().float())
    assert torch.all(I_train[300:] == 300*torch.ones(100).cuda().float())
    X_train = train_set.X

    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    assert M_train.shape == Y_train.shape
    test_set = dataset.dataset_test
    I_test = test_set.sequence_start_index #I is the start index for the sequence
    assert torch.all(I_test[:100] == 600*torch.ones(100).cuda().float())
    assert torch.all(I_test[100:200] == 700*torch.ones(100).cuda().float())
    for sample in train_set:
        x, y,M,I = sample
        
        assert x.shape == (config["num_states"]+2, config["num_samples"]-1)
        assert torch.all(torch.eq(y[0,:,:].T,x[2:,:])).item()
        assert torch.all(x[:1,:] == torch.zeros(x[:1,:].shape).cuda().float())
        assert torch.all(x[1:2,:] == torch.ones(x[1:2,:].shape).cuda().float())
        # y and M are the same as debug=True and hard=False (i.e. the dynamics are deterministic)
        assert torch.all(torch.eq(y[0,:,:].T,M[0,:,:].T)).item()

def mortgage_dataset_test_debug(config):
    """A simple sanitycheck for the MortgageDataset class."""
    config["generator"]["debug"] = True
    config["generator"]["hard"] = False
    
    dataset = MortgageDataset(**config)
    dataset.setup()
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)

    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    assert torch.all(I_train[:100] == torch.zeros(100).cuda().float())
    assert torch.all(I_train[300:] == 300*torch.ones(100).cuda().float())
    X_train = train_set.X
    assert torch.all(torch.tensor(f_train[1:100]).cuda().float() == X_train[0,0,:])
    assert torch.all(torch.tensor(f_train[101:200]).cuda().float() == X_train[100,0,:])
    assert torch.all(torch.tensor(f_train[101:200]).cuda().float() == X_train[130,0,:])
    assert torch.all(torch.tensor(f_train[201:300]).cuda().float() == X_train[200,0,:])
    assert torch.all(torch.tensor(f_train[201:300]).cuda().float() == X_train[230,0,:])
    assert torch.all(torch.tensor(f_train[301:400]).cuda().float() == X_train[300,0,:])
    assert torch.all(torch.tensor(f_train[301:400]).cuda().float() == X_train[333,0,:])

    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    assert M_train.shape == Y_train.shape
    test_set = dataset.dataset_test
    I_test = test_set.sequence_start_index #I is the start index for the sequence
    assert torch.all(I_test[:100] == 600*torch.ones(100).cuda().float())
    assert torch.all(I_test[100:200] == 700*torch.ones(100).cuda().float())
    for sample in train_set:
        x, y,M,I = sample
        assert x.shape == (config["num_states"]+2, config["num_samples"]-1)
        assert x[2,0] == 1 #the first state is always 1. 
        assert torch.all(torch.eq(y[0,:,:].T,x[2:,:])).item()
        #since we are in debug mode, all subsequent states are either 3 or 4
        assert torch.all(torch.logical_or(x[2,1:] == 1, x[3,1:] == 1))
        assert torch.all(x[1:1+config["loan_pool_size"],0] != torch.zeros(config["loan_pool_size"]).cuda().float()) #The first position of the loan specfic feature is not 0
       # assert torch.all(x[1:1+config["loan_pool_size"],1:] == torch.zeros(x[1:1+config["loan_pool_size"],1:].shape).cuda().float())


def mortgage_dataset_loan_pool_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    config["loan_pool_size"] = 5
    dataset = MortgageDataset(**config)
    dataset.setup()
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)

    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    assert torch.all(I_train[:100] == torch.zeros(100).cuda().float())
    assert torch.all(I_train[300:] == 300*torch.ones(100).cuda().float())
    X_train = train_set.X
    v_train = torch.tensor(v_train).cuda().float()
    assert torch.all(v_train == torch.zeros(v_train.shape).cuda().float())
    hidden_path_var = train_set.hidden_path_var
    eps = 1e-5
    assert torch.sum(hidden_path_var) <= config["loan_pool_size"]*(f_train.shape[0])*(config["generator"]["h_look_back"]-1)  + eps
    assert torch.max(torch.sum(hidden_path_var,dim=2)) <= config["generator"]["h_look_back"]-1 +eps
    assert torch.min(torch.sum(hidden_path_var,dim=2)) >= 0
    assert hidden_path_var.shape == (f_train.shape[0],config["loan_pool_size"],config["num_samples"]-1)
    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    assert M_train.shape == Y_train.shape
    test_set = dataset.dataset_test
    I_test = test_set.sequence_start_index #I is the start index for the sequence
    assert torch.all(I_test[:100] == 600*torch.ones(100).cuda().float())
    assert torch.all(I_test[100:200] == 700*torch.ones(100).cuda().float())
    for sample in train_set:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        assert torch.all(x[1:1+config["loan_pool_size"],0] != torch.zeros(config["loan_pool_size"]).cuda().float()) #The first position of the loan specfic feature is not 0
        #assert torch.all(x[1:1+config["loan_pool_size"],1:] == torch.zeros(x[1:1+config["loan_pool_size"],1:].shape).cuda().float())
        assert torch.all(torch.eq(y[0,:,:].T,x[6:16,:])).item()
        assert torch.all(torch.eq(y[1,:,:].T,x[16:26,:])).item()
        assert torch.all(torch.eq(y[2,:,:].T,x[26:36,:])).item()
        assert torch.all(torch.eq(y[3,:,:].T,x[36:46,:])).item()
        assert torch.all(torch.eq(y[4,:,:].T,x[46:56,:])).item()
        assert not (
            torch.all(torch.eq(y[3,:,:].T,x[46:56,:])).item() and 
            torch.all(torch.eq(y[2,:,:].T,x[46:56,:])).item()
            )
        #y and M are the same as debug=True and hard=False (i.e. the dynamics are deterministic)
        assert torch.sum(x[6:,0]) == 5
        assert torch.sum(x[6:,12]) == 5
        assert torch.sum(x[6:,:]) == 5*99


def mortgage_dataset_level_veasy_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    config["num_states"] = 3
    config["num_terminal_states"] = 1
    config["loan_pool_size"] = 5
    config["generator"]["level"] = "veasy"
    dataset = MortgageDataset(**config)
    dataset.setup()
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)

    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    assert torch.all(I_train[:100] == torch.zeros(100).cuda().float())
    assert torch.all(I_train[300:] == 300*torch.ones(100).cuda().float())
    X_train = train_set.X
    v_train = torch.tensor(v_train).cuda().float()
    assert torch.all(v_train == torch.zeros(v_train.shape).cuda().float())
    f_train = torch.tensor(v_train).cuda().float()
    assert torch.all(f_train == torch.zeros(f_train.shape).cuda().float())
    hidden_path_var = train_set.hidden_path_var
    eps = 1e-5
    assert torch.sum(hidden_path_var) <= config["loan_pool_size"]*(f_train.shape[0])*(config["generator"]["h_look_back"]-1) + eps
    assert torch.max(torch.sum(hidden_path_var,dim=2)) <= config["generator"]["h_look_back"]-1 + eps
    assert torch.min(torch.sum(hidden_path_var,dim=2)) >= 0
    assert hidden_path_var.shape == (f_train.shape[0],config["loan_pool_size"],config["num_samples"]-1)
    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    assert M_train.shape == Y_train.shape
    test_set = dataset.dataset_test
    I_test = test_set.sequence_start_index #I is the start index for the sequence
    assert torch.all(I_test[:100] == 600*torch.ones(100).cuda().float())
    assert torch.all(I_test[100:200] == 700*torch.ones(100).cuda().float())
    for sample in train_set:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        assert torch.all(x[1:1+config["loan_pool_size"],0] != torch.zeros(config["loan_pool_size"]).cuda().float()) #The first position of the loan specfic feature is not 0
        #assert torch.all(x[1:1+config["loan_pool_size"],1:] == torch.zeros(x[1:1+config["loan_pool_size"],1:].shape).cuda().float())
        assert torch.all(torch.eq(y[0,:,:].T,x[6:9,:])).item()
        assert torch.all(torch.eq(y[1,:,:].T,x[9:12,:])).item()
        assert torch.all(torch.eq(y[2,:,:].T,x[12:15,:])).item()
        assert torch.all(torch.eq(y[3,:,:].T,x[15:18,:])).item()
        assert torch.all(torch.eq(y[4,:,:].T,x[18:21,:])).item()
        #assert not torch.all(torch.eq(y[3,:,:].T,x[18:21,:])).item()
        #y and M are the same as debug=True and hard=False (i.e. the dynamics are deterministic)
        assert torch.sum(x[6:,0]) == 5
        assert torch.sum(x[6:,12]) == 5
        assert torch.sum(x[6:,:]) == 5*99

def get_number_of_transitions_in_states(y, states):
    """Returns the number of transitions of y in states 
    
    Args:
        y: A tensor of shape (loan_pool_size, num_samples-1, num_states)
        states: A list of states. Each element is between 0 and num_states-1
    
    Returns:
        num_transitions_in_state: A tensor of shape (loan_pool_size) with the 
        number of transitions in state states for each loan in the loan pool.
    """
    num_transitions_in_state = torch.zeros(y.shape[0]).cuda().float()
    for i in range(y.shape[0]):
        for j in range(y.shape[1]):
            for elem in states:
                if y[i,j,elem] == 1:
                    num_transitions_in_state[i] += 1

    return num_transitions_in_state



def mortgage_dataset_veasy_stats_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    config["num_states"] = 3
    config["num_terminal_states"] = 1
    config["loan_pool_size"] = 10
    config["generator"]["level"] = "veasy"
    config["load_saved_data"] = True
    config["save_data"] = True
    config["num_seq"] = 160
    config["use_feature"] = False
    dataset = MortgageDataset(**config)
    dataset.setup()
    
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)
    
    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    X_train = train_set.X
    v_train = torch.tensor(v_train).cuda().float()
    assert torch.all(v_train == torch.zeros(v_train.shape).cuda().float())
    f_train = torch.tensor(v_train).cuda().float()
    assert torch.all(f_train == torch.zeros(f_train.shape).cuda().float())
    hidden_path_var = train_set.hidden_path_var
    eps = 1e-5
    assert torch.sum(hidden_path_var) <= config["loan_pool_size"]*(f_train.shape[0])*(config["generator"]["h_look_back"]-1) + eps
    assert torch.max(torch.sum(hidden_path_var,dim=2)) <= config["generator"]["h_look_back"]-1 + eps
    assert torch.min(torch.sum(hidden_path_var,dim=2)) >= 0
    assert hidden_path_var.shape == (I_train.shape[0],config["loan_pool_size"],config["num_samples"]-1)
    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    correlation_of_non_terminal_length = []
    avg_num_transitions_in_state_0 = []
    avg_num_transitions_in_non_terminal_states = []
    assert M_train.shape == Y_train.shape
    for sample in train_set:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        
        terminal_states = [2]
        initial_states = [0]
        num_non_terminal = (
            (config["num_samples"]-1)*torch.ones((config["loan_pool_size"])).cuda().float() 
            - 
            get_number_of_transitions_in_states(y, terminal_states))
        num_transitions_in_initial_states = (
            get_number_of_transitions_in_states(y, initial_states)
            )
        avg_num_transitions_in_state_0.append(torch.mean(num_transitions_in_initial_states))
        avg_num_transitions_in_non_terminal_states.append(torch.mean(num_non_terminal))
        mat = torch.stack((num_non_terminal[:5],num_non_terminal[5:]), dim=0)
        corr_coef = torch.corrcoef(mat)[0,1]
        if torch.isnan(corr_coef):
            corr_coef = torch.tensor(1).cuda().float()
        
        correlation_of_non_terminal_length.append(corr_coef.item())
    
    mean_0 = torch.mean(torch.tensor(avg_num_transitions_in_state_0))
    mean_non_terminal = torch.mean(torch.tensor(avg_num_transitions_in_non_terminal_states))
    fraction_positive_correlation =torch.sum(torch.tensor(correlation_of_non_terminal_length) > 0)/len(correlation_of_non_terminal_length)
    fraction_0 = mean_0.item()/(config["num_samples"]-1)
    fraction_non_terminal = mean_non_terminal.item()/(config["num_samples"]-1)
    
    assert fraction_positive_correlation >= 0.5 # Due to loan path dependency
    assert np.mean(correlation_of_non_terminal_length) > 0.0 # Due to loan path dependency
    assert fraction_non_terminal > 0.4 # We don't want too many defaults (i.e low default rate)
    assert fraction_0 > 0.2 # We want many transitions in the current (initial) state 


def mortgage_dataset_med_hard_stats_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    config["loan_pool_size"] = 10
    config["generator"]["level"] = "medium"
    config["load_saved_data"] = True
    config["save_data"] = True
    config["num_seq"] = 1000
    config["use_feature"] = False
    dataset = MortgageDataset(**copy.deepcopy(config))
    dataset.setup()
    config["generator"]["level"] = "hard"
    dataset_hard = MortgageDataset(**copy.deepcopy(config))
    dataset_hard.setup()
    train_set_hard = dataset_hard.dataset_train
    
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)
    
    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    X_train = train_set.X
    v_train = torch.tensor(v_train).cuda().float()
    assert torch.all(v_train == torch.zeros(v_train.shape).cuda().float())
    f_train = torch.tensor(v_train).cuda().float()
    assert torch.all(f_train == torch.zeros(f_train.shape).cuda().float())
    hidden_path_var = train_set.hidden_path_var
    eps = 1e-5
    
    assert torch.sum(hidden_path_var) <= config["loan_pool_size"]*(f_train.shape[0])*(config["generator"]["h_look_back"]-1) + eps
    assert torch.max(torch.sum(hidden_path_var,dim=2)) <= config["generator"]["h_look_back"]-1 + eps
    assert torch.min(torch.sum(hidden_path_var,dim=2)) >= 0
    assert hidden_path_var.shape == (I_train.shape[0], config["loan_pool_size"], config["num_samples"]-1)
    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    correlation_of_non_terminal_length = []
    avg_num_transitions_in_state_0 = []
    avg_num_transitions_in_non_terminal_states = []
    assert M_train.shape == Y_train.shape
    for sample in train_set:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        
        terminal_states = [8,9]
        initial_states = [0]
        num_non_terminal = (
            (config["num_samples"]-1)*torch.ones((config["loan_pool_size"])).cuda().float() 
            - 
            get_number_of_transitions_in_states(y, terminal_states))
        num_transitions_in_initial_states = (
            get_number_of_transitions_in_states(y, initial_states)
            )
        avg_num_transitions_in_state_0.append(torch.mean(num_transitions_in_initial_states))
        avg_num_transitions_in_non_terminal_states.append(torch.mean(num_non_terminal))
        mat = torch.stack((num_non_terminal[:5],num_non_terminal[5:]), dim=0)
        corr_coef = torch.corrcoef(mat)[0,1]
        if torch.isnan(corr_coef):
            corr_coef = torch.tensor(1).cuda().float()
        
        correlation_of_non_terminal_length.append(corr_coef.item())
    
    mean_0 = torch.mean(torch.tensor(avg_num_transitions_in_state_0))
    mean_non_terminal = torch.mean(torch.tensor(avg_num_transitions_in_non_terminal_states))
    fraction_positive_correlation =torch.sum(torch.tensor(correlation_of_non_terminal_length) > 0)/len(correlation_of_non_terminal_length)
    fraction_0 = mean_0.item()/(config["num_samples"]-1)
    fraction_non_terminal = mean_non_terminal.item()/(config["num_samples"]-1)
    assert fraction_positive_correlation >= 0.5 # Due to loan path dependency
    assert np.mean(correlation_of_non_terminal_length) > 0 # Due to loan path dependency
    assert fraction_non_terminal > 0.4 # We don't want too many defaults (i.e low default rate)
    assert fraction_0 > 0.2 # We want many transitions in the current (initial) state
    correlation_of_non_terminal_length = []
    avg_num_transitions_in_state_0 = []
    avg_num_transitions_in_non_terminal_states = []
    for sample in train_set_hard:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        
        terminal_states = [8,9]
        initial_states = [0]
        num_non_terminal = (
            (config["num_samples"]-1)*torch.ones((config["loan_pool_size"])).cuda().float() 
            - 
            get_number_of_transitions_in_states(y, terminal_states))
        num_transitions_in_initial_states = (
            get_number_of_transitions_in_states(y, initial_states)
            )
        avg_num_transitions_in_state_0.append(torch.mean(num_transitions_in_initial_states))
        avg_num_transitions_in_non_terminal_states.append(torch.mean(num_non_terminal))
        mat = torch.stack((num_non_terminal[:5],num_non_terminal[5:]), dim=0)
        corr_coef = torch.corrcoef(mat)[0,1]
        if torch.isnan(corr_coef):
            corr_coef = torch.tensor(1).cuda().float()
        
        correlation_of_non_terminal_length.append(corr_coef.item())
    
    mean_0 = torch.mean(torch.tensor(avg_num_transitions_in_state_0))
    mean_non_terminal = torch.mean(torch.tensor(avg_num_transitions_in_non_terminal_states))
    fraction_positive_correlation =torch.sum(torch.tensor(correlation_of_non_terminal_length) > 0)/len(correlation_of_non_terminal_length)
    fraction_0 = mean_0.item()/(config["num_samples"]-1)
    fraction_non_terminal = mean_non_terminal.item()/(config["num_samples"]-1)
    assert fraction_positive_correlation >= 0.35 # Due to loan path dependency
    assert np.mean(correlation_of_non_terminal_length) > -0.1 # Due to loan path dependency
    assert fraction_non_terminal > 0.2 # We don't want too many defaults (i.e low default rate)
    assert fraction_0 > 0.1 # We want many transitions in the current (initial) state


def mortgage_2d_path_dependency_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    config["loan_pool_size"] = 10
    config["generator"]["level"] = "2d_path_dependency"
    config["save_data"] = True
    config["num_seq"] = 100
    config["num_states"] = 3
    config["use_feature"] = True
    config["load_saved_data"] = False
    dataset = MortgageDataset(**copy.deepcopy(config))
    dataset.setup()
    dataset_hard = MortgageDataset(**copy.deepcopy(config))
    dataset_hard.setup()
    train_set_hard = dataset_hard.dataset_train
    
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)
    
    assert f_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert v_train.shape[0] == ((config["nr_steps"]*config["num_samples"])*(1 - config["val_split"] -config["test_split"]))
    assert I_train.shape[0] == ((config["num_seq"])*(1 - config["val_split"] -config["test_split"]))
    assert len(f_train.shape) == 1
    assert len(v_train.shape) == 1
    assert len(I_train.shape) == 1
    X_train = train_set.X
    v_train = torch.tensor(v_train).cuda().float()
    assert torch.all(v_train == torch.zeros(v_train.shape).cuda().float())
    f_train = torch.tensor(v_train).cuda().float()
    assert torch.all(f_train == torch.zeros(f_train.shape).cuda().float())
    hidden_path_var = train_set.hidden_path_var
    eps = 1e-3
    
    assert torch.sum(hidden_path_var) <= 2*hidden_path_var.shape[1]*(X_train.shape[0])*(config["generator"]["h_look_back"]-1) + eps
    assert torch.max(torch.sum(hidden_path_var,dim=2)) <= config["generator"]["h_look_back"]-1 + eps
    assert torch.min(torch.sum(hidden_path_var,dim=2)) >= 0
    #assert hidden_path_var.shape == (I_train.shape[0],1,config["num_samples"]-1)
    M_train = train_set.transition_probabilities #the true transition probabilities
    Y_train = train_set.Y #the true transitions
    correlation_of_non_terminal_length = []
    avg_num_transitions_in_state_0 = []
    avg_num_transitions_in_non_terminal_states = []
    assert M_train.shape == Y_train.shape
    for sample in train_set:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        terminal_states = [2]
        initial_states = [0]
        num_non_terminal = (
            (config["num_samples"]-1)*torch.ones((config["loan_pool_size"])).cuda().float() 
            - 
            get_number_of_transitions_in_states(y, terminal_states))
        num_transitions_in_initial_states = (
            get_number_of_transitions_in_states(y, initial_states)
            )
        avg_num_transitions_in_state_0.append(torch.mean(num_transitions_in_initial_states))
        avg_num_transitions_in_non_terminal_states.append(torch.mean(num_non_terminal))
        mat = torch.stack((num_non_terminal[:5],num_non_terminal[5:]), dim=0)
        corr_coef = torch.corrcoef(mat)[0,1]
        if torch.isnan(corr_coef):
            corr_coef = torch.tensor(1).cuda().float()
        
        correlation_of_non_terminal_length.append(corr_coef.item())
    
    mean_0 = torch.mean(torch.tensor(avg_num_transitions_in_state_0))
    mean_non_terminal = torch.mean(torch.tensor(avg_num_transitions_in_non_terminal_states))
    fraction_positive_correlation =torch.sum(torch.tensor(correlation_of_non_terminal_length) > 0)/len(correlation_of_non_terminal_length)
    fraction_0 = mean_0.item()/(config["num_samples"]-1)
    fraction_non_terminal = mean_non_terminal.item()/(config["num_samples"]-1)
    assert fraction_positive_correlation >= 0.4 # Due to loan path dependency
    assert np.mean(correlation_of_non_terminal_length) > 0 # Due to loan path dependency
    assert fraction_non_terminal > 0.4 # We don't want too many defaults (i.e low default rate)
    assert fraction_0 > 0.2 # We want many transitions in the current (initial) state
    correlation_of_non_terminal_length = []
    avg_num_transitions_in_state_0 = []
    avg_num_transitions_in_non_terminal_states = []
    for sample in train_set_hard:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        
        terminal_states = [2]
        initial_states = [0]
        num_non_terminal = (
            (config["num_samples"]-1)*torch.ones((config["loan_pool_size"])).cuda().float() 
            - 
            get_number_of_transitions_in_states(y, terminal_states))
        num_transitions_in_initial_states = (
            get_number_of_transitions_in_states(y, initial_states)
            )
        avg_num_transitions_in_state_0.append(torch.mean(num_transitions_in_initial_states))
        avg_num_transitions_in_non_terminal_states.append(torch.mean(num_non_terminal))
        mat = torch.stack((num_non_terminal[:5],num_non_terminal[5:]), dim=0)
        corr_coef = torch.corrcoef(mat)[0,1]
        if torch.isnan(corr_coef):
            corr_coef = torch.tensor(1).cuda().float()
        
        correlation_of_non_terminal_length.append(corr_coef.item())
    
    mean_0 = torch.mean(torch.tensor(avg_num_transitions_in_state_0))
    mean_non_terminal = torch.mean(torch.tensor(avg_num_transitions_in_non_terminal_states))
    fraction_positive_correlation =torch.sum(torch.tensor(correlation_of_non_terminal_length) > 0)/len(correlation_of_non_terminal_length)
    fraction_0 = mean_0.item()/(config["num_samples"]-1)
    fraction_non_terminal = mean_non_terminal.item()/(config["num_samples"]-1)
    assert fraction_positive_correlation >= 0.35 # Due to loan path dependency
    assert np.mean(correlation_of_non_terminal_length) > -0.1 # Due to loan path dependency
    assert fraction_non_terminal > 0.2 # We don't want too many defaults (i.e low default rate)
    assert fraction_0 > 0.1 # We want many transitions in the current (initial) state

def main():
    config = {
        "_name_": "timeseries_synthetics",
        "num_states": 10, #total number of states
        "num_terminal_states": 2, #number of terminal states
        "use_feature": True, #If false, will not include the macro variable as a feature, and also not include the loan specific features
        "num_samples": 100, #length of each sequence
        "loan_pool_size": 1, #pool size
        "load_saved_data": True,
        "saved_data_directory": f"{BASE_PATH}/data/mortgage_new/",
        "save_data": True,
        "num_seq": 1000, #number of sequences
        "val_split": 0.2, #fraction of samples in the validation split
        "test_split": 0.4, #fraction of samples in the test split
        "dataset_name": "timeseries_synthetics",
        "nr_steps": 10, #number of different starting points
        "seed": 42, # For validation split
        "generator": {
            "use_loan_specific_feature": True,
            "level": "medium",
            "path_dependency_dimension": 1,
            "h_look_back": 10,
            "h": 1,
            "h_kappa": 0,
            "k_bias": 0.65, #0.75
            "a_bias": 0.03,
            "b_bias": 0.3,
            "k_scale": 0.22,
            "a_scale": -0.03,
            "b_scale": 0.3,
            "debug": False, #if true the dynamics are simplified
            "hard": True, #only used if debug=True. If False, the dynamics are deterministic
        }
    }
    import copy
    #print(config)
    
    mortgage_dataset_loan_pool_test(copy.deepcopy(config))
    print("Fifth done")
    mortgage_dataset_level_veasy_test(copy.deepcopy(config))
    print("Sixth done")
    mortgage_dataset_test(copy.deepcopy(config))
    print("Seventh done")
    mortgage_dataset_med_hard_stats_test(copy.deepcopy(config))
    print("Third done")
    config["load_saved_data"] = False
    mortgage_dataset_no_feature_test(copy.deepcopy(config))
    print("First done")
    mortgage_2d_path_dependency_test(copy.deepcopy(config))
    print("Second done")
    mortgage_dataset_veasy_stats_test(copy.deepcopy(config))
    print("Fourth done")
    
    
    
    config["load_saved_data"] = False
    config["save_data"] = False
    mortgage_dataset_test_debug(copy.deepcopy(config))
    mortgage_dataset_loan_pool_test(copy.deepcopy(config))
    mortgage_dataset_level_veasy_test(copy.deepcopy(config))
    mortgage_dataset_test(copy.deepcopy(config))
    mortgage_dataset_no_feature_test(copy.deepcopy(config))
    
    

    ### Get statistics for the different levels: 
    # 1. What is the average number of non terminal transitions
    # 2. What is the average number of transitions in state 0.
    # 3. Correlation of the sequence length of samples in the same sequence
    print("tests passed")

if __name__ == "__main__":
    main()