import pandas as pd
from datetime import datetime
import numpy as np
import os
import sys
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
CORELOGIC_DATA_PATH = os.environ.get("CORELOGIC_DATA_PATH", "/share/data/llm_mortgages/original_data")
sys.path.append(BASE_PATH)
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from src.dataloaders.base import SequenceDataset
import seaborn as sns
import matplotlib.pyplot as plt
import pyarrow.csv as pac
import matplotlib.dates as mdates
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as mcolors
SEED_NR = 42
np.random.seed(SEED_NR)
torch.manual_seed(SEED_NR)

class CreateLoanData():
    def __init__(
            self, 
            path_origination, 
            path_performance,
            feature_set,
            database_size = 1000,
            max_to_sample = 10,
            nr_sampling_timesteps = 15,
            start_year = 1988,  # We are counting from January 1988
            end_year = 2023,    # We are counting until December 2022
            nr_classes = 8,  # 7 classes + 1 for padding
            normalize_data = True,
            columns_to_normalize_origination = None,
            columns_to_normalize_performance = None,
            verbose = True
            ):
        self.verbose = verbose
        use_pac = False
        if use_pac:
            self.origination = pac.read_csv(path_origination).to_pandas()
            self.performance = pac.read_csv(path_performance).to_pandas()
            self.unemployment_rate = pac.read_csv(f"{BASE_PATH}/data/unemployment.csv").to_pandas()
            self.national_mortgage_rate = pac.read_csv(f"{BASE_PATH}/data/national_mortgage_rate.csv").to_pandas()
        else:
            self.origination = pd.read_csv(path_origination)
            self.performance = pd.read_csv(path_performance)
            self.unemployment_rate = pd.read_csv(f"{BASE_PATH}/data/unemployment.csv")
            self.national_mortgage_rate = pd.read_csv(f"{BASE_PATH}/data/national_mortgage_rate.csv")
        if self.verbose:
            print("Initial size of origination data: ", len(self.origination))
        self.available_states = [col[:-2] for col in self.unemployment_rate.columns if col.endswith('UR')]
        self.required_origination_columns = ["loan_id", "origination_date", "fico_score_at_origination", "original_balance", "initial_interest_rate", "original_ltv"]
        self.origination = self.origination.dropna(subset=self.required_origination_columns)
        self.origination = self.origination.reset_index(drop=True)
        if self.verbose:
            print("Size of origination data after dropping missing values in required variables: ", len(self.origination))
        
        self.performance = self.performance.sort_values(by=["loan_age"])
        self.end_year = end_year

        
        self.database_size = database_size
        self.start_year = start_year
        self.nr_classes = nr_classes
        self.normalize_data = normalize_data
        self.feature_set = feature_set
        self.column_to_normalize_origination = columns_to_normalize_origination
        if self.normalize_data and columns_to_normalize_origination is not None:
            
            for col in self.column_to_normalize_origination:
                if self.origination[col].dtype == 'object':
                    self.origination[col] = pd.to_numeric(self.origination[col], errors='coerce')
                try:
                    col_mean = self.origination[col].mean(skipna=True)
                except:
                    breakpoint()
                col_std = self.origination[col].std(skipna=True)
                self.origination[col] = (self.origination[col] - col_mean) / col_std
        if self.normalize_data and columns_to_normalize_performance is not None:
            for col in columns_to_normalize_performance:
                col_mean = self.performance[col].mean(skipna=True)
                col_std = self.performance[col].std(skipna=True)
                self.performance[col] = (self.performance[col] - col_mean) / col_std
        self.nr_sampling_timesteps = nr_sampling_timesteps
        self.max_to_sample = max_to_sample
        self.mapping = {
            "0": 0,
            "3": 1,
            "6": 2,
            "9": 3,
            "C": 4,
            "F": 5,
            "R": 6,
        }
        self.setup()
    
    def one_hot_encode(self, feature, nr_classes):        
        one_hot = np.eye(nr_classes)[np.array(feature)]
        return one_hot.T
    
    def truncate_performance_data(self, performance_data):
        # Check for multiple occurrences of "R"
        performance_data.reset_index(drop=True, inplace=True)
        # Find the index of the first occurrence of "0" and "R"
        first_zero_index = performance_data[performance_data["mba_delinquency_status"] == "0"].index.min()
        first_reo_index = performance_data[performance_data["mba_delinquency_status"] == "R"].index.min()

        # Determine the truncation point
        truncate_index = len(performance_data)  # Default to no truncation
        if not pd.isna(first_zero_index):
            truncate_index = min(truncate_index, first_zero_index)
        if not pd.isna(first_reo_index):
            truncate_index = min(truncate_index, first_reo_index)

        # Truncate the DataFrame up to the earliest occurrence of "0" or "R"
        performance_data = performance_data.loc[:truncate_index].reset_index(drop=True)

        return performance_data


    def get_features(self, origination_data, performance, feature_set):
        features = [] # list of features, each feature 
        # has shape (feature_dim, number_timepoints), where feature_dim can be different for each feature

        # Truncates the performance data where there are multiple states that are paid-off or REO (see DL for Mortgage Risk Paper)
        performance = self.truncate_performance_data(performance)
        count_0 = (performance["mba_delinquency_status"] == "0").sum()
        assert not count_0 > 1, "There are multiple paid-off states in the performance data"
        

        # Features to add
        # Lagged Prime Default Rate in Zip Code
        # Lagged Subprime Default Rate in Zip Code

        nr_timepoints = len(performance["mba_delinquency_status"])
        for feature in feature_set:
            #if len(features) == 37:
            #    breakpoint()
            if feature == "current_state":
                status = performance["mba_delinquency_status"]  # update this to mba_delinquency_status
                #status = status.apply(lambda x: self.mapping[x])
                status = status.map(self.mapping)
                one_hot_status = self.one_hot_encode(status, self.nr_classes)
                features.append(one_hot_status)
            elif feature in {"original_balance", "fico_score_at_origination", "initial_interest_rate", "original_ltv"}:
                feature_val = origination_data[feature]
                feature_val_vec= np.ones(nr_timepoints)*feature_val
                features.append(np.expand_dims(feature_val_vec, axis=0))
            elif feature in  {"unemployment_rate", "national_mortgage_rate"}:
                origination_date = origination_data["origination_date"]
                year_month = str(int(origination_date))  # Remove the decimal part
                year = int(year_month[:4])  # First 4 characters represent the year
                month = int(year_month[4:6])  # Next 2 characters represent the month
                length = len(performance["mba_delinquency_status"])
                start_index_unemployment = (year - self.start_year) * 12 + month -1
                end_index_unemployment = start_index_unemployment + length
                if feature == "unemployment_rate":
                    if self.normalize_data:
                        state = origination_data["state"]
                        feature = self.unemployment_rate[state].iloc[start_index_unemployment:end_index_unemployment].values
                    else:
                        state = origination_data["state"]
                        feature = self.unemployment_rate[state+"UR"].iloc[start_index_unemployment:end_index_unemployment].values
                else:
                    if self.normalize_data:
                        feature = self.national_mortgage_rate["national_mortgage_rate_normalized"].iloc[start_index_unemployment:end_index_unemployment].values
                    else:
                        feature = self.national_mortgage_rate["national_mortgage_rate"].iloc[start_index_unemployment:end_index_unemployment].values
                features.append(np.expand_dims(feature, axis=0))
            # Binary Orignation Features
            elif feature in {"inferred_collateral_type", "convertible_flag", "pool_insurance_flag", "io_flag", "prepay_penalty_flag", "negative_amortization_flag","buydown_flag"} :  # Prime vs Subprime
                value = origination_data[feature]
                # if value is U or nan, add it to missing indicator
                # if value is Y or P add it as 1, if value is N or S add it as 0 (set this to 0.5 if missing)
                missing_indicator = int((value == "U"))
                mapping = {"Y": 1, "P": 1, "N": 0, "S": 0, "U": 0.5, np.nan: 0.5, None: 0.5}
                value = mapping[value]
                if value not in {0,0.5, 1}:
                    value = 0.5
                missing_indicator_feature = np.ones(len(performance["mba_delinquency_status"]))*missing_indicator
                value_feature = np.ones(len(performance["mba_delinquency_status"]))*value
                features.append(np.expand_dims(missing_indicator_feature, axis=0))
                features.append(np.expand_dims(value_feature, axis=0))
            # Continuous Performance Features
            elif feature in {"current_interest_rate", "current_balance", "scheduled_monthly_pi", "scheduled_principal", "mba_days_delinquent"}:
                current_interest_rate = performance[feature]
                missing_interest_rate_indicator = ((current_interest_rate.isnull())).astype(int)
                if missing_interest_rate_indicator.sum() == len(missing_interest_rate_indicator):
                    current_interest_rate = np.zeros(len(missing_interest_rate_indicator))
                else:
                    current_interest_rate = current_interest_rate.replace(np.nan).fillna(method='ffill')
                    current_interest_rate = current_interest_rate.replace(np.nan, 0)
                if feature == "scheduled_monthly_pi":
                    # Clip the values to between -5 and 5
                    current_interest_rate = np.clip(current_interest_rate, -5, 5)
                features.append(np.expand_dims(missing_interest_rate_indicator, axis=0))
                features.append(np.expand_dims(current_interest_rate, axis=0))
            elif feature in {"times_30dd", "times_60dd", "times_90dd", "times_current", "times_foreclosure"}: # in past 12 months
                delinquency_history_string = performance["delinquency_history_string"]

                mapping = {"times_30dd": "3", "times_60dd": "6", "times_90dd": "9", "times_current": "C", "times_foreclosure": "F"}
                key = mapping[feature]
                missing_indicator = (delinquency_history_string.isnull()).astype(int)
                delinquency_history_string = delinquency_history_string.fillna("0")
                
                times = delinquency_history_string.apply(lambda x: x.count(key)/4) # divide by 4 as normalization
                
                if feature == "times_30dd":
                    features.append(np.expand_dims(missing_indicator, axis=0))
                if np.max(times) > 4:
                    breakpoint()
                features.append(np.expand_dims(times, axis=0))
            # Features Needing Special Treatment
            elif feature in {"original_term"}:
                original_term = origination_data["original_term"]
                missing_indicator = int(np.isnan(original_term))
                original_term = original_term if np.sum(missing_indicator) == 0 else 360
                less_than_200 = int(original_term < 200)
                # vector of length nr_timepoints
                less_than_200 = np.ones(len(performance["mba_delinquency_status"]))*less_than_200
                missing_indicator = np.ones(len(performance["mba_delinquency_status"]))*missing_indicator
                
                features.append(np.expand_dims(missing_indicator, axis=0))
                features.append(np.expand_dims(less_than_200, axis=0))
                
            # Count performance Feature
            elif feature in {"loan_age"}:
                loan_age = performance["loan_age"]
                missing_indicator = (loan_age.isnull()).astype(int)
                # one indicator if loan age is less than 12 months, one indicator if loan age is less than 60 months
                # one indicator if loan age is less than 120 months
                loan_age = loan_age.fillna(0)

                less_than_12 = (loan_age < 12).astype(int)
                less_than_60 = (loan_age < 60).astype(int)
                less_than_120 = (loan_age < 120).astype(int)
                if np.max(less_than_120) > 1:
                    breakpoint()
                
                features.append(np.expand_dims(missing_indicator, axis=0))
                features.append(np.expand_dims(less_than_12, axis=0))
                features.append(np.expand_dims(less_than_60, axis=0))
                features.append(np.expand_dims(less_than_120, axis=0))
            elif feature in {"zip-code"}:
                # The ZIP Codes are: 92677, with 109642 loans in Orange County;
                # 93065, with 98673 loans in Simi Valley;
                # 91709, with 95497 loans in Chino Hills;
                # 92336 with 94794 loans in Fontana. 
                # Create an indicator for each of these zip codes, and one for the rest/missing
                # Get the property zip code
                
                zip_code = origination_data["property_zip"]
                
                # Create 5 indicators (4 specific zip codes + 1 for others/missing)
                target_zips = {92677, 93065, 91709, 92336}
                nr_indicators = 5
                indicators = np.zeros((nr_indicators, nr_timepoints))
                
                # If zip code matches one of our target zips, set that indicator to 1
                # Otherwise set the "other/missing" indicator (last one) to 1
                if pd.isna(zip_code):
                    indicators[4, :] = 1  # Set other/missing indicator
                else:
                    zip_found = False
                    for i, target_zip in enumerate(target_zips):
                        if zip_code == target_zip:
                            indicators[i, :] = 1
                            zip_found = True
                            break
                    if not zip_found:
                        indicators[4, :] = 1  # Set other/missing indicator
                
                features.append(indicators)

        try:
            X = np.concatenate(features, axis=0)
        except:
            breakpoint()
        # Check if X contains NaN values
        if np.isnan(X).sum() > 0:
            breakpoint()
        assert np.isnan(X).sum() == 0
        if np.max(X) > 100:
            for i in range(len(features)):
                print(i)
                #print("Feature: ", feature_set[i])
                print("Max value: ", np.max(features[i]))
                    
            breakpoint()
        assert np.max(X) <= 100
        delinquency_status = performance["mba_delinquency_status"]
        #delinquency_status = delinquency_status.apply(lambda x: self.mapping[x])
        delinquency_status = delinquency_status.map(self.mapping)
        one_hot_delinquency_status = self.one_hot_encode(status, self.nr_classes)
        Y = one_hot_delinquency_status
        # X[i] is used to predict Y[i+1]
        return X,Y
    
    def get_foreclosure_and_prepayment_rates(self):
        Ynp = self.db_Y
        # Ynp shape (nr_classes, nr_timesteps, nr_loans)
        nr_active_loans = np.sum(np.max(Ynp[:-1,:,:], axis=0), axis=1)
        nr_foreclosed = np.sum(Ynp[5,:,:], axis=1)
        nr_prepaid = np.sum(Ynp[0,:,:], axis=1)
        foreclosure_rate = nr_foreclosed/(nr_active_loans+1)
        prepayment_rate = nr_prepaid/(nr_active_loans+1)
        assert np.isnan(foreclosure_rate).sum() == 0
        assert np.isnan(prepayment_rate).sum() == 0
        return foreclosure_rate, prepayment_rate

    def add_lagged_prepayment_and_foreclosure_rates(self, normalize=True):
        # Add the lagged prepayment and foreclosure rates to the feature vector
        # X shape (nr_features, nr_timesteps, nr_loans)

        def get_foreclosure_and_prepayment_rates():
            Ynp = self.db_Y
            # Ynp shape (nr_classes, nr_timesteps, nr_loans)
            nr_active_loans = np.sum(np.max(Ynp[:-1,:,:], axis=0), axis=1)
            nr_foreclosed = np.sum(Ynp[5,:,:], axis=1)
            nr_prepaid = np.sum(Ynp[0,:,:], axis=1)
            foreclosure_rate = nr_foreclosed/(nr_active_loans+1)
            prepayment_rate = nr_prepaid/(nr_active_loans+1)
            assert np.isnan(foreclosure_rate).sum() == 0
            assert np.isnan(prepayment_rate).sum() == 0
            return foreclosure_rate, prepayment_rate
        foreclosure_rate, prepayment_rate = get_foreclosure_and_prepayment_rates()
        if normalize:
            foreclosure_rate = (foreclosure_rate - np.mean(foreclosure_rate))/np.std(foreclosure_rate)
            prepayment_rate = (prepayment_rate - np.mean(prepayment_rate))/np.std(prepayment_rate)
        foreclosure_feature = np.zeros((1, self.db.shape[1], self.db.shape[2]))
        prepayment_feature = np.zeros((1, self.db.shape[1], self.db.shape[2]))
        
        for t in range(1,self.db.shape[1]):
            for i in range(self.db.shape[2]):
                if self.db_Y[-1,t,i] != 1:
                    foreclosure_feature[0,t,i] = foreclosure_rate[t-1]
                    prepayment_feature[0,t,i] = prepayment_rate[t-1]
        # I think this is incorrect we should use the current period foreclosure,
        # In this month the foreclosure rate was X percentage, we want to predict state next month
        # Currently we are lagging incorrectly
        self.db = np.concatenate((self.db, foreclosure_feature, prepayment_feature), axis=0)
    
    def add_zip_code_lagged_foreclosure_prepayment(self, normalize=True):
        """
        Adds zip-code level foreclosure and prepayment rates, consolidated into 3 features:
        
        1. Foreclosure rate for loans in the same zip code and same prime/subprime category
        2. Prepayment rate for loans in the same zip code and same prime/subprime category
        3. Missing indicator (if either zip code or prime/subprime status is missing)
        
        The zip codes are features 50-53 with missing indicator feature 54.
        The prime-flag is feature 25, with missing indicator feature 24.
        
        Args:
            normalize (bool): Whether to normalize the rates. Default is True.
        """
        # Define zip codes from the problem description
        target_zips = [92677, 93065, 91709, 92336]
        
        # Shape: (nr_features, nr_timesteps, nr_loans)
        nr_timesteps = self.db.shape[1]
        nr_loans = self.db.shape[2]
        
        # Initialize features - just 3 now
        foreclosure_rates = np.zeros((1, nr_timesteps, nr_loans))
        prepayment_rates = np.zeros((1, nr_timesteps, nr_loans))
        missing_indicator = np.zeros((1, nr_timesteps, nr_loans))
        
        # For each timestep
        for t in range(nr_timesteps):
            # Skip first timestep since we need t-1 for lagging
            if t == 0:
                continue
                
            # Calculate rates for each zip code and prime/subprime combination
            rates_by_group = {}  # Will store (foreclosure_rate, prepayment_rate) for each (zip_code, is_prime)
            
            # First pass: calculate rates for each group
            for zip_idx, zip_code in enumerate(target_zips):
                for is_prime in [0, 1]:  # 0 = subprime, 1 = prime
                    # Create masks for active loans, specific zip code, and loan type
                    active_mask = (self.db_Y[-1, t, :] != 1)  # Not in exit state
                    zip_mask = (self.db[50 + zip_idx, t, :] == 1)
                    loan_type_mask = (self.db[25, t, :] == is_prime)
                    
                    # Missing indicator masks
                    zip_missing = (self.db[54, t, :] == 1)
                    loan_type_missing = (self.db[24, t, :] == 1)
                    
                    # Combined mask for loans of this zip and type
                    combined_mask = active_mask & zip_mask & loan_type_mask & (~zip_missing) & (~loan_type_missing)
                    
                    # Count loans from previous timestep (t-1)
                    prev_active_loans = combined_mask.sum()
                    
                    # If no loans of this type and zip code, continue
                    if prev_active_loans == 0:
                        continue
                    
                    # Count foreclosures and prepayments in previous timestep
                    foreclosed_mask = (self.db_Y[5, t, :] == 1)
                    prepaid_mask = (self.db_Y[0, t, :] == 1)
                    
                    prev_foreclosures = (foreclosed_mask & combined_mask).sum()
                    prev_prepayments = (prepaid_mask & combined_mask).sum()
                    
                    # Calculate rates without normalizing here
                    foreclosure_rate = prev_foreclosures / (prev_active_loans + 1)  # Add 1 to avoid division by zero
                    prepayment_rate = prev_prepayments / (prev_active_loans + 1)
                    
                    # Store the rates for this group
                    rates_by_group[(zip_code, is_prime)] = (foreclosure_rate, prepayment_rate)
            
            # Second pass: assign rates to each loan based on its zip code and prime/subprime status
            for i in range(nr_loans):
                # Skip if loan is not active
                if self.db_Y[-1, t, i] == 1:
                    continue
                
                # Check if zip code or loan type is missing
                zip_missing = (self.db[54, t, i] == 1)
                loan_type_missing = (self.db[24, t, i] == 1)
                
                if zip_missing or loan_type_missing:
                    # If either is missing, set the missing indicator
                    missing_indicator[0, t, i] = 1
                    continue
                
                # Determine the loan's zip code
                loan_zip = None
                for zip_idx, zip_code in enumerate(target_zips):
                    if self.db[50 + zip_idx, t, i] == 1:
                        loan_zip = zip_code
                        break
                
                # Determine if the loan is prime
                is_prime = (self.db[25, t, i] == 1)
                
                # If we have rates for this group, assign them
                if loan_zip and (loan_zip, is_prime) in rates_by_group:
                    foreclosure_rate, prepayment_rate = rates_by_group[(loan_zip, is_prime)]
                    foreclosure_rates[0, t, i] = foreclosure_rate
                    prepayment_rates[0, t, i] = prepayment_rate
        
        if normalize:
            # 1) Gather ALL foreclosure and prepayment values for active loans across time
            #    (Skip t=0 if it is always zero or not relevant.)
            fore_values = []
            prep_values = []
            
            nr_timesteps = self.db.shape[1]
            nr_loans = self.db.shape[2]
            
            for t in range(1, nr_timesteps):
                # Active loans are those not in an exit state at time t
                active_mask = (self.db_Y[-1, t, :] != 1)
                
                # Extract only the active subset
                fore_t_active = foreclosure_rates[0, t, active_mask]
                prep_t_active = prepayment_rates[0, t, active_mask]
                
                # Extend our global lists
                fore_values.extend(fore_t_active)
                prep_values.extend(prep_t_active)
            
            # Convert to numpy arrays for mean/std calculations
            fore_values = np.array(fore_values)
            prep_values = np.array(prep_values)
            
            if len(fore_values) > 1:
                mean_fore = fore_values.mean()
                std_fore = fore_values.std()
            else:
                mean_fore, std_fore = 0.0, 1.0  # fallback
            
            if len(prep_values) > 1:
                mean_prep = prep_values.mean()
                std_prep = prep_values.std()
            else:
                mean_prep, std_prep = 0.0, 1.0  # fallback

            # 3) Apply that mean/std to each timestep's active loans
            for t in range(1, nr_timesteps):
                active_mask = (self.db_Y[-1, t, :] != 1)
                
                # Foreclosure
                f = foreclosure_rates[0, t, active_mask]
                # Optionally skip zeros if you don't want them shifted to negative or positive
                # (depends on your preference):
                #   non_zero_mask = f != 0
                #   f[non_zero_mask] = (f[non_zero_mask] - mean_fore) / std_fore
                #
                # For simplicity, standardize everything (including zeros) here:
                if std_fore > 0:
                    f = (f - mean_fore) / std_fore
                
                # Prepayment
                p = prepayment_rates[0, t, active_mask]
                if std_prep > 0:
                    p = (p - mean_prep) / std_prep
                
                # Assign back only to active subset
                foreclosure_rates[0, t, active_mask] = f
                prepayment_rates[0, t, active_mask] = p

        
        # Concatenate the three features to self.db
        self.db = np.concatenate((self.db, foreclosure_rates, prepayment_rates, missing_indicator), axis=0)
    

    def add_lagged_rates_two_indicators(self):
        """
        Adds *four* new indicator features (instead of a single median-based one):
        1) foreclosure_rate[t-1] < fore_25
        2) foreclosure_rate[t-1] < fore_75
        3) prepayment_rate[t-1] < prep_25
        4) prepayment_rate[t-1] < prep_75

        Each is a binary feature in [0,1]. We only set it if the loan is active 
        at time t (i.e., not in exit state).

        The final shape of self.db becomes [nr_features + 4, nr_timesteps, nr_loans].
        """

        # 1) Compute the global foreclosure & prepayment rates per time
        foreclosure_rate, prepayment_rate = self.get_foreclosure_and_prepayment_rates()
        # Each is shape [nr_timesteps], i.e. foreclosure_rate[t] is fraction at time t

        # 2) Determine the 25th and 75th percentiles for each array
        fore_25 = np.percentile(foreclosure_rate, 25)
        fore_75 = np.percentile(foreclosure_rate, 75)
        prep_25 = np.percentile(prepayment_rate, 25)
        prep_75 = np.percentile(prepayment_rate, 75)

        # 3) Prepare arrays for four new features
        T = self.db.shape[1]   # nr_timesteps
        L = self.db.shape[2]   # nr_loans

        # shape: (1, T, L) for each new feature
        foreclosure_lt25_feat = np.zeros((1, T, L), dtype=self.db.dtype)
        foreclosure_lt75_feat = np.zeros((1, T, L), dtype=self.db.dtype)
        prepayment_lt25_feat  = np.zeros((1, T, L), dtype=self.db.dtype)
        prepayment_lt75_feat  = np.zeros((1, T, L), dtype=self.db.dtype)

        # 4) Double loop: time t in [1..T-1], loans i in [0..L-1].
        #    We skip t=0 because there's no t-1 for a lag.
        for t in range(1, T):
            # We'll look up the "lagged" rates at time t-1
            fore_val = foreclosure_rate[t-1]
            prep_val = prepayment_rate[t-1]

            for i in range(L):
                # Check if loan i is still active at time t
                if self.db_Y[-1, t, i] != 1:
                    #  -- Foreclosure indicators --
                    if fore_val < fore_25:
                        foreclosure_lt25_feat[0, t, i] = 1.0
                    if fore_val < fore_75:
                        foreclosure_lt75_feat[0, t, i] = 1.0

                    #  -- Prepayment indicators --
                    if prep_val < prep_25:
                        prepayment_lt25_feat[0, t, i] = 1.0
                    if prep_val < prep_75:
                        prepayment_lt75_feat[0, t, i] = 1.0

        # 5) Concatenate these 4 features to self.db along the feature dimension
        #    Original shape => [nr_features, T, L]
        #    New shape => [nr_features + 4, T, L]
        self.db = np.concatenate(
            (
                self.db,
                foreclosure_lt25_feat,
                foreclosure_lt75_feat,
                prepayment_lt25_feat,
                prepayment_lt75_feat,
            ),
            axis=0
        )
            
    
    def valid_data(self, performance, origination_data):
        delinquency_status = performance["mba_delinquency_status"]
        
        if (
            ("T" in delinquency_status.values) or 
            ("X" in delinquency_status.values) or 
            ("Z" in delinquency_status.values) or 
            ("S" in delinquency_status.values)
            ):
            return False
        
        if len(delinquency_status) == 0:
            return False
        # Check that fico_score_at_origination, original_balance, initial_interest_rate, original_ltv are not missing
        if (
            np.isnan(origination_data["fico_score_at_origination"]) or
            np.isnan(origination_data["original_balance"]) or
            np.isnan(origination_data["initial_interest_rate"]) or
            np.isnan(origination_data["original_ltv"]) or
            np.isnan(origination_data["origination_date"])
            ):
            return False
        
        # Check for negative values in loan_age column of performance data
        if (performance["loan_age"] < 0).sum() > 0:
            return False
        if origination_data["state"] not in self.available_states:
            if origination_data["property_zip"] in {92677, 93065, 91709, 92336}:
                origination_data["state"] = "CA"
            elif origination_data["property_zip"] in {80013, 80015}:
                origination_data["state"] = "CO"
            else:
                breakpoint()
                return False
        origination_date = origination_data["origination_date"]
        performance_data_valid = self.check_performance_data(performance, origination_date)
        if not performance_data_valid:
            return False
        # Convert the float to a string, remove the decimal, and parse the year and month
        year_month = str(int(origination_date))  # Remove the decimal part
        year = int(year_month[:4])  # First 4 characters represent the year
        if year >= self.end_year:
            return False
        if year < self.start_year:
            return False
        return True
    
    def check_performance_data(self, performance_data, origination_date):
        year_month = str(int(origination_date))  # Remove the decimal part
        year = int(year_month[:4])  # First 4 characters represent the year
        month = int(year_month[4:6])  # Next 2 characters represent the month
        # Create a datetime object for the given year and month
        date = datetime(year, month, 1)
        # Create a datetime object for start year
        epoch = datetime(self.start_year, 1, 1)
        L = len(performance_data)-1
        months_since_start = (date.year - epoch.year) * 12 + (date.month - epoch.month) + L

        end_state = performance_data["mba_delinquency_status"].iloc[-1]
        
        #if months_since_start < 12*(self.end_year - self.start_year) and months_since_start >12*(self.end_year - self.start_year) -10:
         #   breakpoint()

        # -1 for margin
        if months_since_start < 12*(self.end_year - self.start_year) - 1 and end_state not in ["0", "R"]:
            
            return False
        
        # Check for invalid transitions ("C" -> "2", "3", or "1" to "3")
        delinquency_status_str = "".join(list(performance_data["mba_delinquency_status"]))
        if "C6" in delinquency_status_str or "C9" in delinquency_status_str or "39" in delinquency_status_str:
            
            return False
        return True


    def setup(self):
        temp_db = []
        self.diagonal = []
        added_loans = 0
        i = 0
        performance_grouped = self.performance.groupby("loan_id")
        while added_loans < self.database_size and i < len(self.origination):
            loan_id = self.origination.loc[i, "loan_id"]
            try:
                performance_data = performance_grouped.get_group(loan_id)
            except:
                i += 1
                continue
            
            #self.performance[self.performance["loan_id"] == loan_id]
            origination_data = self.origination.loc[i]
            
            if not self.valid_data(performance_data, origination_data):
                i += 1
                continue
            
            
            
            origination_date = origination_data["origination_date"]
            year_month = str(int(origination_date))  # Remove the decimal part
            year = int(year_month[:4])  # First 4 characters represent the year
            month = int(year_month[4:6])  # Next 2 characters represent the month
            # Create a datetime object for the given year and month
            date = datetime(year, month, 1)
            # Create a datetime object for start year
            epoch = datetime(self.start_year, 1, 1)
            # Calculate the number of months since start year
            months_since_epoch = (date.year - epoch.year) * 12 + (date.month - epoch.month)
            if months_since_epoch + len(performance_data) > 12*(self.end_year - self.start_year):
                
                # Truncate performance_data
                steps_to_truncate = len(performance_data)+months_since_epoch - 12*(self.end_year - self.start_year)
                performance_data = performance_data.iloc[:-steps_to_truncate]  

            features, ground_truth = self.get_features(origination_data, performance_data, self.feature_set)

            if added_loans == 0:
                self.nr_features = features.shape[0]
            temp_db.append((months_since_epoch, features, ground_truth))
            self.diagonal.append((months_since_epoch,added_loans))
            i+=1
            added_loans += 1
        
        self.database_size = added_loans
        if self.verbose:
            print("Added loans: ", added_loans)
            print("Loans filtered out:", i - added_loans)
        # sort the database by the number of months since start year
        temp_db = sorted(temp_db, key=lambda x: x[0])
        self.diagonal = sorted(self.diagonal, key=lambda x: x[0])
        self.nr_timesteps = 12*(self.end_year - self.start_year)
        self.db = np.zeros((self.nr_features,  self.nr_timesteps, self.database_size))
        self.db_Y = np.zeros((self.nr_classes, self.nr_timesteps, self.database_size))
        self.db_Y[-1,:,:] = 1
        self.db[7,:,:] = 1
        for i in range(self.database_size):     
                start_time = self.diagonal[i][0]
                end_time = start_time + len(temp_db[i][2][0,:])
                if end_time > self.nr_timesteps:
                    # breakpoint()  # We should cut off the performance data vec rather than reach time point
                    self.db[:,start_time: ,i] = temp_db[i][1][:,:self.nr_timesteps-start_time]
                    self.db_Y[:,start_time: , i] = temp_db[i][2][:,:self.nr_timesteps-start_time]
                else:
                    self.db[:,start_time: end_time ,i] = temp_db[i][1]
                    self.db_Y[:,start_time: end_time, i] = temp_db[i][2]
        
        
        if "lagged_foreclosure_rate" in  self.feature_set and "lagged_prepayment_rate" in self.feature_set:
            #self.add_lagged_prepayment_and_foreclosure_rates(self.normalize_data)
            #self.add_lagged_rates_two_indicators()
            self.add_zip_code_lagged_foreclosure_prepayment(self.normalize_data)

        # Check if self.db contains NaN values
        assert np.isnan(self.db).sum() == 0
        # Check if self.db_Y contains NaN values 
        assert np.isnan(self.db_Y).sum() == 0
        


    def return_data_dict(self):

        data_dict = {
            "X": self.db,  # shape (nr_features, nr_timesteps, nr_loans)
            "Y": self.db_Y,  # shape (nr_classes, nr_timesteps, nr_loans)
            "diagonal": self.diagonal, # (nr_loans, 2)
            "dataset_config": {
                "start_year": self.start_year,
                "end_year": self.end_year,
                "nr_classes": self.nr_classes,
                "nr_features": self.nr_features,
            }
        }
        return data_dict


