import sys
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
sys.path.append(BASE_PATH)
from src.dataloaders.dataloader_corelogic import LoanDataset
from scripts.notebooks.true_loss_level.get_transition_probabilities import load_model_corelogic
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader, ConcatDataset
import copy
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import pandas as pd
import matplotlib.pyplot as plt
import copy
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import gc
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
# Set global matplotlib settings for Computer Modern font
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "text.usetex": True,  # Ensures LaTeX-like font rendering
    "axes.grid": False  # Remove grid lines
})

sns.set(rc={
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "text.usetex": True,
    "axes.grid": False
})
SEED_NR = 42
np.random.seed(SEED_NR)
torch.manual_seed(SEED_NR)


def get_dataset(config, task = 'corelogic'):
    if task == 'corelogic':
        dataset = LoanDataset(**copy.deepcopy(config))
    elif task == 'synthetic':
        from src.dataloaders.dataloader_mortgage import MortgageDataset
        dataset = MortgageDataset(**copy.deepcopy(config))
    dataset.setup()
    train_set = dataset.dataset_train
    return train_set, dataset.dataset_val, dataset.dataset_test


def inspect_data(train_set):
    for sample in train_set:
        x, y, valid_idx = sample # M is the true transition probs
        print(x.shape)
        break
    

def get_model_config(name):

    


    model_config = {
        "experiment": "timeseries/set_corelogic",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-19/18-27-36/step_3000.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_20000.npz"
    }

    model_config = {
        "experiment": "timeseries/set_corelogic",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-19/00-31-34/step_3800.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_20000.npz"
    }

    model_config = {
        "experiment": "timeseries/set_corelogic",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/18-25-06/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000.npz"
    }

    model_config = {
        "experiment": "timeseries/set_corelogic",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/19-44-43/step_9200.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000.npz"
    }

    model_config = {
        "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/22-49-47/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_la_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/23-52-05/step_10800.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_la_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/23-51-14/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_la_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-21/10-06-13/step_10000.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-21/19-30-28/step_780.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-21/19-56-53/step_315.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-22/02-57-30/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-22/21-07-38/step_315.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-22/21-32-04/step_6600.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-22/23-33-32/step_120.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-23/11-55-07/step_150.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # Same as before but trained on more data (0.15, 0.15 for val and test)


    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-24/17-58-20/step_132.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # Same as before but trained on more data (0.15, 0.15 for val and test)


    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-25/23-39-58/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # 


    model_config = {
         "experiment": "timeseries/set_corelogic",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-25/23-53-51/step_240.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # set model with 1 loan

    model_config = {
         "experiment": "timeseries/set_corelogic_medium",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-27/11-19-38/step_560.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # set model using all available loans, one mlp layer


    model_config3 = {
         "experiment": "timeseries/set_corelogic_medium",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-27/19-45-39/step_440.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # set model using 1 loan at a time, one mlp layer (bz 1500)

    model_config2 = {
         "experiment": "timeseries/set_corelogic_medium",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-27/20-25-09/step_200.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # set model using 1 loan at a time, one mlp layer (bz 3000)


    model_config1 = {
         "experiment": "timeseries/set_corelogic_medium",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-28/12-14-59/step_1040.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # set model using 1 loan at a time, one mlp layer (bz 3000)

    model_config = {
         "experiment": "timeseries/set_corelogic_medium",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-28/20-21-09/step_680.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # set model using 1 loan at a time, one mlp layer (bz 3000)

    model_config = {
         "experiment": "timeseries/set_corelogic_medium",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-28/22-45-33/step_640.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }  # 3000 loans at a time, new sampling method


    model_config = {
         "experiment": "timeseries/set_corelogic_top4",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-29/00-41-37/step_680.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_top1_zip_52.npz"
    }  # top 4 zip, 700 loans at a time
    #2025-01-28/12-14-59/step_1040.ckpt
    #2025-01-28/22-45-33/step_640.ckpt

    

    model_config_700_loans = {
         "experiment": "timeseries/set_corelogic_top4",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-29/12-12-43/step_680.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    } # top 4 zip, 700 loans at a time

    model_config_700_batch_size_52 = {
         "experiment": "timeseries/set_corelogic_top4_52",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-29/17-40-51/step_720.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    } # top 4 zip, 1 loan at a time bz 700 outputs/2025-01-30/01-17-46/step_760.ckpt

    model_config_700_batch_size_50 = {
         "experiment": "timeseries/set_corelogic_top4",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/01-17-46/step_760.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config_2500_loan_size_50_j30 = {
         "experiment": "timeseries/set_corelogic_top4",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/12-22-08/step_920.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config_500_batch_size_50_j30 = {
         "experiment": "timeseries/set_corelogic_top4_1lz",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/12-52-51/step_920.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config_500_batch_size_52_j30 = {
         "experiment": "timeseries/set_corelogic_top4_52",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/13-00-30/step_1080.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    test = {
         "experiment": "timeseries/set_corelogic_top4_exp",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-02-01/09-55-29/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    test = {
         "experiment": "timeseries/set_corelogic_top4_exp",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-02-01/09-55-29/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    experiment_only_cross_section = {
         "experiment": "timeseries/set_corelogic_top4_exp",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",  2025-02-01/16-25-37/step_5110.ckpt
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/", 
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    rand_train_top4 = {
         "experiment": "timeseries/set_corelogic_top4_exp",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-02-01/18-59-36/step_1200.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    rand_train2_top4 = {
         "experiment": "timeseries/set_corelogic_top4_exp",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-02-02/11-41-20/step_11206.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }


    model_config_may14_set_seq = {
         "experiment": "timeseries/corelogic/cl_set-seq",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-05-12/10-17-27/step_756.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_top1_zip_52.npz"
    }

    nn_5_layer_config = {
        "experiment": "timeseries/corelogic/cl_5l_nn",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-05-12/12-03-50/step_819.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-05-14/12-20-07/step_819.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_top1_zip_52.npz"
    }


# Debug = False: 2025-01-31/15-21-07/step_1200.ckpt
# Debug = True: 2025-02-01/09-30-40/step_720.ckpt
# Extra train time no debug: 2025-02-01/09-55-29/last.ckpt

    model_config_dict = {
        "top4-jan29-50": model_config_700_loans,
        "top4-jan29-52": model_config_700_batch_size_52,
        "top41lz-jan29-50": model_config_700_batch_size_50,
        "top4-jan30-50": model_config_2500_loan_size_50_j30,
        "top41lz-jan30-50": model_config_500_batch_size_50_j30,
        "top41lz-jan30-52": model_config_500_batch_size_52_j30,
        "test-jan30-50": test,
        "onlycross-jan30-50": experiment_only_cross_section,
        "randtrain-jan30-50": rand_train_top4,
        "randtrain2-jan30-50": rand_train2_top4,
        "may_14_set_seq": model_config_may14_set_seq,
        "may_14_nn_5layer": nn_5_layer_config

    }
    if name in model_config_dict:
        return model_config_dict[name]
    
    assert False, f"Model config {name} not found"


# Reshape data for easier computation
def reshape_for_auc(data):
    return data.reshape(-1, data.shape[-1])

def get_auc(probs, y_true):
    probs = reshape_for_auc(probs)
    y_true = reshape_for_auc(y_true)
    auc = roc_auc_score(y_true, probs, multi_class="ovr", average=None)
    return auc


def create_auc_table(columns, data, pretext):  # Define the desired data format

    # Create a DataFrame
    df_new_table = pd.DataFrame(data, columns=columns)

    # Save the new table as an image with a similar format
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.axis('tight')
    ax.axis('off')
    table = ax.table(cellText=df_new_table.values, colLabels=df_new_table.columns, loc='center', cellLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.auto_set_column_width(col=list(range(len(df_new_table.columns))))

    # Save the table as an image
    output_path = pretext +"/data/auc_table_set_seq.png"
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.show()

def data_stats(train,val,test):
    print("Full dataset stats:")
    train.print_full_dataset_statistics()
    extensive = False
    addition_to_path = "_TOP4_ZIP"
    train.print_dataset_statistics()

    val.print_dataset_statistics()
    test.print_dataset_statistics()
    #train.save_empirical_transition_counts_image(include_exit=True, addition_to_path= addition_to_path)
    train.save_empirical_transition_counts_image(within_sampling_bounds=True, include_exit=False, counts=False, addition_to_path= addition_to_path)
    if extensive:
        train.save_empirical_transition_counts_image(within_sampling_bounds=True, include_exit=False, counts=False, addition_to_path= addition_to_path, use_sampling=True)
        val.save_empirical_transition_counts_image(within_sampling_bounds=True, include_exit=False, counts=False, addition_to_path= addition_to_path, use_sampling=True)
        test.save_empirical_transition_counts_image(within_sampling_bounds=True, include_exit=False, counts=False, addition_to_path= addition_to_path, use_sampling=True)
        
        train.save_empirical_transition_counts_image(include_exit=False)
        
        train.save_empirical_transition_counts_image(within_sampling_bounds=True, include_exit=True)
        val.save_empirical_transition_counts_image(within_sampling_bounds=True, include_exit=True)
        test.save_empirical_transition_counts_image(within_sampling_bounds=True, include_exit=True)
        
        train.save_empirical_transition_counts_image(include_exit=True, counts=True, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = False, addition_to_path= addition_to_path)
        val.save_empirical_transition_counts_image(include_exit=True, counts=True, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = False, addition_to_path= addition_to_path)
        test.save_empirical_transition_counts_image(include_exit=True, counts=True, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = False, addition_to_path= addition_to_path)
        #train.save_empirical_transition_counts_image(include_exit=True, counts=True, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = True, addition_to_path= addition_to_path)
        #val.save_empirical_transition_counts_image(include_exit=True, counts=True, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = True, addition_to_path= addition_to_path)
        #test.save_empirical_transition_counts_image(include_exit=True, counts=True, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = True, addition_to_path= addition_to_path)
        
        train.save_empirical_transition_counts_image(include_exit=True, counts=False, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = False, addition_to_path= addition_to_path)
        val.save_empirical_transition_counts_image(include_exit=True, counts=False, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = False, addition_to_path= addition_to_path)
        test.save_empirical_transition_counts_image(include_exit=True, counts=False, within_sampling_bounds=True, use_sampling=True, deterministic_sampling = False, addition_to_path= addition_to_path)

from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt

def get_auc_matrix(y_true, prob, name, save_heatmap=True):
    # Add noise to preds
    probs = copy.deepcopy(prob) + np.random.normal(0, 0.000001, prob.shape) # Add noise to probs
    # shape: [nr_val_steps, nr_loans_to_sample, nr_time_steps, nr_classes]
    y_t = y_true.reshape(-1, y_true.shape[-2],y_true.shape[-1])
    y_t_s = y_t.argmax(axis=2)
    probs_t = probs.reshape(-1, probs.shape[-2],probs.shape[-1])
    pos = defaultdict(list)

    for j in range(y_t.shape[0]):
        for k in range(y_t.shape[1]-1):
            pos[(y_t_s[j,k],y_t_s[j,k+1])].append((j,k))
    
    auc_matrix = np.zeros((probs.shape[-1],probs.shape[-1]))
    pos_count = np.zeros((probs.shape[-1],probs.shape[-1]), dtype=int)
    predicted_prob_of_transition = np.zeros((probs.shape[-1],probs.shape[-1]))
    # auc from state a to state b
    def get_auc(pos,a,b):
        classes = [0,1,2,3,4,5,6,7]
        #positive_samples = [probs_t[j,k, y_t_s[j,k]] for (j,k) in pos[(a,b)]]
        positive_samples = [probs_t[j,k+1, b] for (j,k) in pos[(a,b)]]  
        #negative_samples = [
        #probs_t[j, k, y_t_s[j, k]] for c in classes if c != b for (j, k) in pos[(a, c)]]
        negative_samples = [
        probs_t[j, k+1, b] for c in classes if c != b for (j, k) in pos[(a, c)]]
        pos_cnt = len(positive_samples)
        # sample 100 samples from each with replacement
        lower_bound = 10
        if len(positive_samples) < lower_bound or len(negative_samples) < lower_bound:
            auc = np.nan
            predicted_probs = np.nan
        else:
            n_samples = 1000
            positive_samples = np.random.choice(positive_samples, n_samples, replace=True)
            negative_samples = np.random.choice(negative_samples, n_samples, replace=True)
            auc = (1/n_samples)*np.sum([1 if positive_samples[i] > negative_samples[i] else 0 for i in range(n_samples)])
            predicted_probs = custom_round(100*np.mean(positive_samples))
        auc_matrix[a,b] = auc
        return auc, int(pos_cnt), predicted_probs
    for a in [1,2,3,4,5]:
        for b in [0,1,2,3,4,5,6]:
            auc_matrix[a,b], (pos_count[a,b]), predicted_prob_of_transition[a,b] = get_auc(pos,a,b)
    # delete row 0, 6,7
    reorder_x = [4,1,2,3,5,6, 0,7]
    reorder_y = [5,3,2,1,4,0,6,7]
    auc_matrix = auc_matrix[reorder_y,:]
    auc_matrix = auc_matrix[:,reorder_x]
    pos_count = pos_count[reorder_y,:]
    pos_count = pos_count[:,reorder_x]
    predicted_prob_of_transition = predicted_prob_of_transition[reorder_y,:]
    predicted_prob_of_transition = predicted_prob_of_transition[:,reorder_x]

    auc_matrix = np.delete(auc_matrix, [5,6,7], axis=0)
    auc_matrix = np.delete(auc_matrix, 7, axis=1)
    pos_count = np.delete(pos_count, [5,6,7], axis=0)
    pos_count = np.delete(pos_count, 7, axis=1)
    predicted_prob_of_transition = np.delete(predicted_prob_of_transition, [5,6,7], axis=0)
    predicted_prob_of_transition = np.delete(predicted_prob_of_transition, 7, axis=1)
    labels_y = ["F",  "90dd","60dd", "30dd", "Current" ]
    labels_x = ["Current", "30dd", "60dd", "90dd", "F",  "REO", "Paid Off",]

    
    def save_heatmap(data, title, filename, fmt_type, log_scale=False):
        plt.figure(figsize=(10, 7))
        sns.set_context("notebook")  # Default styling
        sns.set_style("white")  # Remove background grid

        norm = None
        if log_scale:
            norm = mcolors.LogNorm(vmin=max(1, np.min(data[data > 0])), vmax=np.max(data))  # Avoid log(0)
        cmap = LinearSegmentedColormap.from_list("custom_red", ["#f4c2c2", "#8B0000"])
        ax = sns.heatmap(
            data, annot=True, fmt=fmt_type, cmap=cmap,
            xticklabels=labels_x, yticklabels=labels_y,
            annot_kws={"size": 14},  # Larger numbers inside cells
            linewidths=0.3, linecolor="white",  # Light dividers
            norm=norm  # Apply log scale if specified
        )

        plt.xticks(fontsize=14, rotation=45)
        plt.yticks(fontsize=14, rotation=0)
        plt.xlabel("End State", fontsize=16)
        plt.ylabel("Initial State", fontsize=16)
        #plt.title(title, fontsize=18, pad=15)

        plt.tight_layout()
        plt.savefig(filename, dpi=300, bbox_inches="tight", format="pdf")  # Save as high-res PDF
        plt.close()

    if save_heatmap:
        base_path = f"{BASE_PATH}/scripts/notebooks/data/corelogic/"

        save_heatmap(auc_matrix, "AUC Transition Matrix", f"{base_path}auc_matrix_{name}.pdf", fmt_type=".2f")
        save_heatmap(predicted_prob_of_transition, "Predicted Transition Probability Matrix", f"{base_path}predicted_transition_matrix_{name}.pdf", fmt_type=".2f")
        save_heatmap(pos_count, "Position Count Matrix", f"{base_path}pos_count_matrix_{name}.pdf", fmt_type="d", log_scale=True)  # Log scale applied

    print(auc_matrix)
    return auc_matrix, pos_count, predicted_prob_of_transition



def compute_xentropy(y_true, y_pred):
    xentropy = np.sum(compute_class_contributions_xentropy(y_true, y_pred))
    # Check that xentropy is not inf or nan
    if np.isnan(xentropy) or np.isinf(xentropy):
        print("Cross entropy is nan or inf")
        breakpoint()
    return xentropy

def compute_xentropy1(y_true, y_pred):

    y_true = y_true.reshape(-1, y_true.shape[-1])
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    count = 0
    xentropy =0
    for i in range(y_true.shape[-1]-1): # ignore exit state
        
        for j in range (1,y_true.shape[0]):
            if y_true[j-1, -1] !=1 and y_true[j,i] == 1: 
                add = -(np.log(y_pred[j,i]))
                assert np.abs(add) < 10
                xentropy += -(np.log(y_pred[j,i]))
                count += 1
    return xentropy/count

def compute_class_contributions_xentropy(y_true, y_pred, tol=1e-15):
    """
    Vectorized version of your 'compute_class_contributions_xentropy',
    which:
      - Ignores the last class (exit state).
      - For each class i, sums -log(prob) over all positions j where 
        the previous row j-1 is not exit AND y_true[j, i] == 1.
      - Returns a list of each class's contribution, normalized by the total count.
      - Also prints the total xentropy (averaged).

    Returns
    -------
    list of floats
        Where `result[i]` is the average cross-entropy contributed by class i
        (divided by the total count of all classes).
    """
    # Flatten
    y_true = y_true.reshape(-1, y_true.shape[-1])
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])

    N, n_classes = y_true.shape
    assert  N >= 2

    # Which rows are non-exit for the previous row
    is_non_exit_prev = (y_true[:-1, -1] != 1)

    xentropy_by_class = np.zeros(n_classes - 1, dtype=np.float64)
    total_count = 0

    # For each class i from 0..n_classes-2
    for i in range(n_classes - 1):
        mask_i = is_non_exit_prev & (y_true[1:, i] == 1)
        p_i = y_pred[1:, i]
        xentropy_i = -np.log(p_i[mask_i]).sum() # Dont add 1e-15, want errors to be exposed!
        #assert not  np.isnan(xentropy_i)
        if np.isnan(xentropy_i) or np.isinf(xentropy_i):
            breakpoint()
        xentropy_by_class[i] = xentropy_i
        total_count += mask_i.sum()

    # Average cross-entropy across all classes
    assert total_count != 0

    # By-class contributions => (xentropy_i / total_count)
    # so that sum of returned values = mean_xentropy
    relative_contribs = [x / total_count for x in xentropy_by_class]
    return relative_contribs


def compute_class_contributions_xentropy1(y_true, y_pred):

    y_true = y_true.reshape(-1, y_true.shape[-1])
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    relative_class_contributions = []
    count = 0
    for i in range(y_true.shape[-1]-1): # ignore exit state
        xentropy =0
        for j in range (1,y_true.shape[0]):
            if y_true[j-1, -1] !=1 and y_true[j,i] == 1: 
                add = -(np.log(y_pred[j,i]))
                assert np.abs(add) < 10
                xentropy += -(np.log(y_pred[j,i]))
                count += 1

        relative_class_contributions.append(xentropy)
    return [x/count for x in relative_class_contributions]


def get_metrics_cl(y_true, y_pred, name, get_table=False, save_heatmap=True):
    extropy = compute_xentropy(y_true, y_pred)
    print(f"Cross entropy on {name}: {extropy}")

    per_class_contributions_xentropy = compute_class_contributions_xentropy(y_true, y_pred)
    
    #print(f"Per class contributions to cross entropy on {name}: {per_class_contributions_xentropy}")
    # Initially the current (or class 7) make the full contribution to the cross entropy
    # After some time, the loss should aim to spread this to be even across all classes
    # 
    assert np.abs(np.sum(per_class_contributions_xentropy) - extropy) < 1e-10
    per_class_contributions_xentropy = [custom_round(x,3) for x in per_class_contributions_xentropy]
    auc_matrix, pos_cnt, pred_probs = get_auc_matrix(y_true, y_pred, name, save_heatmap=save_heatmap)
    average_non_nan_auc = np.nanmean(auc_matrix)
    print(f"Average AUC on {name}: {average_non_nan_auc}")
    average_non_nan_predicted_prob = np.nanmean(pred_probs)
    print(f"Average predicted probability of transition on {name}: {average_non_nan_predicted_prob}")

    flat_y_true = y_true.reshape(-1, y_true.shape[-1])
    am_flat_y_true = np.argmax(flat_y_true, axis=1) 
    flat_y_pred = y_pred.reshape(-1, y_pred.shape[-1])
    # only consider the component where y_true is 1
    flat_y_pred = flat_y_pred[np.arange(flat_y_pred.shape[0]), am_flat_y_true]
    avg_prob_correct_class = np.mean(flat_y_pred[am_flat_y_true != 7])
    avg_prob_class_i = []
    for i in range(7):
        avg_prob_class_i.append(custom_round(np.mean(flat_y_pred[am_flat_y_true == i]),3))
        #print("Average probability for class {}:".format(i), np.mean(flat_y_pred[am_flat_y_true == i]))


    if get_table:
        data = ["Set-Seq Model", auc_matrix[4,6], auc_matrix[3,6], auc_matrix[2,6], auc_matrix[1,6], auc_matrix[0,6]]
        columns = ["Model", "C→P", "30dd→P", "60dd→P", "90+dd→P", "F→P"]
        full_data = [
            data,
            ["0 hidden layer", 0.65, 0.77, 0.68, 0.59, 0.57],
            ["1 hidden layer", 0.72, 0.79, 0.71, 0.76, 0.68],
            ["3 hidden layer", 0.74, 0.81, 0.73, 0.79, 0.73],
            ["5 hidden layer", 0.74, 0.81, 0.73, 0.79, 0.73],
            ["Ensemble", 0.76, 0.83, 0.74, 0.79, 0.74]
        ]
        print("AUC data:", data)
        pretext = f'{BASE_PATH}scripts/notebooks'
        create_auc_table(columns, full_data, pretext)
    metrics = {
        "name": name,
        "Xentropy": custom_round(extropy,3),
        "c-30dd": custom_round(auc_matrix[4,1],3),
        "c-c": custom_round(auc_matrix[4,0],3),
        "c-paid-off": custom_round(auc_matrix[4,6],3),
        "30dd-60dd": custom_round(auc_matrix[3,2],3),
        "Avg. AUC": custom_round(average_non_nan_auc,3),
        "Avg. SC Transition Prob": custom_round(average_non_nan_predicted_prob,3),
        "Avg. Transition Prob": custom_round(avg_prob_correct_class,3),
        "avg_prob_class_i": avg_prob_class_i,
        "per_class_contributions_xentropy": per_class_contributions_xentropy,
    }
    return metrics


def train_logistic_model(train, val, test):

    X_train, Y_train = train.get_logistic_regression_data() # update the sampling here!
    X_val, Y_val = val.get_logistic_regression_data()
    X_test, Y_test = test.get_logistic_regression_data()
    # Train logistic regression model
    
    print("getting empirical transition matrix")
    emp_count_matrix = train.get_full_dataset_empirical_transition_counts(within_sampling_bounds=True, use_sampling=True, deterministic_sampling=False)
    empirical_prob_matrix = emp_count_matrix/emp_count_matrix.sum(axis=1, keepdims=True)
    
    print(empirical_prob_matrix)
    # Create the predictions based on the empirical transition matrix
   
    assert len(np.argmax(X_val[:,:8], axis=1)) == len(Y_val)
    assert len(np.argmax(X_val[:,:8],axis=1) == (Y_val)) == len(Y_val)
    assert len(np.argmax(X_test[:,:8],axis=1) == (Y_test)) == len(Y_test)
    
    print("Training logistic regression model")
    model = LogisticRegression(max_iter=1000)
    model.fit(X_train, Y_train)
    
    y_true_1 = Y_val
    y_pred_1 = model.predict_proba(X_val)
    y_true_2 = Y_test
    y_pred_2 = model.predict_proba(X_test)
    # concatenate along
    y_true = np.concatenate((y_true_1, y_true_2), axis=0)
    y_pred = np.concatenate((y_pred_1, y_pred_2), axis=0)

    y = np.zeros((y_true.shape[0], 8))
    y[np.arange(y_true.shape[0]), y_true] = 1
    # makle y_true one hot
    
    do_sanity_checks = True
    if do_sanity_checks:
        # normalize by the number of observations of 7 (and exit state)
        #xentropy = compute_xentropy(y, y_pred)
        #print(f"Cross entropy on val+test: {xentropy}")
        #print(model.coef_)

        y_pred_2 = model.predict_proba(X_train)

        current_state = np.argmax(X_train[:,:8], axis=1)
        y_train_one_hot = np.zeros((Y_train.shape[0], 8))
        y_train_one_hot[np.arange(Y_train.shape[0]), Y_train] = 1
        Y_train = y_train_one_hot
        xentropy_lr = compute_xentropy(Y_train, y_pred_2)
        print(f"Cross entropy on train logistic regression SANITY CHECK: {xentropy_lr}")
        y_pred_emp = np.zeros((Y_train.shape[0], 8))

        current_state = np.argmax(X_train[:,:8], axis=1)
        xentropy_emp = 0
        for i in range(1,X_train.shape[0]):
            y_pred_emp[i] = empirical_prob_matrix[current_state[i]]
        
        xentropy_emp = compute_xentropy(Y_train, y_pred_emp)
        print(f"Cross entropy on train empirical SANITY CHECK: {xentropy_emp}")

        
        # Do the same for the test set
        y_pred_2 = model.predict_proba(X_test)
        current_state = np.argmax(X_test[:,:8], axis=1)
        y_test_one_hot = np.zeros((Y_test.shape[0], 8))
        y_test_one_hot[np.arange(Y_test.shape[0]), Y_test] = 1
        Y_test = y_test_one_hot
        xentropy_lr = compute_xentropy(Y_test, y_pred_2)
        print(f"Cross entropy on test logistic regression SANITY CHECK: {xentropy_lr}")
        y_pred_emp = np.zeros((Y_test.shape[0], 8))
        current_state = np.argmax(X_test[:,:8], axis=1)
        xentropy_emp = 0
        tol = 1e-8
        for i in range(1,X_test.shape[0]):
            y_pred_emp[i] = empirical_prob_matrix[current_state[i]]
        xentropy_emp = compute_xentropy(Y_test, y_pred_emp+tol)
        print(f"Cross entropy on test empirical SANITY CHECK: {xentropy_emp}")
        

        #for i in range(1,X_train.shape[0]-1):
            # y_val is one dimensional
        #    if Y_train[i-1] != 7:
        #        try:
        #            add = np.log(empirical_prob_matrix[current_state[i], Y_train[i]])
        #            add2 = np.log(y_pred_2[i,Y_train[i]])
        #            if add < -30:
        #                breakpoint()
        #        except:
        #            breakpoint()
            #if add < -3:
            #    breakpoint()
        #    xentropy_emp += add
        #    xentropy_lr += add2
        #xentropy_emp = -xentropy_emp/np.sum(Y_test != 7)
        #xentropy_lr = -xentropy_lr/np.sum(Y_test != 7)
        # normalize by the number of observations of 7 (and exit state)
        
        #print(f"Cross entropy on train logistic regression SANITY CHECK: {xentropy_lr}")

        #
        #X_test = X_train
        #y_true_2 = Y_train
        #### Sanity check #####
        #current_state = np.argmax(X_test[:,:8], axis=1)
        #y_pred_2 = model.predict_proba(X_test)
        #y = np.zeros((y_true_2.shape[0], 8))
        #y[np.arange(y_true_2.shape[0]), y_true_2] = 1
        # makle y_true one hot
        #xentropy = -np.sum(y[:,:-1]*np.log(y_pred_2[:,:-1]), axis=1)
        #xentropy = np.sum(xentropy)/np.sum(y[:,:-1])
        #print(f"Cross entropy on logistic regression test SANITY CHECK @2: {xentropy}")
    return model, empirical_prob_matrix



def make_prediction(
    model,
    model_name,
    x_batch,    # PyTorch tensor, shape [B, n_features, n_loans, n_time_steps], possibly on GPU
    y_batch=None,
    all_units_in_batch_dim = False,
):
    """
    Produce predictions given x_batch, depending on the model_name.

    Parameters
    ----------
    model : object or np.ndarray
        - If model_name == 'logistic-regression', this is your scikit-learn logistic regression object.
        - If model_name == 'empirical_transition_probabilities', this is a transition matrix (np.ndarray).
        - If model_name == 'set-seq', this is a PyTorch model (neural network).
    model_name : str
        One of {"logistic-regression", "empirical_transition_probabilities", "set-seq"}.
    x_batch : torch.Tensor
        Shape [B, n_features, n_loans, n_time_steps].  
        On GPU if "set-seq", or on CPU. We will handle conversions as needed.
    y_batch : torch.Tensor, optional
        Shape [B, n_loans, n_time_steps, n_classes]. On the same device as x_batch.
        Needed for step-wise checks or indexing the next time step.

    Returns
    -------
    y_pred_torch : torch.Tensor
        Predictions for the entire batch, shape [B, n_loans, n_time_steps, n_classes].
        Always returned as a PyTorch tensor (on the same device as x_batch).
    """

    # 1. If model_name == "set-seq", we do not convert to NumPy. We'll assume a neural net.
    if model_name == "set-seq":
        device = x_batch.device
        if all_units_in_batch_dim:
            
            # Reshape to [B*n_loans, n_features, 1, n_time_steps]
            x_batch = x_batch.permute(0, 2, 1, 3).reshape(-1, x_batch.shape[1], 1, x_batch.shape[3])

        # Move model to device (if not already)
        model = model.to(device)
        model.eval()
        model._state = None

        # No need to convert x_batch to CPU or NumPy. We'll stay on GPU for speed if available.
        with torch.no_grad():
            # The user code calls model((x_batch, {}))[0]
            # shape might be [B, n_loans, n_time, n_classes], for example
            y_pred = model((x_batch, {}))[0]
            
            # Then apply softmax over dim=3
            y_pred = F.softmax(y_pred, dim=3)
            
            # reshape y_pred to same shape as y_batch
            y_pred = y_pred.view(y_batch.shape)

            if all_units_in_batch_dim:
                
                # Reshape back to [B, n_loans, n_time_steps, n_classes]
                y_pred = y_pred.reshape(-1, y_batch.shape[1], y_batch.shape[2], y_batch.shape[3])
            
            return y_pred

    # 2. Otherwise, we’re dealing with scikit-learn logistic regression or an empirical transition matrix.
    #    => Move x_batch, y_batch to CPU and convert to NumPy.
    x_cpu = x_batch.cpu().numpy()  # shape [B, n_features, n_loans, n_time_steps]
    y_cpu = None
    if y_batch is not None:
        y_cpu = y_batch.cpu().numpy()  # shape [B, n_loans, n_time_steps, n_classes]

    B, n_features, n_loans, n_time_steps = x_cpu.shape

    # We'll prepare the final predictions as a NumPy array first (for easy step-wise logic).
    if y_cpu is not None:
        n_classes = y_cpu.shape[-1]
    else:
        n_classes = 2  # default if we don't know

    y_pred_np = np.zeros((B, n_loans, n_time_steps, n_classes), dtype=np.float32)

    # 2a. Empirical transitions
    if model_name == "empirical_transition_probabilities":
        # 'model' is actually a NumPy array for transitions: shape [num_states, n_classes]
        transition_matrix = model  

        # 'model' is the transition matrix => shape [num_states, n_classes]

        # 1) Find the "current state" for each [b, i, t], ignoring last time-step
        #    Suppose the current state is determined by the first 8 features
        current_states = np.argmax(x_cpu[:, :8, :, :], axis=1)  
        # current_states has shape [B, n_loans, n_time_steps]

        # 2) Exclude the last time-step because there's no "next state" to predict for the final step
        #    We'll fill those with zeros by default.
        current_states_flat = current_states[:, :, :-1].reshape(-1)  
        # shape => [B*n_loans*(n_time_steps-1)]

        # 3) Look up transition probabilities in one shot
        #    => shape [B*n_loans*(n_time_steps-1), n_classes]
        trans_flat = transition_matrix[current_states_flat, :]

        # 4) Reshape back to [B, n_loans, n_time_steps-1, n_classes] and store
        trans_probs_4d = trans_flat.reshape(B, n_loans, (n_time_steps - 1), n_classes)
        y_pred_np[:, :, :-1, :] = trans_probs_4d

        # -- Optional vectorized check --
        if y_cpu is not None:
            # Next states from y_cpu: shape [B, n_loans, n_time_steps-1]
            next_states = np.argmax(y_cpu[:, :, 1:, :], axis=-1).reshape(-1)
            # shape => [B*n_loans*(n_time_steps-1)]
            # Probability of the true next class
            prob_of_true = trans_flat[np.arange(len(trans_flat)), next_states]

            if not np.all(prob_of_true > 1e-8):
                breakpoint()
                raise AssertionError("Empirical prob too small for at least one true state.")
        

    # 2b. Logistic Regression
    elif model_name == "logistic-regression" or model_name =="neural_network":
        # 'model' is scikit-learn logistic regression
        # Step-wise iteration

        #model => a scikit-learn logistic regression instance

        # 1) Reshape the input so we can do a single predict_proba() call
        #    We'll skip the last time-step because there's no next state to predict for it.
        X_no_last = x_cpu[:, :, :, :-1]  # shape [B, n_features, n_loans, n_time_steps-1]
        # reorder to [B, n_loans, n_time_steps-1, n_features]
        X_no_last = np.transpose(X_no_last, (0, 2, 3, 1))
        # flatten => [B*n_loans*(n_time_steps-1), n_features]
        X_flat = X_no_last.reshape(-1, n_features)

        # 2) Big predict_proba => shape [B*n_loans*(n_time_steps-1), n_classes]
        probas_flat = model.predict_proba(X_flat)

        # 3) Reshape back => [B, n_loans, n_time_steps-1, n_classes]
        
        probas_4d = probas_flat.reshape(B, n_loans, (n_time_steps - 1), n_classes)
        y_pred_np[:, :, :-1, :] = probas_4d

        # -- Optional vectorized check --
        if y_cpu is not None:
            # Next states from y_cpu => shape [B, n_loans, n_time_steps-1]
            next_states = np.argmax(y_cpu[:, :, 1:, :], axis=-1).reshape(-1)
            # shape => [B*n_loans*(n_time_steps-1)]
            prob_of_true = probas_flat[np.arange(len(probas_flat)), next_states]
            #if not np.all(prob_of_true > 1e-10):
            #    breakpoint()
            #    raise AssertionError("Logistic model prob for true state is too small in at least one case.")

    else:
        raise ValueError(f"Unsupported model_name: {model_name}")

    # Convert predictions (NumPy) back to a PyTorch tensor on the same device as x_batch originally was
    # But note x_batch is now on CPU if we got here. 
    # If you'd rather keep everything on CPU for logistic/empirical, you can do .to(device).
    # For clarity, let's keep them on CPU:
    y_pred_torch = torch.from_numpy(y_pred_np)

    return y_pred_torch


def zero_out_time_range(y_pred_torch, y, start_time, end_time):
    """
    Zero out the output channels for time steps outside [start_time, end_time)
    and set the '1'-th (or 7-th) channel to 1 for those positions.
    
    Assumes:
        - y_pred_torch, y each have shape (B, I, T, D).
        - T is the third dimension (index=2) => size=50 in your case.
    """
    B, I, T, D = y_pred_torch.shape  # => (1, 3000, 50, 8)
    device = y_pred_torch.device

    # 1D time vector
    time = torch.arange(T, device=device)  # => [0..49]
    # Mask for which time indices we want to zero out
    mask_1d = (time < start_time) | (time >= end_time)  # shape: (T,)

    # Broadcast to shape (B, I, T) => (1, 3000, 50)
    mask_3d = mask_1d.unsqueeze(0).unsqueeze(0).expand(B, I, T)

    # Step 1: Zero out all channels at these positions
    #         (flattening over B, I, T)
    y_pred_torch[mask_3d] = 0
    y[mask_3d] = 0

    # Step 2: If you want to set channel 1 to 1, then gather those (b,i,t) indices:
    #    For older PyTorch, just do mask_3d.nonzero(). For newer, use as_tuple=True
    b_idx, i_idx, t_idx = mask_3d.nonzero(as_tuple=True)

    # Now set channel=1 to 1 for those positions:
    y_pred_torch[b_idx, i_idx, t_idx, 7] = 1
    y[b_idx, i_idx, t_idx, 7] = 1

    return y_pred_torch, y

def custom_round(x, nr_decimals=2):
    if np.isnan(x):
        print(f"x is nan: {x}")
        return 0 
        breakpoint() # TODO figure out the issue
    if x == 0:
        return 0  # Explicitly return 0 if x is 0
    elif x >= 1:
        return round(x, nr_decimals)  # Round to two decimal places if x >= 1
    else:
        # Compute the position of the first non-zero digit
        exponent = -int(np.floor(np.log10(abs(x))))
        return round(x, exponent + nr_decimals)  # Round to two digits after the leading zero


def evaluate_model_cl(
    model,         # logistic model, set-seq model, or np.ndarray (transition matrix)
    model_name, 
    val_set, 
    test_set=None, 
    batch_size=1,
    fix_seed=True,
    all_units_in_batch_dim=False
):
    """
    Evaluates the given model on val_set (and optionally test_set),
    returning predictions and ground-truth, all as PyTorch Tensors.
    """
    if fix_seed:
        torch.manual_seed(SEED_NR)
        np.random.seed(SEED_NR)
    # Merge val/test sets if needed
    if test_set is not None:
        combined_dataset = ConcatDataset([val_set, test_set])
    else:
        combined_dataset = val_set

    dataloader = DataLoader(
        combined_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        drop_last=True
    )
    all_y_pred = []
    all_y_true = []

    for batch in dataloader:
        x, y, valid_indices = batch  # x and y are PyTorch tensors, on CPU by default (unless you pinned them or moved them).

        # If we're using a "set-seq" model on GPU, we might want to move x, y to GPU here:
        # But you can also do that inside the model or pass them as is. For example:
        if model_name == "set-seq":
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            x = x.to(device)
            y = y.to(device)

        # Now let make_prediction handle conversions
        y_pred_torch = make_prediction(
            model=model,
            model_name=model_name,
            x_batch=x,
            y_batch=y,
            all_units_in_batch_dim=all_units_in_batch_dim
        )
        start_time = valid_indices[0][0]
        end_time = valid_indices[1][0]
        if y.shape != y_pred_torch.shape:
            print(f"Shapes of y and y_pred_torch do not match: {y.shape} vs {y_pred_torch.shape}")
            breakpoint()
        # make y_pred_torch and y be in state 7 outside of the sampling bounds
        y_pred_torch, y = zero_out_time_range(y_pred_torch, y, start_time, end_time)
        
        # y_pred_torch has shape [B, n_loans, n_time_steps, n_classes]

        all_y_pred.append(y_pred_torch)
        all_y_true.append(y)
    
    # Concatenate along the batch dimension
    y_pred_concat = torch.cat(all_y_pred, dim=1)  # shape [N, n_loans, n_time_steps, n_classes]
    y_true_concat = torch.cat(all_y_true, dim=1)

    return y_pred_concat, y_true_concat


def print_metrics(metric_list):
    import pandas as pd  # Ensure pandas is imported

    # Convert list of dicts to DataFrame
    df = pd.DataFrame(metric_list)
    
    # Set the "name" column as the index if it exists
    if 'name' in df.columns:
        df.set_index('name', inplace=True)
    
    # Rename index from "name" to "Model"
    df.index.name = "Model"

    # Identify columns with lists
    list_cols = [col for col in df.columns if df[col].apply(lambda x: isinstance(x, list)).any()]

    # Convert lists to readable strings in the full version
    df_full = df.copy()
    df_full[list_cols] = df_full[list_cols].applymap(lambda x: ', '.join(map(str, x)) if isinstance(x, list) else x)

    # Create the version without list values
    df_no_lists = df.drop(columns=list_cols)

    # Format numeric values to remove trailing zeros
    def format_number(val):
        if isinstance(val, float):
            return f"{val:.5g}"  # Uses general format, removing unnecessary zeros
        return val
    
    df_no_lists = df_no_lists.applymap(format_number)
    df_full = df_full.applymap(format_number)

    # LaTeX table formatting
    def format_latex_table(df, label):
        return f"""
    \\begin{{table}}[h]
        \\centering
        \\caption{{Comparing Performance of Different Models on the Mortgage Risk Task}}
        \\label{{{label}}}
        {df.to_latex(escape=False, index=True)}
    \\end{{table}}
        """.strip()

    # Generate LaTeX tables
    latex_no_lists = format_latex_table(df_no_lists, "tab:metrics")
    latex_full = format_latex_table(df_full, "tab:metrics_full")

    # Save CSV files
    df_no_lists.to_csv("metrics.csv")
    df_full.to_csv("metrics_full.csv")

    # Save LaTeX files
    with open("metrics.tex", "w") as f:
        f.write(latex_no_lists)
    
    with open("metrics_full.tex", "w") as f:
        f.write(latex_full)
    
    # Print only the non-list metrics LaTeX table
    print("\nCopy the following into Overleaf for metrics (without lists):\n")
    print(latex_no_lists)

    print("\nMetrics table saved to metrics.csv and metrics_full.csv")
    print("LaTeX table saved to metrics.tex and metrics_full.tex")


def run_evaluation(
        train, 
        val, 
        test,
        data_validation=False,
        data_sanity_checks=False,
        plot_active_and_defualt_loans=False,
        train_baseline_nn=False,
        train_baseline_logistic=False,
        model_eval=False,
        dataset="top4",
        date="jan29",
        nr_features=50,
        batch_size=1,
        all_units_in_batch_dimension=False
        ):
    tol = 5e-8
    if data_validation:
        data_stats(train,val,test)

    if data_sanity_checks:
        train.feature_visualization(nr_loans_to_sample=10000)
    
    plot_active_and_defualt_loans = True
    if plot_active_and_defualt_loans:
        train.get_nr_active_loans_plot("new_sampling_2")

    metric_list = []

    if train_baseline_nn:
        from train_simple_nn import train_nn
        X_train, Y_train = train.get_logistic_regression_data()
        X_val, Y_val = val.get_logistic_regression_data()
        nn_model = train_nn(X_train, Y_train, batch_size=1000, epochs=15, X_val=X_val, Y_val=Y_val)
        nn_model.to("cpu")
        # 2) Evaluate logistic regression model
        y_nn_probs, y_true_nn = evaluate_model_cl(
            model=nn_model,
            model_name="neural_network",
            val_set=test,
            test_set=None
        )
        

        # Slice off the last time-step from predictions and the first time-step from labels
        y_nn_probs = y_nn_probs[:, :, :-1, :].detach().cpu().numpy() + tol
        y_true_nn = y_true_nn[:, :, 1:, :].detach().cpu().numpy()
        name = f"{dataset}-neural-network-test-{date}-{nr_features}"
        metrics_nn = get_metrics_cl(y_true_nn, y_nn_probs, name)
        metric_list.append(metrics_nn)

    if train_baseline_logistic:
        name = f"{dataset}-logistic-regression-test-{date}-{nr_features}"

        log_model, empirical_prob_matrix = train_logistic_model(train, val, test)

        # 2) Evaluate logistic regression model
        y_log_probs_torch, y_true_log_reg_torch = evaluate_model_cl(
            model=log_model,
            model_name="logistic-regression",
            val_set=test,
            test_set=None
        )

        # Slice off the last time-step from predictions and the first time-step from labels
        y_log_probs = y_log_probs_torch[:, :, :-1, :].detach().cpu().numpy() + tol
        y_true_log_reg = y_true_log_reg_torch[:, :, 1:, :].detach().cpu().numpy()
        # 3) Optionally, also evaluate the empirical transition probabilities
        y_emp_probs_torch, y_true_emp_torch = evaluate_model_cl(
            model=empirical_prob_matrix+tol,  # pass the matrix itself!
            model_name="empirical_transition_probabilities",
            val_set=test,
            test_set=None
        )

        empirical_probabilities = y_emp_probs_torch[:, :, :-1, :].detach().cpu().numpy()
        y_true_emp = y_true_emp_torch[:, :, 1:, :].detach().cpu().numpy()


        y_true_reshaped = np.reshape(y_true_emp, (-1,y_true_emp.shape[-1]))
        amax_y = np.argmax( y_true_reshaped, axis=-1)
        probs_for_correct_class = np.reshape(empirical_probabilities,(-1,empirical_probabilities.shape[-1] ))[np.arange(y_true_reshaped.shape[0]), amax_y]
        min_probs = np.min(probs_for_correct_class)
        #if min_prob    s == 0: 
         #   breakpoint() # This is probably fine
        assert (y_true_emp == y_true_log_reg).all()
        metrics_logistic_regression = get_metrics_cl(y_true_log_reg, y_log_probs, name) # pos_cnt is the empirical transition counts
        metric_list.append(metrics_logistic_regression)
        name = f"{dataset}-empirical-transition-matrix-test-{date}-{nr_features}"
        metrics_empirical_probabilities = get_metrics_cl(y_true_emp, empirical_probabilities, name)
        metric_list.append(metrics_empirical_probabilities)

    if model_eval:
        name_val = f"{dataset}-test-set-seq-{date}-{nr_features}"
        name_train = f"{dataset}-train-set-seq-{date}-{nr_features}"
        model_config = get_model_config(name=dataset+"-"+date+"-"+str(nr_features))
        model = load_model_corelogic(**model_config)

        # Evaluate on val+test together, It makes more sense to evaluate only on test
        probs_torch, y_true_torch = evaluate_model_cl(
            model=model,
            model_name="set-seq",
            val_set=test,
            test_set=None,   #
            batch_size=batch_size,
            fix_seed=True,
            all_units_in_batch_dim=all_units_in_batch_dimension
        )

        # Convert to NumPy and slice as needed
        probs = probs_torch[:, :, :-1, :].detach().cpu().numpy() + tol
        y_true = y_true_torch[:, :, 1:, :].detach().cpu().numpy()
        metrics_set_seq_val = get_metrics_cl(y_true, probs, name=name_val)
        metric_list.append(metrics_set_seq_val)
        # Evaluate on train
        probs_train_torch, y_true_train_torch = evaluate_model_cl(
            model=model,
            model_name="set-seq",
            val_set=train,   # pass train as "val_set" if you like
            test_set=None,
            batch_size=batch_size,
            fix_seed=True,
            all_units_in_batch_dim=all_units_in_batch_dimension  
        )

        probs_train = probs_train_torch[:, :, :-1, :].detach().cpu().numpy() + tol
        y_true_train = y_true_train_torch[:, :, 1:, :].detach().cpu().numpy()
        metrics_set_seq_train = get_metrics_cl(y_true_train, probs_train, name_train)
        metric_list.append(metrics_set_seq_train)
    return metric_list 


def float_to_str_no_leading_zero(value, precision=3):
    """
    Convert a floating-point number into a string with given precision,
    removing the leading '0' if |value| < 1.

    Examples:
      0.0123 -> ".012"
      -0.456 -> "-.456"
      1.2345 -> "1.234"
      -9.876 -> "-9.876"
      0.0    -> ".000"
    """
    sign = "-" if value < 0 else ""
    abs_val = abs(value)
    fmt_str = f"{abs_val:.{precision}f}"   # e.g. '0.012'
    if abs_val < 1:
        # Remove the '0' before the decimal
        # '0.012' -> '.012'
        parts = fmt_str.split(".")  # ['0', '012']
        # Keep decimal part as is
        fmt_str = f".{parts[1]}"
    return sign + fmt_str

def save_heatmap_diff(
    diff_mean_matrix,
    diff_str_matrix,
    filename,
    labels_x,
    labels_y,
    title="AUC Difference (Seq - Logistic)",
    cmap="coolwarm"
):
    """
    Saves a heatmap of the mean difference matrix, annotating each cell
    with 'mean (std)' but removing the leading zero from strings.
    """
    plt.figure(figsize=(10, 7))
    sns.set_context("notebook")  
    sns.set_style("white")  # Remove background grid

    # Force color range symmetric around 0
    #min
    #vmin, vmax = -max_abs, max_abs
    vmin, vmax = np.nanmin((diff_mean_matrix)), np.nanmax((diff_mean_matrix))
    cmap = LinearSegmentedColormap.from_list("custom_red", ["#f4c2c2", "#8B0000"])
    ax = sns.heatmap(
        diff_mean_matrix,
        annot=diff_str_matrix,
        fmt="",  # We have the full annotation string in 'diff_str_matrix'
        cmap=cmap,
        xticklabels=labels_x,
        yticklabels=labels_y,
        annot_kws={"size": 14},  # controlling cell text size
        linewidths=0.3,
        linecolor="white",
        vmin=vmin,
        vmax=vmax
    )

    plt.xticks(fontsize=14, rotation=45)
    plt.yticks(fontsize=14, rotation=0)
    plt.xlabel("End State", fontsize=16)
    plt.ylabel("Initial State", fontsize=16)
    plt.title(title, fontsize=18, pad=15)  # (optional)

    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches="tight", format="pdf")
    plt.close()


def get_auc_diff(
    model_set_seq,
    model_name1, 
    model_logistic,
    model_name2,
    test_set,
    seeds=None,
    batch_size=1,
    all_units_in_batch_dim=False,
    heatmap_filename="diff_matrix_mean_std.pdf"
):
    """
    Compare the AUC transition matrices of two models (Set-Seq and Logistic)
    by repeatedly sampling from 'test_set' using multiple seeds. For each seed:

      - Fix seed by setting test_set.eval_seed = seed.
      - Evaluate both models, compute AUC transition matrices:
          auc_seq, auc_log
      - diff_matrix = (auc_seq - auc_log)   # since Seq is better

    Then average the difference matrices across all seeds, compute std dev,
    and save a heatmap with 'mean (std)' as annotations (omitting leading zeros).

    Parameters
    ----------
    model_set_seq : PyTorch model
        Your trained Set-Seq model.
    model_logistic : sklearn LogisticRegression
        Your trained logistic model.
    test_set : LoanDataset
        The test dataset, configured to random-sample each call.
    seeds : list of int, optional
        List of integer seeds to use for repeated evaluation. Defaults to [1000..10000].
    batch_size : int
        For DataLoader in evaluate_model_cl.
    all_units_in_batch_dim : bool
        If True, Set-Seq model will reshape [B, feats, loans, time] -> [B*n_loans, feats, 1, time].
    heatmap_filename : str
        Path for saving the heatmap PDF with mean/std difference.

    Returns
    -------
    mean_diff_matrix : np.ndarray
        Elementwise average difference of (seq_auc - log_auc) across seeds.
    std_diff_matrix : np.ndarray
        Elementwise std dev of that difference.
    diff_matrix_as_str : np.ndarray of str
        Each cell is "mean (std)", with leading zero removed from each float.
    """
    if seeds is None:
        seeds = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]

    diff_matrices = []
    auc_matrices = []
    # We assume the get_auc_matrix returns a 5x7 matrix after reordering/deleting rows & columns:
    labels_y = ["F", "90dd", "60dd", "30dd", "Current"]
    labels_x = ["Current", "30dd", "60dd", "90dd", "F", "REO", "Paid Off"]

    for seed in seeds:
        # Fix dataset seed to produce a new random sample for each iteration
        test_set.eval_seed = seed

        # =========== Evaluate Set-Seq ============
        y_pred_seq, y_true_seq = evaluate_model_cl(
            model=model_set_seq,
            model_name=model_name1,
            val_set=test_set,
            test_set=None,
            batch_size=batch_size,
            fix_seed=False,  # We'll handle seeds ourselves
            all_units_in_batch_dim=all_units_in_batch_dim
        )
        # Slice off last time-step from predictions, first from labels
        y_pred_seq = y_pred_seq[:, :, :-1, :].cpu().numpy()
        y_true_seq = y_true_seq[:, :, 1:, :].cpu().numpy()

        auc_seq, _, _ = get_auc_matrix(y_true_seq, y_pred_seq, name="temp_seq")

        # =========== Evaluate Logistic ============
        y_pred_log, y_true_log = evaluate_model_cl(
            model=model_logistic,
            model_name=model_name2,
            val_set=test_set,
            test_set=None,
            batch_size=batch_size,
            fix_seed=False,
            all_units_in_batch_dim=False
        )
        y_pred_log = y_pred_log[:, :, :-1, :].cpu().numpy()
        y_true_log = y_true_log[:, :, 1:, :].cpu().numpy()

        auc_log, _, _ = get_auc_matrix(y_true_log, y_pred_log, name="temp_log")

        assert auc_seq.shape == auc_log.shape, (
            f"AUC matrix shape mismatch: set-seq {auc_seq.shape}, logistic {auc_log.shape}"
        )

        # ====== Compute difference: seq_auc - log_auc =====
        diff_matrix = auc_seq - auc_log
        diff_matrices.append(diff_matrix)
        auc_matrices.append(auc_seq)


    diff_matrices = np.stack(diff_matrices, axis=0)  # shape = [nSeeds, row, col]
    auc_matrices = np.stack(auc_matrices, axis=0)  # shape = [nSeeds, row, col]
    mean_diff_matrix = np.mean(diff_matrices, axis=0)
    mean_auc_matrix = np.mean(auc_matrices, axis=0)
    std_diff_matrix = np.std(diff_matrices, axis=0)
    std_auc_matrix = np.std(auc_matrices, axis=0)

    # Build matrix of strings => "mean (std)" without leading zeros
    rows, cols = mean_diff_matrix.shape
    diff_matrix_as_str = np.empty((rows, cols), dtype=object)
    auc_matrix_as_str = np.empty((rows, cols), dtype=object)
    for i in range(rows):
        for j in range(cols):
            mval = float_to_str_no_leading_zero(mean_diff_matrix[i, j], precision=3)
            sval = float_to_str_no_leading_zero(std_diff_matrix[i, j], precision=3)
            diff_matrix_as_str[i, j] = f"{mval} ({sval})"
            mval = float_to_str_no_leading_zero(mean_auc_matrix[i, j], precision=3)
            sval = float_to_str_no_leading_zero(std_auc_matrix[i, j], precision=3)
            auc_matrix_as_str[i, j] = f"{mval} ({sval})"

    # Print summary in console
    print("\nAverage AUC Difference [Seq - Logistic] (Mean ± Std):\n")
    for row in diff_matrix_as_str:
        print("  ".join(row))

    # Save heatmap as a PDF
    if model_name2 == "logistic-regression":
        title = "Mean AUC Difference (Seq-Seq - Logistic Regression) Across {} Seeds".format(len(seeds))
    else:
        title = "Mean AUC Difference (Seq-Seq - Neural Network) Across {} Seeds".format(len(seeds))
    save_heatmap_diff(
        diff_mean_matrix=mean_diff_matrix,
        diff_str_matrix=diff_matrix_as_str,
        filename=heatmap_filename,
        labels_x=labels_x,
        labels_y=labels_y,
        title=title,
        cmap="coolwarm"
    )


    #save_heatmap_diff(
    #    diff_mean_matrix=mean_auc_matrix,
    #    diff_str_matrix=auc_matrix_as_str,
    #    filename=f"{BASE_PATH}/scripts/notebooks/data/corelogic/auc_matrix_mean_std_feb2.pdf",
    #    labels_x=labels_x,
   #     labels_y=labels_y,
    #    title="Mean AUC Matrix Set Seq Across {} Seeds".format(len(seeds)),
    #    cmap="coolwarm"

#    )

    return mean_diff_matrix, std_diff_matrix, diff_matrix_as_str, auc_matrices
    

def main():

    dataset_config  =  {
            
            "path_origination": "/share/data/llm_mortgages/data/filtered_origination_data_top_4_zips.csv",
            "path_performance": "/share/data/llm_mortgages/data/filtered_performance_data_top_4_zips.csv",
            "normalize_data": True,
            "database_size": 300000, #300000, #300000, #300000,
            "start_year": 1988, # correct start year
            "end_year": 2023, # correct end year
            "columns_to_normalize_origination": [
                "fico_score_at_origination", 
                "original_balance", 
                "initial_interest_rate", 
                "original_ltv"
                ],
            "columns_to_normalize_performance": [
                "current_balance", 
                "current_interest_rate", 
                "scheduled_monthly_pi",
                "scheduled_principal",
                "mba_days_delinquent"
                ],
            "feature_set": [
                "current_state", #8
                'fico_score_at_origination', # 1
                "original_balance",  # 1
                "initial_interest_rate", # 1 
                "original_ltv", # 1
                "unemployment_rate",  # 1
                "national_mortgage_rate", # 1 
                "current_balance",  # 2
                "current_interest_rate", # 2 
                "scheduled_monthly_pi", # 2
                "scheduled_principal", # 2 
                "mba_days_delinquent", # 2
                "inferred_collateral_type", # 2
                "convertible_flag", # 2
                "pool_insurance_flag", # 2
                "io_flag", # 2
                "prepay_penalty_flag", # 2
                "negative_amortization_flag", # 2
                "buydown_flag", # 2
                "loan_age", # 4
                "original_term", # 2
                "times_30dd", # 2
                "times_60dd", # 1
                "times_90dd", # 1
                "times_current", # 1
                "times_foreclosure", # 1
                "zip-code", # 5
                "lagged_foreclosure_rate", # 1
                "lagged_prepayment_rate", # 2
                ], # Total 58, Total 55 
                 
            "nr_classes": 8,
            "verbose": True,
        }

    config = {
        "_name_": "corelogic_loan_dataset",
        "dataset_config": dataset_config,
        "val_split": 0.1,
        "test_split": 0.3,
        "val_split_date": "2009-06", #"2009-06",
        "test_split_date": "2009-12", #"2009-12",
        "load_data": True,
        "save_data": False,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_52.npz",  # 52 # is this correct data? May be 52?!
        "max_to_sample": 4500, #4500, #4500, # Total nr loans
        "nr_sampling_timesteps": 50,
        "nr_loans_to_sample": 2500, #500, #6000 Actually we have 30000 loans, so 60000 samples per epoch is about right
        "steps_per_epoch": 70, #200,  # 20
        "sample_random_loan_index": True,
        "sample_random_time_index": True,
        "eval_mode": True,
        "eval_seed": 3000
    }
    date = "jan30"
    diff_auc = True
    create_dataset = False
    if create_dataset:
        print("Loading data")
        train, val, test = get_dataset(config)
        print("Data loaded")
    if diff_auc:
        train, val, test = get_dataset(config)
        
        print("Data loaded")  
        model_config = get_model_config(name="may_14_set_seq")  #   "randtrain-jan30-50"
        
        print("Set Model loaded")
        model = load_model_corelogic(**model_config)
        print("Log Model loaded")
        compare_set_seq_logistic = False
        if compare_set_seq_logistic:
            train.eval_seed = 1000
            log_model, empirical_prob_matrix = train_logistic_model(train, val, test)
            base = f"{BASE_PATH}/scripts/notebooks/data/corelogic/"
            mean_diff_matrix, std_diff_matrix, diff_matrix_as_str, auc_matrices_set_seq = get_auc_diff(
                model_set_seq=model,
                model_name1="set-seq",
                model_logistic=log_model,
                model_name2="logistic-regression",
                test_set=test,
                seeds=[1000, 2000, 3000,4000,5000,6000, 7000, 8000, 9000, 10000],
                batch_size=1,
                all_units_in_batch_dim=False,
                heatmap_filename=base+"diff_matrix_set_seq_log_reg_mean_std_feb2.pdf"
            )
            print(mean_diff_matrix)
            print(std_diff_matrix)
        compare_set_seq_nn = False
        if compare_set_seq_nn:
            
            from train_simple_nn import train_nn
            train.eval_seed = 1000
            X_train, Y_train = train.get_logistic_regression_data()
            X_val, Y_val = val.get_logistic_regression_data()
            nn_model = train_nn(X_train, Y_train, batch_size=1000, epochs=15, X_val=X_val, Y_val=Y_val)
            import gc
            del X_train, Y_train, X_val, Y_val
            gc.collect()
            base = f"{BASE_PATH}/scripts/notebooks/data/corelogic/neurips/"
            mean_diff_matrix, std_diff_matrix, diff_matrix_as_str, _ = get_auc_diff(
            model_set_seq=model,
            model_name1="set-seq",
            model_logistic=nn_model,
            model_name2="neural_network",
            test_set=test,
            seeds=[1000, 2000, 3000,4000,5000,6000, 7000, 8000, 9000, 10000],
            batch_size=1,
            all_units_in_batch_dim=False,
            heatmap_filename=base+"diff_matrix_set_seq_nn_mean_std_feb2.pdf"
            )
        #del train, val, test
        #gc.collect()\
        compare_set_seq_nn5_layer = True
        if compare_set_seq_nn5_layer:
            
            from train_simple_nn import train_nn
            train.eval_seed = 1000
            model_config_nn5layer = get_model_config(name="may_14_nn_5layer")  #   "randtrain-jan30-50"
            
            print("Set Model loaded")
            model_nn5layer = load_model_corelogic(**model_config_nn5layer)
            print("Log Model loaded")
            base = f"{BASE_PATH}/scripts/notebooks/data/corelogic/neurips/"
            mean_diff_matrix, std_diff_matrix, diff_matrix_as_str, _ = get_auc_diff(
            model_set_seq=model,
            model_name1="set-seq",
            model_logistic=model_nn5layer,
            model_name2="set-seq",
            test_set=test,
            seeds=[10,20,30,40,50], #1,2,3,4,5??
            batch_size=1,
            all_units_in_batch_dim=False,
            heatmap_filename=base+"diff_matrix_set_seq_nn_mean_std_may19.pdf"  # may 14
            )
            # Save the mean diff matrix
            np.save(base+"diff_matrix_set_seq_nn_mean_std_may19.npy", mean_diff_matrix)  #may 14
            #np.save(base+"diff_matrix_as_str_set_seq_nn_may14.npy", diff_matrix_as_str)
            np.save(base+"std_diff_matrix_set_seq_nn_may19.npy", std_diff_matrix)  # may 14
            #np.save(base+"auc_matrices_set_seq_nn_may14.npy", auc_matrices)

    
    
        
    first_experiment = False
    if first_experiment:
        print("Loading data")
        train, val, test = get_dataset(config)
        print("Data loaded")
        metric_list = run_evaluation(
            train, 
            val, 
            test,
            data_validation=False,
            data_sanity_checks=False,
            plot_active_and_defualt_loans=False,
            train_baseline_nn=True, # True
            train_baseline_logistic=True, # True
            model_eval=True,
            dataset="top4",
            date=date,
            nr_features=50
            )
        # Free memory
        import gc
        del train, val, test
        gc.collect()
    else:
        metric_list = []
    second_experiment = False
    if second_experiment:
        config["data_path"] = f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_52.npz"
        nr_features = 52
        train, val, test = get_dataset(config)

        metric_list_2 = run_evaluation(
            train, 
            val, 
            test,
            data_validation=False,
            data_sanity_checks=False,
            plot_active_and_defualt_loans=False,
            train_baseline_nn=False, # True
            train_baseline_logistic=False, # True
            model_eval=True,
            dataset="top41lz",
            date=date,
            nr_features=nr_features,
            all_units_in_batch_dimension=True
            )
        
        import gc
        del train, val, test
        gc.collect()
        
        metric_list.extend(metric_list_2)
    third_experiment = False
    if third_experiment:
        
        config["data_path"] = f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_55.npz" # check this
        print("loading started")
        train, val, test = get_dataset(config)
        test.eval_seed = 3000
        train.eval_seed = 3000
        print("loading done")
        nr_features = 50
        breakpoint()
        metric_list_3 = run_evaluation(
            train, 
            val, 
            test,
            data_validation=False,
            data_sanity_checks=False,
            plot_active_and_defualt_loans=False,
            train_baseline_nn=False, # True
            train_baseline_logistic=False, # True
            model_eval=False,
            dataset="randtrain", # randtrain
            date=date,
            nr_features=nr_features,
            all_units_in_batch_dimension=False
            )
        metric_list.extend(metric_list_3)

    if len(metric_list) > 0:
        print_metrics(metric_list)
    

    
if __name__ == "__main__":
    import cProfile
    import pstats
    profiler = cProfile.Profile()
    profiler.enable()
    main()
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats('cumulative')
    stats.print_stats(60)

    # 0: Paid Off
    # 1: 30 days delinquent
    # 2: 60 days delinquent
    # 3: 90 days delinquent
    # 4: Current
    # 5: Foreclosure
    # 6: REO
    # 7: End State after either REO or Paid off

    # Prepayment
    # Current to prepaid
    # 3dd -> prepaid
    # 60dd -> prepaid
    # 90dd -> prepaid
    # F -> P
    # (start_state, end_state): (4,0), (1,0), (2,0), (3, 0), (5,0): These are the AUCs we wants