class LoanDataLoader(Dataset):

    def __init__(self, data_dict, limits, sampling_limits, config, name, steps_per_epoch):
        super(LoanDataLoader,self).__init__()
        #super(LoanDataLoader, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #self.X = torch.tensor(data_dict["X"]).half().to(self.device) # shape (nr_features, nr_timesteps, nr_loans)
        #self.Y = torch.tensor(data_dict["Y"]).half().to(self.device) # shape (nr_classes, nr_timesteps, nr_loans)
        self.store_on_cpu = True
        if self.store_on_cpu:
            self.X = torch.tensor(data_dict["X"]).half()  # stored in fp16 on CPU
            self.Y = torch.tensor(data_dict["Y"]).half()  # stored in fp16 on CPU
        else:
            self.X = torch.tensor(data_dict["X"]).half().to(self.device)  # stored in fp16 on GPU
            self.Y = torch.tensor(data_dict["Y"]).half().to(self.device)  # stored in fp16 on GPU
        
        try:
            self.dataset_config_dict = data_dict["dataset_config"].item()
        except:
            self.dataset_config_dict = data_dict["dataset_config"]

        self.nr_features = self.dataset_config_dict["nr_features"]
        self.diagonal = data_dict["diagonal"]
        self.max_to_sample = config["max_to_sample"]
        self.lower_bound = limits[0]
        self.upper_bound = limits[1]
        self.lower_sampling_bound = sampling_limits[0]
        self.upper_sampling_bound = sampling_limits[1]
        self.nr_sampling_timesteps = config["nr_sampling_timesteps"]
        self.nr_loans_to_sample = config["nr_loans_to_sample"]
        self.sample_random_loan_index = config["sample_random_loan_index"]
        self.sample_random_time_index = config["sample_random_time_index"]
        self.use_time_weighting = config["use_time_weighting"]
        self.time_weighting_half_life = config["time_weighting_half_life"]
        self.eval_mode = config["eval_mode"]
        self.eval_seed = config["eval_seed"]
        self.name = name
        self.steps_per_epoch = steps_per_epoch
        self.start_year = self.dataset_config_dict["start_year"]
        self.end_year = self.dataset_config_dict["end_year"]
        self.return_start_time = False
    
    def print_counts_per_year(self):
        count_per_start_year = [0]*(self.end_year - self.start_year)
        for i in range(len(self.diagonal)):
            start_time = self.diagonal[i][0]
            year = self.start_year + start_time//12
            count_per_start_year[year - self.start_year] += 1
        print("Count per start year: (first year {}, last year {})".format(self.start_year, self.end_year), count_per_start_year)
    
    def feature_visualization(self, nr_loans_to_sample=1000):
        # Find n =100 loans that transition into the Paid Off state from C, with at least 5 total series transitions
        # Then extract the current outstanding balance, for the previous 5 period
        # Show the average current outstanding balance for each of the previous 5 periods

        # Now do the same for loans that transition from C to C
        # state 0 is paid off, state 4 is current
        # shape (nr_classes, nr_timesteps, nr_loans)
        Ynp = self.Y.cpu().numpy()
        transitions_ending_in_state_0 = []
        nr_added = 0
        n = nr_loans_to_sample
        transitions_ending_in_state_4 = []
        nr_added_c_c =0
        lookback = 10
        # sample 10000 indices from 0,len(self.diagonal):
        indices1 = np.random.choice(len(self.diagonal)-1, nr_loans_to_sample*3, replace=True)
        indices2 = np.random.choice(len(self.diagonal)-1, nr_loans_to_sample*3, replace=True)
        for i in indices1:
            
            start_time = self.diagonal[i][0]

            curr = start_time
            while curr < self.Y.shape[1] and np.argmax(Ynp[:,curr,i]) != 0:
                curr += 1
            if curr < self.Y.shape[1] and curr - start_time >= lookback and nr_added <= n and np.argmax(Ynp[:,curr,i]) == 0:
                transitions_ending_in_state_0.append((i, curr - lookback, curr))
                nr_added +=1
            if nr_added ==n:
                break
        
        for i in indices2:
            start_time = self.diagonal[i][0]

            curr = start_time
            added_this_loan = False
            while curr < self.Y.shape[1] and np.argmax(Ynp[:,curr,i]) not in  [0,6,7]:
                curr += 1
            
            if curr - start_time >= lookback+1:
                pick_a_end_time = np.random.choice(range(start_time+lookback, curr))
                if np.argmax(Ynp[:,pick_a_end_time,i]) == 4:
                    transitions_ending_in_state_4.append((i, pick_a_end_time-lookback, pick_a_end_time))
                    nr_added_c_c +=1
            if nr_added_c_c == n:
                break

        feature_vec_list_end_0 = []
        feature_vec_list_end_4 = []
        for i in range(n):
            loan_idx, start_time, end_time = transitions_ending_in_state_0[i]
            feature_vec = self.X[:,start_time:end_time+1,loan_idx].cpu().numpy()
            feature_vec_list_end_0.append(feature_vec)
            loan_idx, start_time, end_time = transitions_ending_in_state_4[i]
            feature_vec = self.X[:,start_time:end_time+1,loan_idx].cpu().numpy()
            feature_vec_list_end_4.append(feature_vec)
        # Take the feature mean over the n samples
        feature_vec_end_0 = np.mean(np.stack(feature_vec_list_end_0), axis=0)
        assert feature_vec_end_0.shape == (self.nr_features, lookback+1)
        feature_vec_end_4 = np.mean(np.stack(feature_vec_list_end_4), axis=0)
        assert feature_vec_end_4.shape == (self.nr_features, lookback+1)
        index_to_feature_mapping ={
            14: "missing_ind_curr_balance",
            15: "current_balance",
            16: "missing_ind_ir",
            17: "current_interest_rate",
            33: "prepay_penalty_flag",
            32: "missing_ind_prepay_penalty_flag",
            13: "national_mortgage_rate",
            12: "unemployment_rate",
            18: "missing_ind_scheduled_monthly_pi",
            19: "scheduled_monthly_pi",
        }
        normalized_features = [15,17]
        features = [14,15,32,33,16,17,12,13]
        for index_feature_to_consider in features:
            min_feature_value = np.min(np.stack(feature_vec_list_end_0)[:,index_feature_to_consider,:]) 
            
            plt.plot(feature_vec_end_0[index_feature_to_consider,:][::-1], label="C -> Paid Off")
            plt.plot(feature_vec_end_4[index_feature_to_consider,:][::-1], label="C -> C")
            # plot the min feature value (corresponding to current balance = 0)
            plt.plot(np.ones(lookback+1)*min_feature_value, label="Minimal feature value: {}".format(index_to_feature_mapping[index_feature_to_consider]))
            plt.legend()
            plt.title(f"Average {index_to_feature_mapping[index_feature_to_consider]} for loans transitioning \n C -> Paid Off and C -> C, in the last plotted period")
            plt.xlabel("Periods before transition")
            normalized = " (normalized)" if index_feature_to_consider in normalized_features else ""
            plt.ylabel("{}{}".format(index_to_feature_mapping[index_feature_to_consider], normalized))
            plt.savefig("{}/scripts/notebooks/data/corelogic/feature_visualization_{}.png".format(BASE_PATH, index_to_feature_mapping[index_feature_to_consider]))
            plt.close() 


    
    def print_full_dataset_statistics(self, input=None):
        # Count the number timepoints in each class
        no_input = input is None
        if input is not None:
            Ynp = input.cpu().numpy()
        else:
            self.print_counts_per_year()
            Ynp = self.Y.cpu().numpy()
        sum_vec = []
        for i in range(8):
            sum_vec.append(np.sum(Ynp[i,:,:] == 1))  # db_y has shape (nr_classes, nr_timesteps, nr_loans)
        d = {
            "Paid Off": sum_vec[0],  # this should be smaller than the number of loans
            "# 30 days late": sum_vec[1],
            "# 60 days late": sum_vec[2],
            "# 90 days late": sum_vec[3],
            "Current": sum_vec[4],
            "Foreclosed": sum_vec[5],
            "REO": sum_vec[6],
            "End State after REO / Paid off": sum_vec[7]
        }
        print(d)
        nr_foreclosed_at_least_once = np.sum(np.max(Ynp[5,:,:], axis=0))
        print("Number of Loans in the Foreclosed state at least once:", nr_foreclosed_at_least_once)

        if no_input: # We are using the full dataset
            average_sequence_length = np.sum(list(d.values())[:-1])/len(self.diagonal)
            print("Average sequence length: ", average_sequence_length)
            print("Total number of loans: ", len(self.diagonal))
            print("Fraction of Loans REO: ", sum_vec[6]/len(self.diagonal))
    
    def get_foreclosure_and_prepayment_rates(self):
        Ynp = self.Y.cpu().numpy()
        nr_active_loans = np.sum(np.max(Ynp[:-1,:,:], axis=0), axis=1)
        nr_foreclosed = np.sum(Ynp[5,:,:], axis=1)
        nr_prepaid = np.sum(Ynp[0,:,:], axis=1)
        foreclosure_rate = nr_foreclosed/(nr_active_loans+1)
        prepayment_rate = nr_prepaid/(nr_active_loans+1)
        assert np.isnan(foreclosure_rate).sum() == 0
        assert np.isnan(prepayment_rate).sum() == 0
        return foreclosure_rate, prepayment_rate
    
    
    def get_nr_active_loans_plot(self, name=""):
        # Ynp has shape (nr_classes, nr_timesteps, nr_loans)
        Ynp = self.Y.cpu().numpy()
        months_since_1988 = list(range(0, Ynp.shape[1]))  # Total months from Jan 1988

        nr_active_loans = np.sum(np.max(Ynp[:-1, :, :], axis=0), axis=1)
        foreclosure_rate, prepayment_rate = self.get_foreclosure_and_prepayment_rates()

        # Convert months since Jan 1988 to datetime objects
        start_date = pd.Timestamp("1987-12-01")
        end_date = pd.Timestamp("2023-12-01")  # Restrict to 2023
        dates = [start_date + pd.DateOffset(months=month) for month in months_since_1988]

        # Filter data to only include dates within 1988-2023
        valid_indices = [i for i, date in enumerate(dates) if start_date <= date <= end_date]

        dates = [dates[i] for i in valid_indices]
        nr_active_loans = nr_active_loans[valid_indices]
        foreclosure_rate = foreclosure_rate[valid_indices]
        prepayment_rate = prepayment_rate[valid_indices]

        # Increase font size for publication-quality figure
        plt.rcParams.update({'font.size': 14}) 

        fig, ax1 = plt.subplots(figsize=(12, 6))

        # Plot nr_active_loans on the primary y-axis
        ax1.plot(dates, nr_active_loans, label="\# Active Loans", color="blue")
        ax1.set_xlabel("Year", fontsize=16)
        ax1.set_ylabel("Number of Active Loans", fontsize=16, color="black")  # Y-axis now black
        ax1.tick_params(axis='y', labelcolor="black")
        ax1.tick_params(axis='both', which='major', labelsize=14)
        plt.xticks(rotation=45)

        # Set x-axis ticks and format
        ax1.xaxis.set_major_locator(mdates.YearLocator())
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))

        # Explicitly set x-axis limits to remove unwanted ticks
        ax1.set_xlim(pd.Timestamp("1988-01-01"), pd.Timestamp("2023-12-31"))

        # Add a secondary y-axis for foreclosure_rate and prepayment_rate
        ax2 = ax1.twinx()
        ax2.plot(dates, foreclosure_rate, label="Foreclosure Rate", color="red")
        ax2.plot(dates, prepayment_rate, label="Prepayment Rate", color="green")
        ax2.set_ylabel("Rates", fontsize=16, color="black")  # Y-axis now black
        ax2.tick_params(axis='y', labelcolor="black")

        # Add legends for both y-axes
        lines_1, labels_1 = ax1.get_legend_handles_labels()
        lines_2, labels_2 = ax2.get_legend_handles_labels()
        ax2.legend(lines_1 + lines_2, labels_1 + labels_2, loc="upper left", fontsize=14)

        # Add grid and layout adjustments
        plt.title("Active Loans (Total {}), Foreclosure Rate, and Prepayment Rate Over Time".format(len(self.diagonal)), fontsize=18)
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
        
        plt.tight_layout()

        # Save as high-resolution PDF for publication
        plt.savefig("{}/scripts/notebooks/data/corelogic/active_loans_per_timepoint_{}.pdf".format(BASE_PATH, name), format="pdf", dpi=300)
            
    
    def get_full_dataset_empirical_transition_counts(self, within_sampling_bounds, use_sampling, deterministic_sampling):

        if not within_sampling_bounds:
            Ynp = self.Y.cpu().numpy()
        else:
            if not use_sampling:
                Ynp = self.Y.cpu().numpy()
                lower_bound = self.lower_bound
                upper_bound = self.upper_bound
                time_upper_bound = self.diagonal[upper_bound][0] + self.nr_sampling_timesteps
                time_lower_bound = self.diagonal[lower_bound][0]
                Ynp = Ynp[:, time_lower_bound:time_upper_bound, :]
            elif use_sampling and not deterministic_sampling:
                data_size = self.__len__()
                
                for i in range(data_size):
                    # Y has shape (nr_loans_to_sample, nr_timesteps, nr_classes)
                    X, Y, valid_indices = self.__getitem__(i)
                    Y = Y.cpu().numpy()

                    start_time = valid_indices[0]
                    end_time = valid_indices[1]
                    X, Y = self.make_valid(X, Y, start_time, end_time)
                    if i == 0:
                        Ynp = Y
                    else:
                        Ynp = np.concatenate((Ynp, Y), axis=0)
            else:
                data_size = self.upper_bound - self.lower_bound # This is not legit
                for i in range(data_size):
                    X, Y, valid_indices = self.__getitem__(i, sample_random_loan_index=False, sample_random_time_index=False, nr_loans_to_sample=1)
                    Y = Y.cpu().numpy()
                    if i == 0:
                        Ynp = Y
                    else:
                        Ynp = np.concatenate((Ynp, Y), axis=0)
        if use_sampling:
            Ynp = torch.permute(torch.tensor(Ynp), (2, 1, 0)).cpu().numpy()


        # Ynp has shape (nr_classes, nr_timesteps, nr_loans)
        try:
            empirical_transition_counts = np.zeros((self.dataset_config["nr_classes"], self.dataset_config["nr_classes"]))
        except:
            empirical_transition_counts = np.zeros((self.dataset_config_dict["nr_classes"], self.dataset_config_dict["nr_classes"]))

        size = len(self.diagonal) if not within_sampling_bounds else Ynp.shape[2]

        argmax_Ynp = np.argmax(Ynp, axis=0)
        # Iterate over `size` and count transitions
        for i in range(size):
            current_states = argmax_Ynp[:-1, i]
            next_states = argmax_Ynp[1:, i]

            # Count transitions using NumPy advanced indexing
            np.add.at(empirical_transition_counts, (current_states, next_states), 1)

        return empirical_transition_counts
    
    def custom_round(self,x):
        if x == 0:
            return 0  # Explicitly return 0 if x is 0
        elif x >= 1:
            return round(x, 1)  # Round to one decimal point if x >= 1
        else:
            # Find the first non-zero digit position
            return round(x, -int(np.floor(np.log10(abs(x)))))
    
    def save_empirical_transition_counts_image(
            self, 
            path=None, 
            include_exit=True, 
            counts=True, 
            within_sampling_bounds=False, 
            use_sampling=False, 
            deterministic_sampling = False,
            addition_to_path = ""
            ):
        text = "_exit" if include_exit else ""
        text += "_percentages" if not counts else ""
        text += "_" + self.name if within_sampling_bounds else ""
        text += "_sampled" if use_sampling and within_sampling_bounds else ""
        text += "_deterministic" if deterministic_sampling and use_sampling and within_sampling_bounds else ""
        text += addition_to_path
        if path is None:
            path = "{}/scripts/notebooks/data/corelogic/empirical_transition_counts{}.pdf".format(BASE_PATH, text)
        
        # Fetch empirical transition counts
        empirical_transition_counts = self.get_full_dataset_empirical_transition_counts(within_sampling_bounds, use_sampling, deterministic_sampling)
        
        if not include_exit:
            empirical_transition_counts = empirical_transition_counts[:-1, :-1]
        sns.set(rc={
            "font.family": "serif",
            "font.serif": ["Computer Modern"],
            "text.usetex": True,
            "axes.grid": False
        })
        if include_exit:
            reorder_x = [4,1,2,3,5,6, 0,7]
            reorder_y = [5,3,2,1,4,0,6,7]
        else:
            reorder_x = [4,1,2,3,5,6, 0]
            reorder_y = [5,3,2,1,4,0,6]
        
        predicted_prob_of_transition = empirical_transition_counts[reorder_y,:]
        predicted_prob_of_transition = predicted_prob_of_transition[:,reorder_x]
        predicted_prob_of_transition = np.delete(predicted_prob_of_transition, [5,6], axis=0)
        #predicted_prob_of_transition = np.delete(predicted_prob_of_transition, 7, axis=1)

        if not counts:
            # Normalize the counts
            predicted_prob_of_transition = predicted_prob_of_transition / np.sum(predicted_prob_of_transition, axis=1, keepdims=True)
            vectorized_round = np.vectorize(self.custom_round)
            predicted_prob_of_transition = vectorized_round(predicted_prob_of_transition*100)
        labels_y = ["F",  "90dd","60dd", "30dd", "Current", "Exit" ]
        labels_x = ["Current", "30dd", "60dd", "90dd", "F",  "REO", "Paid Off","Exit"]
        if not include_exit:
            labels_y = labels_y[:-1]
            labels_x = labels_x[:-1]

        
        def save_heatmap(data, title, filename, fmt_type, log_scale=False):
            plt.figure(figsize=(10, 7))
            sns.set_context("notebook")  # Default styling
            sns.set_style("white")  # Remove background grid

            norm = None
            if log_scale:
                norm = mcolors.LogNorm(vmin=max(1, np.min(data[data > 0])), vmax=max(np.max(data),1.05))  # Avoid log(0)
            cmap = LinearSegmentedColormap.from_list("custom_red", ["#f4c2c2", "#8B0000"])
            ax = sns.heatmap(
                data, annot=True, fmt=fmt_type, cmap=cmap,
                xticklabels=labels_x, yticklabels=labels_y,
                annot_kws={"size": 14},  # Larger numbers inside cells
                linewidths=0.3, linecolor="white",  # Light dividers
                norm=norm  # Apply log scale if specified
            )

            plt.xticks(fontsize=14, rotation=45)
            plt.yticks(fontsize=14, rotation=0)
            plt.xlabel("End State", fontsize=16)
            plt.ylabel("Initial State", fontsize=16)
            plt.title(title, fontsize=18, pad=15)

            plt.tight_layout()
            plt.savefig(filename, dpi=300, bbox_inches="tight", format="pdf")  # Save as high-res PDF
            plt.close()

        base_path = f"{BASE_PATH}/scripts/notebooks/data/corelogic/"
        name = text
        save_heatmap(predicted_prob_of_transition, "Empirical Transition Probabilities", f"{base_path}auc_matrix_{name}.pdf", fmt_type=".2f", log_scale=True)


    def print_dataset_statistics(self):
        # Only print the statistics within the upper and lower bound
        Ynp = self.Y.cpu().numpy()
        # Ynp has shape (nr_classes, nr_timesteps, nr_loans)
        upper_bound = self.upper_bound
        lower_bound = self.lower_bound
        time_upper_bound = self.diagonal[upper_bound][0] + self.nr_sampling_timesteps
        time_lower_bound = self.diagonal[lower_bound][0]
        restricted_Y = Ynp[:,time_lower_bound:time_upper_bound, :]
        print("Dataset statistics for ", self.name)
        start_yr = self.start_year + (time_lower_bound//12)
        end_yr = min((self.start_year + ((time_upper_bound)//12)), self.end_year)
        
        print("Start Year: ", start_yr)
        print("End Year: ", end_yr)
        print("start_idx: ", self.lower_bound)
        print("end_idx: ", self.upper_bound)
        self.print_full_dataset_statistics(input=torch.tensor(restricted_Y).float().to(self.device))
    
    
    def make_valid(self, X, Y, start_time, end_time):
        # Create a boolean mask over the time dimension
        time_indices = np.arange(X.shape[2])  # [0, 1, 2, ..., X.shape[2]-1]
        invalid_mask = (time_indices < start_time) | (time_indices >= end_time)

        # Set all invalid time steps in X to 0, then set the 7th channel to 1
        X[:, :, invalid_mask] = 0
        X[7, :, invalid_mask] = 1

        # Set all invalid time steps in Y to 0, then set the 7th index to 1
        Y[:, invalid_mask, :] = 0
        Y[:, invalid_mask, 7] = 1

        return X, Y
    
    def make_valid1(self,X,Y, start_time, end_time):
        for b in range(X.shape[1]):
                for t in range(X.shape[2]):
                    if t >= end_time or t<start_time:
                        X[:,b,t] = 0
                        Y[b,t,:] = 0
                        X[7,b,t] = 1
                        Y[b,t,7] = 1
        return X,Y

    def get_logistic_regression_data(self):
        X_samples = []
        Y_samples = []
        L = self.__len__()
        for i in range(L):
            X, Y, valid_indices = self.__getitem__(i)
            #truncate the last timestep
            
            X = X[:,:,:-1]
            Y = Y[:,1:,:]
            start_time = valid_indices[0]
            end_time = valid_indices[1]
            X, Y = self.make_valid(X,Y, start_time, end_time)
            

            X_samples.append(X.cpu().numpy()) # X has shape (nr_features, 1, nr_timesteps)
            Y_samples.append(Y.cpu().numpy()) # Y has shape (1, nr_timesteps, nr_classes)
        
        X_samples = np.concatenate(X_samples, axis=1)  # shape (nr_features, nr_loans, nr_timesteps)
        Y_samples = np.concatenate(Y_samples, axis=0)  # shape (nr_loans, nr_timesteps, nr_classes)
        # Squeeze the 1 dimension
        X_samples = np.transpose(X_samples, (1, 0, 2))  # shape (nr_loans, nr_features, nr_timesteps)
        X_samples = np.transpose(X_samples, (0, 2, 1))  # shape (nr_loans, nr_timesteps, nr_features)
        # Merge the first dimensions of X
        X_samples = np.reshape(X_samples, (-1, X_samples.shape[2]))  # shape (nr_loans*nr_timesteps, nr_features)
        Y_samples = np.reshape(Y_samples, (-1, Y_samples.shape[2]))  # shape (nr_loans*nr_timesteps, nr_classes)

        # argmax for Y
        Y_samples = np.argmax(Y_samples, axis=1)  # shape (nr_loans*nr_timesteps)
        return X_samples, Y_samples
    
    
    def __len__(self):
        #return self.upper_bound - self.lower_bound
        return max(self.steps_per_epoch,1)
    
    
    def __getitem__(
        self, 
        index,
        sample_random_time_index= None,
        nr_loans_to_sample = None,
        sample_random_loan_index = None
    ):
        """
        Unified __getitem__ that can either sample a random time index or use the
        passed-in 'index'. Also can randomly sample loans or pick the 'top'
        (last) loans, but only from those that are 'active' in the first time-step.

        Returns
        -------
        X : torch.Tensor
            Shape [nr_features, nr_loans_to_sample, nr_timesteps].
        Y : torch.Tensor
            Shape [nr_loans_to_sample, nr_timesteps, nr_classes].
        """
        # Set the seed based on the index
        # This is a problem, same samples are seen each epoch
        # Change this to be active only for eval model
        if (self.name in ["val", "test"]) or self.eval_mode:
            
            old_state = np.random.get_state()
            np.random.seed(self.eval_seed + index)


        if nr_loans_to_sample is None:
            nr_loans_to_sample = self.nr_loans_to_sample
        if sample_random_time_index is None:
            sample_random_time_index = self.sample_random_time_index
        if sample_random_loan_index is None:
            sample_random_loan_index = self.sample_random_loan_index
        
        # 1) Decide the "time index" based on sample_random_time_index
        if sample_random_time_index:
            if self.use_time_weighting and self.name == "train":
                # 1. Determine the valid range of indices
                valid_indices = np.arange(self.lower_sampling_bound, self.upper_sampling_bound + 1)
                # 2. Get the corresponding times in months
                times = self.diagonal[self.lower_sampling_bound : self.upper_sampling_bound + 1, 0]
                
                # 3. Set up exponential weighting with a 24-month half-life
                #    The formula below ensures that if a time is 24 months older, 
                #    its weight is exactly 1/2 that of the newest time.
                scale = np.log(2) / self.time_weighting_half_life  # negative ensures older times get smaller weights
                
                # The "newest" (largest) time in our slice
                t_max = times[-1]
                
                # Compute weights = exp(scale * (t - t_max))
                # - For the newest time, (t - t_max) = 0 => weight = exp(0) = 1
                # - For time 24 months older, weight = exp(scale * -24) = exp(-ln(2)) = 0.5
                weights = np.exp(scale * (times - t_max))
                # Normalize
                weights /= weights.sum()
                
                # 4. Sample according to the computed weights
                chosen_local_idx = np.random.choice(len(valid_indices), p=weights)
                time_index = valid_indices[chosen_local_idx]
            else:
                # Plain uniform random sampling if not using the time weighting
                time_index = np.random.randint(self.lower_sampling_bound, self.upper_sampling_bound + 1)
        else:
            time_index = index + self.lower_sampling_bound
        
        # Basic checks
        assert time_index >= self.lower_sampling_bound, "Time index < lower bound."
        assert time_index <= self.upper_sampling_bound, "Time index > upper bound."

        # 2) Determine the start_time from your 'diagonal' array
        #    (Matches your old logic, but in a single function)
        start_time = self.diagonal[time_index][0]

        # 3) Fetch the slice of data from [start_time : start_time+nr_sampling_timesteps]
        #    and from loans [0..time_index]. Adjust if you have different indexing logic.
        #    Shapes here: 
        #        fetched_loans:   [nr_features, nr_timesteps, time_index]
        #        fetched_loans_Y: [nr_classes,  nr_timesteps, time_index]
        fetched_loans = self.X[:, start_time : start_time + self.nr_sampling_timesteps, :time_index]
        fetched_loans_Y = self.Y[:, start_time : start_time + self.nr_sampling_timesteps, :time_index]

        # 4) Identify the set of 'active' loans at the first sampled time-step
        #    We say a loan is active if it is NOT in the exit state (the last class)
        #    for that first time-step. For example, the first time-step is index=0
        #    => check fetched_loans_Y[-1, 0, loan_idx] != 1
        assert fetched_loans.shape[2] != 0

        is_active_mask = (fetched_loans_Y[-1, 0, :] != 1)
        active_indices = np.flatnonzero(is_active_mask.cpu().numpy())

        assert len(active_indices) != 0

        # 5) Among these active loans, choose the final subset of size nr_loans_to_sample
        nr_loans_to_sample = min(nr_loans_to_sample, len(active_indices)) # This allows us to sample max available loans
        #assert nr_loans_to_sample <= len(active_indices)

        if sample_random_loan_index:
            # Randomly choose from active loans
            chosen = np.random.choice(active_indices, size=nr_loans_to_sample, replace=True)
        else:
            # 'Top' means "the last nr_loans_to_sample" columns from active_indices
            # If you meant something else by "top," adjust here.
            chosen = active_indices[-nr_loans_to_sample:]
            # Another way, pick the first nr_loans_to_sample from all loans (not just active)
            chosen = np.arange(min(nr_loans_to_sample, fetched_loans.shape[2])) # TEMP may 7

        # 6) Slice out the chosen loans
        final_loans_X = fetched_loans[:, :, chosen]   # shape: [nr_features, nr_timesteps, nr_loans_chosen]
        final_loans_Y = fetched_loans_Y[:, :, chosen] # shape: [nr_classes,  nr_timesteps, nr_loans_chosen]

        # 7) Permute to match your desired output shape
        #    X => [nr_features, nr_loans, nr_timesteps]
        #    Y => [nr_loans, nr_timesteps, nr_classes]
        X = torch.permute(final_loans_X.float().to(self.device), (0, 2, 1))  # convert to fp32
        Y = torch.permute(final_loans_Y.float().to(self.device), (2, 1, 0))  # convert to fp32

        #return X, Y
        # --------------------------------------------------
        # 6) Figure out which time-steps in [0..nr_sampling_timesteps-1] 
        #    actually fall within [lower_bound_month, upper_bound_month].
        #
        #    local i => global month = start_month + i
        #    We want all i s.t. lower_bound_month <= start_month + i <= upper_bound_month
        # --------------------------------------------------

        start_month = self.diagonal[time_index][0]
        lower_bound_month = self.diagonal[self.lower_bound][0]
        upper_bound_month = self.diagonal[self.upper_bound][0]
        valid_indices = [
            i for i in range(self.nr_sampling_timesteps)
            if (lower_bound_month <= (start_month + i) <= upper_bound_month)
        ]
        if len(valid_indices) == 0:
            breakpoint()
        # e.g. if valid_indices = [2, 3, 4, 5], then valid_time_span = (2, 6)
        # meaning timesteps 2..5 are "valid" in X and Y.
        valid_time_span = (valid_indices[0], valid_indices[-1] + 1)
        # --------------------------------------------------
        # 7) Return X, Y, plus the valid_time_span
        # --------------------------------------------------
        if (self.name in ["val", "test"]) or self.eval_mode:
            np.random.set_state(old_state)
        if self.return_start_time:
            return X, Y, valid_time_span, start_month
        else:
            return X, Y, valid_time_span

class LoanDataset(SequenceDataset): 

    _name_= "corelogic_loan_dataset"

    def setup(self):
        

        if len(self._collate_arg_names) == 0:
                self._collate_arg_names.append("valid_indices")
        self.config = self._load_config()

        status, data_dict = self._load_saved_data()
        
        if not status:
            self.loan_fetcher = CreateLoanData(**self.config["dataset_config"])
            self.data_dict = self.loan_fetcher.return_data_dict()
        else:
            self.data_dict = data_dict
        self.diagonal = self.data_dict["diagonal"]
        self.dataset_size = len(self.diagonal)
        self.nr_sampling_timesteps = self.config["nr_sampling_timesteps"]
        #if self.config["save_data"]:
        #    np.savez(self.config["data_path"], **self.data_dict)
        #    print(f"Saved data to {self.config['data_path']}")
        # Save data efficiently if enabled
        if self.config["save_data"]:
            optimized_data = {}
            use_fp16 = self.config.get("use_fp16", True)  # Default to False
            print("use_fp16", use_fp16)
            use_compression = self.config.get("use_compression", False)  # Default to False

            for key, array in self.data_dict.items():
                if use_fp16 and np.issubdtype(array.dtype, np.floating):
                    optimized_data[key] = array.astype(np.float16)  # Convert floats to FP16
                else:
                    optimized_data[key] = array  # Keep other data types unchanged

            # Save with or without compression
            if use_compression:
                print("Saving with compression")
                np.savez_compressed(self.config["data_path"], **optimized_data)
            else:
                print("Saving without compression")
                np.savez(self.config["data_path"], **optimized_data)

            print(f"Saved data to {self.config['data_path']} (fp16={use_fp16}, compression={use_compression})")
        self._split_data()

        self.dataset_train, self.dataset_val, self.dataset_test = self.get_data()
    
    def _load_config(self):
        try:
            val_split_date = self.val_split_date
            test_split_date = self.test_split_date
        except:
            breakpoint()
            val_split_date = "2010-01"
            test_split_date = "2011-01"
        try:
            eval_mode = self.eval_mode
        except:
            eval_mode = False

        try:
            eval_seed = self.eval_seed
        except:
            eval_seed = 1000
        
        # Check if attribute rolling_model exists
        if not (hasattr(self, 'rolling_model')
                and hasattr(self, 'rolling_start_epoch') 
                and hasattr(self, 'rolling_epoch_interval')
                ):
            self.rolling_model = False
            self.rolling_start_epoch = 0
            self.rolling_epoch_interval = 0
        else:
            self.rolling_model = self.rolling_model
            self.rolling_start_epoch = self.rolling_start_epoch
            self.rolling_epoch_interval = self.rolling_epoch_interval
        
        if hasattr(self, 'use_time_weighting'):
            use_time_weighting = self.use_time_weighting
            time_weighting_half_life = self.time_weighting_half_life
        else:
            print("No use_time_weighting or time_weighting_half_life found in config")
            use_time_weighting = True
            time_weighting_half_life = 48

        config ={
            "_name_": "corelogic_loan_dataset",
            "dataset_config": self.dataset_config,
            "val_split": self.val_split,
            "test_split": self.test_split,
            "load_data": self.load_data,
            "save_data": self.save_data,
            "data_path": self.data_path,
            "max_to_sample": self.max_to_sample,
            "nr_sampling_timesteps": self.nr_sampling_timesteps,
            "nr_loans_to_sample": self.nr_loans_to_sample,
            "steps_per_epoch": self.steps_per_epoch,
            "sample_random_time_index": self.sample_random_time_index,
            "sample_random_loan_index": self.sample_random_loan_index,
            "val_split_date": val_split_date,
            "test_split_date": test_split_date,
            "eval_mode": eval_mode,
            "eval_seed": eval_seed,
            "use_time_weighting": use_time_weighting,
            "time_weighting_half_life": time_weighting_half_life
        }
        return config
    
    def init(self):
        pass

    def _split_data(self):
        """
        Splits the timeline according to self.config["val_split_date"] and
        self.config["test_split_date"]. The earliest date we consider is
        self.diagonal[self.config["max_to_sample"]][0]. That ensures we
        have enough warm-up data.
        
        We then define:
        - Train from earliest_time up to (val_split_time - nr_sampling_timesteps).
        - Val from val_split_time up to (test_split_time - 1).
        - Test from test_split_time onward.
        
        We also define sampling ranges that extend these limits:
        - Train sampling extends by nr_sampling_timesteps//2 beyond the train upper limit.
        - Test sampling extends nr_sampling_timesteps both backward and forward.
        """
        def date_to_months_since_1988(date_str: str) -> int:
            """Converts 'YYYY-MM' to integer months since 1988-01."""
            year, month = map(int, date_str.split('-'))
            return (year - 1988) * 12 + (month - 1)
        
        # -----------------------------------------------------------------
        # 1) Find earliest_time and last_time from the data
        # -----------------------------------------------------------------
        earliest_time = self.diagonal[self.config["max_to_sample"]][0]
        last_time     = self.data_dict["X"].shape[1]   #self.diagonal[-1][0]  # The final month in your dataset
        
        # Convert your config split dates to integer month indices
        val_split_time  = date_to_months_since_1988(self.config["val_split_date"])
        test_split_time = date_to_months_since_1988(self.config["test_split_date"])
        # -----------------------------------------------------------------
        # 2) Define strict train/val/test boundaries (in terms of "time")
        # -----------------------------------------------------------------
        # Train: earliest_time -> (val_split_time - nr_sampling_timesteps)
        train_l_time = earliest_time
        train_u_time = val_split_time -1
        
        # Val: val_split_time -> (test_split_time - 1)
        val_l_time = val_split_time
        val_u_time = test_split_time - 1
        
        # Test: test_split_time -> last_time
        test_l_time = test_split_time
        test_u_time = last_time  # i.e. the end of the dataset
        
        # Make sure we have some basic sanity checks
        assert train_l_time <= train_u_time,  "Train region is invalid (val_split_time too early?)"
        assert val_l_time   <= val_u_time,    "Val region is invalid"
        assert test_l_time  <= test_u_time,   "Test region is invalid"
        assert train_u_time < val_l_time,     "Train/Val overlap"
        assert val_u_time   < test_l_time,    "Val/Test overlap"
        
        # -----------------------------------------------------------------
        # 3) Define extended sampling boundaries
        # -----------------------------------------------------------------
        
        # a) Train sampling: extend the upper boundary by half of nr_sampling_timesteps
        train_samp_l_time = train_l_time
        # clamp so we do not accidentally cross into the val region
        train_samp_u_time = train_u_time - self.nr_sampling_timesteps // 2

        val_samp_l_time = val_l_time - int(3*self.nr_sampling_timesteps // 4)
        val_samp_u_time = val_u_time - 10
        
        # c) Test sampling: extend backward and forward by nr_sampling_timesteps
        test_samp_l_time = test_l_time - int(3*self.nr_sampling_timesteps // 4)
        test_samp_u_time = test_u_time - self.nr_sampling_timesteps

        assert test_samp_u_time <= last_time, "Test sampling region exceeds dataset bounds"

        
        def find_index_for_time(target_time: int) -> int:
            """
            Return the first index in self.diagonal whose time >= target_time.
            If all times < target_time, return the last index (to avoid OOB).
            """
            for i, (t, _) in enumerate(self.diagonal):
                if t >= target_time:
                    return i
            return len(self.diagonal) - 1
        
        # Strict region indices
        train_l_idx = find_index_for_time(train_l_time)
        train_u_idx = find_index_for_time(train_u_time)
        val_l_idx   = find_index_for_time(val_l_time)
        val_u_idx   = find_index_for_time(val_u_time)
        test_l_idx  = find_index_for_time(test_l_time)
        test_u_idx  = find_index_for_time(test_u_time)
        assert train_u_idx < val_l_idx, "Train/Val overlap"
        assert val_u_idx < test_l_idx, "Val/Test overlap"
        assert train_l_idx <= train_u_idx, "Train region is invalid"
        assert val_l_idx <= val_u_idx, "Val region is invalid"
        assert test_l_idx <= test_u_idx, "Test region is invalid"
        
        self.limits_train = (train_l_idx, train_u_idx)
        self.limits_val   = (val_l_idx,   val_u_idx)
        self.limits_test  = (test_l_idx,  test_u_idx)
        
        # Sampling region indices
        train_samp_l_idx = find_index_for_time(train_samp_l_time)
        train_samp_u_idx = find_index_for_time(train_samp_u_time)
        val_samp_l_idx   = find_index_for_time(val_samp_l_time)
        val_samp_u_idx   = find_index_for_time(val_samp_u_time)
        test_samp_l_idx  = find_index_for_time(test_samp_l_time)
        test_samp_u_idx  = find_index_for_time(test_samp_u_time)
        
        self.sampling_train = (train_samp_l_idx, train_samp_u_idx)
        self.sampling_val   = (val_samp_l_idx,   val_samp_u_idx)
        self.sampling_test  = (test_samp_l_idx,  test_samp_u_idx)
        
        # -----------------------------------------------------------------
        # 5) (Optional) Print or assert some checks
        # -----------------------------------------------------------------
        print("==== Strict Boundaries (indices) ====")
        print("Train:", self.limits_train)
        print("Val:  ", self.limits_val)
        print("Test: ", self.limits_test)
        
        print("==== Sampling Boundaries (indices) ====")
        print("Train:", self.sampling_train)
        print("Val:  ", self.sampling_val)
        print("Test: ", self.sampling_test)

        ## Optional: Print the actual time ranges
        print("==== Strict Boundaries (time) ====")
        print("Train:", self.diagonal[self.limits_train[0]][0], self.diagonal[self.limits_train[1]][0])
        print("Val:  ", self.diagonal[self.limits_val[0]][0],   self.diagonal[self.limits_val[1]][0])
        print("Test: ", self.diagonal[self.limits_test[0]][0],  self.diagonal[self.limits_test[1]][0])

    
    def get_data(self):
        return (
            LoanDataLoader(self.data_dict, self.limits_train, self.sampling_train, self.config, name = "train", steps_per_epoch = int(self.config["steps_per_epoch"]*(1-self.config['val_split']-self.config["test_split"]))),
            LoanDataLoader(self.data_dict, self.limits_val, self.sampling_val, self.config, name = "val", steps_per_epoch = int(self.config["steps_per_epoch"]*self.config["val_split"])),
            LoanDataLoader(self.data_dict, self.limits_test, self.sampling_test, self.config, name = "test", steps_per_epoch = int(self.config["steps_per_epoch"]*self.config["test_split"]))
        )
        
    def _load_saved_data(self):
        """
        Loads saved data from a .npz file specified in the configuration.

        Returns:
            bool: True if data was successfully loaded, False otherwise.
            dict: The loaded data dictionary or an empty dictionary on failure.
        """
        data_path = self.config.get("data_path", "")
        if not self.config.get("load_data", False) or not os.path.exists(data_path):
            return False, {}
        try:
            # Attempt to load the data file
            with np.load(data_path, allow_pickle=True) as data:
                print(f"Loaded data from {data_path}")
                return True, {key: data[key] for key in data.files}
        except Exception as e:
            # Handle any errors during loading
            print(f"Error loading data from {data_path}: {e}")
            return False, {}


def main():
    dataset_config  =  {
            "path_origination": f"{CORELOGIC_DATA_PATH}/filtered_origination_data_top_4_zips.csv",
            "path_performance": f"{CORELOGIC_DATA_PATH}/filtered_performance_data_top_4_zips.csv",
            "database_size": 400000,
            "start_year": 1970,
            "columns_to_normalize_origination": ["fico_score_at_origination", "original_balance", "initial_interest_rate", "original_ltv"],
            "feature_set": ["current_state", 'fico_score_at_origination', "original_balance", "initial_interest_rate", "original_ltv"],
            "nr_classes": 8,
            "verbose": False
        }

    config = {
        "_name_": "corelogic_loan_dataset",
        "dataset_config": dataset_config,
        "val_split": 0.1,
        "test_split": 0.1,
        "load_data": True,
        "save_data": False,
        "data_path": "./../../data/corelogic/loan_data.npz",
        "max_to_sample": 10,
        "nr_sampling_timesteps": 20,
        "nr_loans_to_sample": 5,
        "steps_per_epoch": 20000,
    }

    loan_dataset = LoanDataset(**config)
    train_loader, val_loader, test_loader = loan_dataset.get_data()
    train_lower_bound = train_loader.lower_bound
    train_upper_bound = train_loader.upper_bound
    val_lower_bound = val_loader.lower_bound
    val_upper_bound = val_loader.upper_bound
    test_lower_bound = test_loader.lower_bound
    test_upper_bound = test_loader.upper_bound
    assert train_upper_bound < val_lower_bound
    assert val_upper_bound < test_lower_bound
    assert train_upper_bound > train_lower_bound
    assert val_upper_bound > val_lower_bound
    assert test_upper_bound > test_lower_bound
    print("Train lower bound (idx): ", train_lower_bound)
    print("Train upper bound (idx): ", train_upper_bound)
    print("Val lower bound (idx): ", val_lower_bound)
    print("Val upper bound (idx): ", val_upper_bound)
    print("Test lower bound (idx): ", test_lower_bound)
    print("Test upper bound (idx): ", test_upper_bound)
    
    assert len(train_loader) == train_upper_bound - train_lower_bound
    assert len(val_loader) == val_upper_bound - val_lower_bound
    assert len(test_loader) == test_upper_bound - test_lower_bound
    # Create DataLoader
    batch_size = 2
    dataloader = DataLoader(train_loader, batch_size=batch_size, shuffle=True)

    # Iterate through DataLoader
    for batch in dataloader:
        x_batch, y_batch = batch

        # Y has shape (nr_loans_to_sample, nr_timesteps, nr_classes)
        
        assert x_batch.shape == (batch_size, train_loader.nr_features,  config["nr_loans_to_sample"], config["nr_sampling_timesteps"])
        assert y_batch.shape == (batch_size,  config["nr_loans_to_sample"], config["nr_sampling_timesteps"], config["dataset_config"]["nr_classes"] )
    print("DataLoader works as expected")
    loan_fetcher = CreateLoanData(**dataset_config)
    loan_fetcher.print_dataset_statistics()
    verbose  = False
    if verbose:
        origination = pd.read_csv(f"{CORELOGIC_DATA_PATH}/filtered_origination_data_top_4_zips.csv")
        performance = pd.read_csv(f"{CORELOGIC_DATA_PATH}/filtered_performance_data_top_4_zips.csv")
        print(origination.head())
        print(performance.head())
        print(origination.columns)
        print(performance.columns)

if __name__ == "__main__":
    import cProfile
    import pstats
    profiler = cProfile.Profile()
    profiler.enable()
    main()
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats('cumulative')
    stats.print_stats(30)