"""Evaluate a model that with checkpoints, each checkpoint covering a larger 
time period than the previous checkpoint."""

# 1. Load the model

# What is the plots we want to make:
# How does the AUC vary on the different partitions
# Is there some partitions where the set-seq model has an additional edge over the NN baseline?
# What is the best way to do continual training with the sliding window
# - No continual training
# - Retrain for each new window
# - Train window increase in time
# - Sliding window (start train gets removed continually)
# - Use full sliding window, but use a sampling of time more heavily on the end
# starting in say 2009 compare for each year model that is refitted with model trained only on data up to 2009
# Also, does upsampling later periods help learning?


# 1. Does continual training help?
# 2. Compare no continual training with continual training, and continual training with half life

import sys
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
CORELOGIC_DATA_PATH = os.environ.get("CORELOGIC_DATA_PATH", "/share/data/llm_mortgages/original_data")
sys.path.append(BASE_PATH)
sys.path.append(BASE_PATH+"/src/")
sys.path.append(BASE_PATH+"/scripts/notebooks/")
sys.path.append(BASE_PATH+"/scripts/notebooks/data/")
sys.path.append(BASE_PATH+"/scripts/notebooks/true_loss_level/")
from src.dataloaders.dataloader_corelogic import LoanDataset
from scripts.notebooks.true_loss_level.get_transition_probabilities import load_model_corelogic
from scripts.notebooks.true_loss_level.get_corelogic import evaluate_model, get_metrics
from sklearn.metrics import roc_auc_score
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import gc
import os
from datetime import datetime
# Set up plotting style similar to get_corelogic
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "text.usetex": True,
    "axes.grid": False
})

SEED_NR = 42
np.random.seed(SEED_NR)
torch.manual_seed(SEED_NR)

def get_model_config(name):

    rand_train2_top4 = {
         "experiment": "timeseries/set_corelogic_top4_exp",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-02-02/11-41-20/step_11206.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }
    start_epoch = 16
    step_freq = 2
    nr_epochs = 55
    start_year = 2000
    checkpoint_paths = [
        f"{BASE_PATH}/outputs/outputs/2025-02-18/shift_ckpt_epoch_{start_epoch + i * step_freq}_18-43-55_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]
    rolled_model_config = {
        "experiment": "timeseries/set_corelogic_forward_shifting_train",
        "checkpoint_paths": checkpoint_paths,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    checkpoint_paths_half_life_24 = [
        f"{BASE_PATH}/outputs/outputs/2025-02-20/shift_ckpt_epoch_{start_epoch + i * step_freq}_12-08-29_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_model_half_life_24 = {
        "experiment": "timeseries/set_corelogic_forward_shifting_train",
        "checkpoint_paths": checkpoint_paths_half_life_24,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz",
    }


    start_epoch = 21
    step_freq = 1
    nr_epochs = 32
    start_year = 2009
    checkpoint_paths = [
        f"{BASE_PATH}/outputs/outputs/2025-02-20/shift_ckpt_epoch_{start_epoch + i * step_freq}_15-18-41_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]
    rolled_model_config = {
        "experiment": "timeseries/set_corelogic_forward_shifting_train",
        "checkpoint_paths": checkpoint_paths,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz"
    }

    checkpoint_paths_half_life_24 = [
        f"{BASE_PATH}/outputs/outputs/2025-02-20/shift_ckpt_epoch_{start_epoch + i * step_freq}_15-16-53_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_model_half_life_24 = {
        "experiment": "timeseries/set_corelogic_forward_shifting_train",
        "checkpoint_paths": checkpoint_paths_half_life_24,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz",
    }



    checkpoint_paths_nn_cl = [
        f"{BASE_PATH}/outputs/outputs/2025-02-23/shift_ckpt_epoch_{start_epoch + i * step_freq}_12-05-05_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_nn_model_config = {
        "experiment": "timeseries/nn_cl_forward_shift",
        "checkpoint_paths": checkpoint_paths_nn_cl,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz",
    }


    #### New Experiments #####


    start_epoch = 16
    step_freq = 1
    nr_epochs = 32
    start_year = 2002

    checkpoint_paths = [
        f"{BASE_PATH}/outputs/outputs/2025-02-24/shift_ckpt_epoch_{start_epoch + i * step_freq}_17-28-18_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]
    rolled_model_config = {
        "experiment": "timeseries/set_cl_forward_shift",
        "checkpoint_paths": checkpoint_paths,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz"
    }


    checkpoint_paths_nn_cl = [
        f"{BASE_PATH}/outputs/outputs/2025-02-24/shift_ckpt_epoch_{start_epoch + i * step_freq}_15-22-42_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_nn_model_config = {
        "experiment": "timeseries/nn_cl_forward_shift",
        "checkpoint_paths": checkpoint_paths_nn_cl,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    checkpoint_paths_linear_cl = [
        f"{BASE_PATH}/outputs/outputs/2025-02-24/shift_ckpt_epoch_{start_epoch + i * step_freq}_15-24-36_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_linear_model_config = {
        "experiment": "timeseries/linear_cl_forward_shift",
        "checkpoint_paths": checkpoint_paths_linear_cl,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    checkpoint_paths_gated_selection_cl = [
        f"{BASE_PATH}/outputs/outputs/2025-02-24/shift_ckpt_epoch_{start_epoch + i * step_freq}_21-36-04_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_gated_selection_model_config = {
        "experiment": "timeseries/gated_selection_cl_forward_shift",
        "checkpoint_paths": checkpoint_paths_gated_selection_cl,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }


    ## New Experiments Feb 25

    start_epoch = 16
    step_freq = 1
    nr_epochs = 32
    start_year = 2002

    gated_selection_checkpoint_paths_58 = [
        f"{BASE_PATH}/outputs/outputs/2025-02-25/shift_ckpt_epoch_{start_epoch + i * step_freq}_18-39-43_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]
    gated_selection_model_config_58 = {
        "experiment": "timeseries/gated_selection_cl_forward_shift",
        "checkpoint_paths": gated_selection_checkpoint_paths_58,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz"
    }


    

    checkpoint_paths_linear_cl_58 = [
        f"{BASE_PATH}/outputs/outputs/2025-02-25/shift_ckpt_epoch_{start_epoch + i * step_freq}_18-25-47_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    linear_model_config_58 = {
        "experiment": "timeseries/linear_cl_forward_shift",
        "checkpoint_paths": checkpoint_paths_linear_cl_58,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    checkpoint_paths_nn_cl_58 = [
        f"{BASE_PATH}/outputs/outputs/2025-02-25/shift_ckpt_epoch_{start_epoch + i * step_freq}_18-28-09_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_nn_model_config_58 = {
        "experiment": "timeseries/nn_cl_forward_shift",
        "checkpoint_paths": checkpoint_paths_nn_cl_58,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }


    checkpoint_paths_nn_cl_55 = [
        f"{BASE_PATH}/outputs/outputs/2025-02-24/shift_ckpt_epoch_{start_epoch + i * step_freq}_15-22-42_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_nn_model_config_55 = {
        "experiment": "timeseries/nn_cl_forward_shift_55",
        "checkpoint_paths": checkpoint_paths_nn_cl_55,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    checkpoint_paths_gated_selection_cl_55 = [
        f"{BASE_PATH}/outputs/outputs/2025-02-24/shift_ckpt_epoch_{start_epoch + i * step_freq}_21-36-04_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_gated_selection_model_config_55 = {
        "experiment": "timeseries/gated_selection_cl_forward_shift_55",
        "checkpoint_paths": checkpoint_paths_gated_selection_cl_55,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    checkpoint_paths_gated_selection_cl_55_visualize_layer = [
       f"{BASE_PATH}/outputs/outputs/2025-03-03/shift_ckpt_epoch_{start_epoch + i * step_freq}_12-10-05_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    rolled_gated_selection_model_config_55_visualize_layer = {
        "experiment": "timeseries/gated_selection_cl_kl_1",
        "checkpoint_paths": checkpoint_paths_gated_selection_cl_55_visualize_layer,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    ## New Experiment March 13

    debug_gated_selection_dmod_32_paths = [
        f"{BASE_PATH}/outputs/outputs/2025-03-13/shift_ckpt_epoch_{start_epoch + i * step_freq}_11-09-22_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    debug_gated_selection_dmod_32 = {
        "experiment": "timeseries/gated_selection_cl_kl_1_debug",
        "checkpoint_paths": debug_gated_selection_dmod_32_paths,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    gated_selection_dmod_32_paths = [
        f"{BASE_PATH}/outputs/outputs/2025-03-13/shift_ckpt_epoch_{start_epoch + i * step_freq}_11-39-29_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    gated_selection_dmod32 = {
        "experiment": "timeseries/gated_selection_cl_kl_1",
        "checkpoint_paths": gated_selection_dmod_32_paths,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }


    debug_gated_selection_dmod_32_paths_kl_50 = [
        f"{BASE_PATH}/outputs/outputs/2025-03-13/shift_ckpt_epoch_{start_epoch + i * step_freq}_11-38-50_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    debug_gated_selection_dmod_32_kl_50 = {
        "experiment": "timeseries/gated_selection_cl_kl_50_debug",
        "checkpoint_paths": debug_gated_selection_dmod_32_paths_kl_50,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    gated_selection_dmod_32_paths_kl_50 = [
        f"{BASE_PATH}/outputs/outputs/2025-03-13/shift_ckpt_epoch_{start_epoch + i * step_freq}_11-29-41_val_{start_year + i}-06.ckpt" for i in range((nr_epochs-start_epoch)//step_freq)
    ]

    gated_selection_dmod32_kl_50 = {
        "experiment": "timeseries/gated_selection_cl_kl_50",
        "checkpoint_paths": gated_selection_dmod_32_paths_kl_50,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi_lp.npz",
    }

    # Models to compare
    # 1. Gated Selection Set-Seq wo extra features (55 total)
    # 2. Gated Selection Set-Seq extra features (58 total)
    # 3. NN extra features (58 total)
    # 4. NN wo extra features (55 total)
    # 5. Linear Set-Seq extra features (58 total)
    # That's it
    model_config_dict = {
        "base_model_config": rand_train2_top4,
        "rolled_model_config": rolled_model_config,
        "rolled_model_half_life_24": rolled_model_half_life_24,
        #"logistic_model_config": logistic_model_config,
        "rolled_nn_model_config": rolled_nn_model_config,
        "rolled_linear_model_config": rolled_linear_model_config,
        "rolled_gated_selection_model_config": rolled_gated_selection_model_config,
        "rolled_gated_selection_model_config_55": rolled_gated_selection_model_config_55,
        "rolled_nn_model_config_55": rolled_nn_model_config_55,
        "rolled_nn_model_config_58": rolled_nn_model_config_58,
        "linear_model_config_58": linear_model_config_58,
        "gated_selection_model_config_58": gated_selection_model_config_58,
        "rolled_gated_selection_model_config_55_visualize_layer": rolled_gated_selection_model_config_55_visualize_layer,
        "gated_selection_dmod32": gated_selection_dmod32,
        "debug_gated_selection_dmod32": debug_gated_selection_dmod_32,
        "gated_selection_dmod32_kl_50": gated_selection_dmod32_kl_50,
        "debug_gated_selection_dmod32_kl_50": debug_gated_selection_dmod_32_kl_50,
    }
    if name in model_config_dict:
        return model_config_dict[name]
    
    assert False, f"Model config {name} not found"

def get_dataset(config):
    dataset = LoanDataset(**config)
    dataset.setup()
    return dataset.dataset_train, dataset.dataset_val, dataset.dataset_test, dataset

def update_data(train,val,test, dataset, val_date):
    """Update model data so that val_date is when the validation set start (1 yr later for test)."""

    if not hasattr(dataset, 'rolling_model'):
        return

    # Convert val_date string to datetime
    
    val_date_dt = datetime.strptime(val_date, "%Y-%m")
    test_date_dt = val_date_dt.replace(year=val_date_dt.year + 1)

    # Format dates as strings
    val_date_str = val_date_dt.strftime("%Y-%m")
    test_date_str = test_date_dt.strftime("%Y-%m")

    # Update dataset config with new dates
    dataset.config["val_split_date"] = val_date_str
    dataset.config["test_split_date"] = test_date_str
    dataset._split_data()

    # Update dataset bounds
    train.lower_bound = dataset.limits_train[0]
    train.upper_bound = dataset.limits_train[1]
    val.lower_bound = dataset.limits_val[0]
    val.upper_bound = dataset.limits_val[1]
    test.lower_bound = dataset.limits_test[0]
    test.upper_bound = dataset.limits_test[1]

    train.lower_sampling_bound = dataset.sampling_train[0]
    train.upper_sampling_bound = dataset.sampling_train[1]
    val.lower_sampling_bound = dataset.sampling_val[0]
    val.upper_sampling_bound = dataset.sampling_val[1]
    test.lower_sampling_bound = dataset.sampling_test[0]
    test.upper_sampling_bound = dataset.sampling_test[1]

    print(f"New val split date: {val_date_str}")
    print(f"New test split date: {test_date_str}")


def evaluate_rolling_model(model_config, train,val,test,dataset, same_model=False, model_name="set-seq"):
    metric_list = []
    for i in range(len(model_config["checkpoint_paths"])):
        if same_model:
            model_config["checkpoint_path"] = model_config["checkpoint_paths"][0]
            if i ==0:
                model = load_model_corelogic(**model_config)
        else:
            model_config["checkpoint_path"] = model_config["checkpoint_paths"][i]
            model = load_model_corelogic(**model_config)
            print(f"Evaluating model for checkpoint {i}")
        #sbreakpoint()
        val_date = model_config["checkpoint_paths"][i].split("_")[-1][:-5]
        update_data(train, val, test, dataset, val_date)
        y_pred, y_true = evaluate_model(
            model,         # logistic model, set-seq model, or np.ndarray (transition matrix)
            'set-seq', 
            val, 
            test_set=None, 
            batch_size=1,
            fix_seed=True,
            all_units_in_batch_dim=False
        )
        # Convert to NumPy and slice as needed
        name_val = f"{model_name}-val-{val_date}"
        tol = 1e-6
        y_pred = y_pred[:, :, :-1, :].detach().cpu().numpy() + tol
        y_true = y_true[:, :, 1:, :].detach().cpu().numpy()
        metrics_set_seq_val = get_metrics(y_true, y_pred, name=name_val)
        metric_list.append(metrics_set_seq_val)
    return metric_list
        
import matplotlib.pyplot as plt
import numpy as np

def plot_metrics(metric_lists, labels, output_pdf_path, title="Model Performance Metrics Over Time", start_year=2001):
    """
    Plots multiple metric lists (each containing dictionaries for each year).
    
    Parameters
    ----------
    metric_lists : list of list of dict
        A list where each element is one "metric_list", i.e. the output of `evaluate_rolling_model`.
        Example structure:
            metric_lists[i] = [
               {'Avg. AUC': ..., 'Xentropy': ..., 'Avg. Transition Prob': ...},
               ...
            ]
    labels : list of str
        The labels corresponding to each metric list (used in the legend).
        Length must match len(metric_lists).
    output_pdf_path : str
        The path (including filename) where the PDF will be saved.
    title : str
        The figure title.
    """
    
    # --- Safety checks ---
    if len(metric_lists) == 0:
        raise ValueError("metric_lists is empty. Provide at least one list of metrics.")
    if len(metric_lists) != len(labels):
        raise ValueError("Length of `labels` must match length of `metric_lists`.")

    # --- Determine length from the first metric list ---
    num_years = len(metric_lists[0])
    # If each metric list has the same rolling-window length, we can do:
    years = [int(start_year + i) for i in range(num_years)]  # Example: 2001 + i
    # Or if your start year is dynamic, pass it in or infer it differently

    # --- Prepare figure ---
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))
    fig.suptitle(title, fontsize=14)

    # Create a color palette. For up to 10 metric lists, 'tab10' is nice.
    # If you expect more, use e.g. 'tab20' or create a bigger palette.
    cmap = plt.get_cmap('tab10')
    colors = [cmap(i) for i in range(len(metric_lists))]

    lines_for_legend = []
    labels_for_legend = []

    # --- Plot each metric list in turn ---
    for i, (ml, lbl) in enumerate(zip(metric_lists, labels)):
        # Extract relevant metrics
        aucs       = [m['Avg. AUC'] for m in ml]
        xentropy   = [m['Xentropy'] for m in ml]
        trans_prob = [m['30dd-60dd'] for m in ml]
        #"c-30dd": custom_round(auc_matrix[4,1],3),
        #"c-c": custom_round(auc_matrix[4,0],3),
        #"c-paid-off": custom_round(auc_matrix[4,6],3),
       # "30dd-60dd": custom_round(auc_matrix[3,2],3),
        # Color for this metric list
        c = colors[i]

        # Plot on ax1
        l1 = ax1.plot(
            years, aucs, 
            marker='o', linestyle='-', color=c, linewidth=2, markersize=5, label=lbl
        )[0]
        # We only need to collect one line object per metric_list for the legend:
        lines_for_legend.append(l1)
        labels_for_legend.append(lbl)

        # Plot on ax2
        ax2.plot(
            years, xentropy,
            marker='o', linestyle='-', color=c, linewidth=2, markersize=5
        )

        # Plot on ax3
        ax3.plot(
            years, trans_prob,
            marker='o', linestyle='-', color=c, linewidth=2, markersize=5
        )

    # --- Customize subplots ---
    ax1.set_ylabel('Average AUC')
    ax1.set_xlabel('Year')
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(years)
    ax1.tick_params(axis='x', rotation=45)

    ax2.set_ylabel('Cross-Entropy Loss')
    ax2.set_xlabel('Year')
    ax2.grid(True, alpha=0.3)
    ax2.set_xticks(years)
    ax2.tick_params(axis='x', rotation=45)

    ax3.set_ylabel(r'AUC 30dd $\to$ 60dd')
    ax3.set_xlabel('Year')
    ax3.grid(True, alpha=0.3)
    ax3.set_xticks(years)
    ax3.tick_params(axis='x', rotation=45)

    # --- Create a legend using the collected lines ---
    fig.legend(lines_for_legend, labels_for_legend, loc='lower center', ncol=len(metric_lists),
               bbox_to_anchor=(0.5, -0.02))

    # Adjust layout (leave space at bottom for legend)
    plt.tight_layout(rect=[0, 0.05, 1, 0.93])

    # --- Save figure as PDF ---
    fig.savefig(output_pdf_path, format='pdf', bbox_inches='tight', pad_inches=0.1)
    print(f"Saved plot to {output_pdf_path}")

    return fig

import matplotlib.pyplot as plt

def plot_c_metrics(metric_lists, labels, output_pdf_path, title="Model Performance Over Time", start_year=2001):
    """
    Plots the c-based metrics from multiple metric lists (each containing dictionaries for each year).
    Specifically plots:
      - "c-30dd"
      - "c-c"
      - "c-paid-off"
    
    Parameters
    ----------
    metric_lists : list of list of dict
        A list where each element is one "metric_list" (the output of `evaluate_rolling_model`).
        Example structure:
            metric_lists[i] = [
               {'c-30dd': ..., 'c-c': ..., 'c-paid-off': ...},
               ...
            ]
    labels : list of str
        The labels corresponding to each metric list (used in the legend).
        Length must match len(metric_lists).
    output_pdf_path : str
        The path (including filename) where the PDF will be saved.
    title : str
        The figure title.
    start_year : int
        The first year in your rolling-window metric list.
    """

    # --- Safety checks ---
    if len(metric_lists) == 0:
        raise ValueError("metric_lists is empty. Provide at least one list of metrics.")
    if len(metric_lists) != len(labels):
        raise ValueError("Length of `labels` must match length of `metric_lists`.")

    # --- Determine length from the first metric list ---
    num_years = len(metric_lists[0])
    years = [int(start_year + i) for i in range(num_years)]
    
    # --- Prepare figure ---
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300

    fig, (ax1, ax3) = plt.subplots(1, 2, figsize=(15, 4))
    fig.suptitle(title, fontsize=14)

    # Create a color palette
    cmap = plt.get_cmap('tab10')
    colors = [cmap(i) for i in range(len(metric_lists))]

    lines_for_legend = []
    labels_for_legend = []

    # --- Plot each metric list in turn ---
    for i, (ml, lbl) in enumerate(zip(metric_lists, labels)):
        # Extract relevant "c-*" metrics
        c_30dd      = [m["c-30dd"]      for m in ml]  # Or use m.get("c-30dd", None) if not always present
        c_c         = [m["c-c"]         for m in ml]
        c_paid_off  = [m["c-paid-off"]  for m in ml]

        c = colors[i]

        # Plot on ax1
        l1 = ax1.plot(
            years, c_30dd,
            marker='o', linestyle='-', color=c, linewidth=2, markersize=5, label=lbl
        )[0]
        # Collect one line object for the legend:
        lines_for_legend.append(l1)
        labels_for_legend.append(lbl)

        # Plot on ax2
        #ax2.plot(
        #    years, c_c,
       #     marker='o', linestyle='-', color=c, linewidth=2, markersize=5
        #)

        # Plot on ax3
        ax3.plot(
            years, c_paid_off,
            marker='o', linestyle='-', color=c, linewidth=2, markersize=5
        )

    # --- Customize subplots ---
    ax1.set_ylabel(r'AUC C $\to$ 30dd')
    ax1.set_xlabel('Year')
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(years)
    ax1.tick_params(axis='x', rotation=45)

    #ax2.set_ylabel(r'AUC C $\to$ C')
    #ax2.set_xlabel('Year')
    #ax2.grid(True, alpha=0.3)
    #ax2.set_xticks(years)
    #ax2.tick_params(axis='x', rotation=45)

    ax3.set_ylabel(r'AUC C $\to$ Paid-Off')
    ax3.set_xlabel('Year')
    ax3.grid(True, alpha=0.3)
    ax3.set_xticks(years)
    ax3.tick_params(axis='x', rotation=45)

    # --- Create a legend ---
    fig.legend(lines_for_legend, labels_for_legend, loc='lower center', ncol=len(metric_lists),
               bbox_to_anchor=(0.5, -0.02))

    # Adjust layout to make space for the legend
    plt.tight_layout(rect=[0, 0.05, 1, 0.93])

    # --- Save figure as PDF ---
    fig.savefig(output_pdf_path, format='pdf', bbox_inches='tight', pad_inches=0.1)
    print(f"Saved plot to {output_pdf_path}")

    return fig

import matplotlib.pyplot as plt
import os

def plot_and_save_g_matrix(G, title, filepath):
    """
    G: (B, P, P) gating matrix
    We'll plot the first sample in the batch for simplicity.
    """
    if G.dim() == 3:
        G = G[0]  # (P, P)
    G_np = G.detach().cpu().numpy()

    plt.figure(figsize=(6,5))
    plt.imshow(G_np, cmap="viridis", aspect="auto")
    plt.colorbar()
    plt.title(title)
    plt.savefig(filepath, bbox_inches='tight')
    plt.close()

def partial_forward_set_encoder(x, set_encoder):
    """
    Partially run the SetEncoder forward to retrieve:
      x_in  -> shape [B, P, T, input_feature_dim]
      x_1   -> shape [B, P, T, feature_embedding_dim]
    """
    # replicate shape logic
    x_in = torch.transpose(x, 1, 2)
    x_in = torch.transpose(x_in, 2, 3)
    #x_1, _ = set_encoder.m_1(x_in)
    x_1 = x_in
    return x_in, x_1

import torch
import torch.nn.functional as F

def compute_gating_matrix_and_output(x, gate_net, proj_token_dim,V,_m1_forward,K, t_g=0):
    """
    Replicates the gating logic from M2.gated_selection_forward.
    Returns:
        G: (B, P, P)  -- the gating matrix
        x_gated: (B, P, T, D_out) -- the gated, projected output
    """
    B, P, T, D = x.shape
    # 1) Apply your chunk-based feedforward "M1" (e.g. local convolution)
    #    This transforms raw x into x1 of shape (B, P, T, m1_output_dim).
    x1 = _m1_forward(x)  # shape => (B, P, T, m1_output_dim)

    # 2) Compute gating 'keys' from that chunk embedding
    #    K(...) -> shape (B, P, T, m1_output_dim)
    xk = K(x1)  

    # 3) Extract K from the single time step t_g (e.g. 0)
    #    e_tg => (B, P, m1_output_dim)
    e_tg = xk[:, :, t_g, :]

    # 4) Expand out (i, j) pairs to feed into your gating net
    e_tg_i = e_tg.unsqueeze(2).expand(-1, -1, P, -1)  # => (B, P, P, m1_output_dim)
    e_tg_j = e_tg.unsqueeze(1).expand(-1, P, -1, -1)  # => (B, P, P, m1_output_dim)
    e_tg_ij = torch.cat([e_tg_i, e_tg_j], dim=-1)     # => (B, P, P, 2*m1_output_dim)

    # 5) Pass that into gate_net to get gating logits
    gates_raw, _ = gate_net(e_tg_ij)             # => (B, P, P, 1)
    # Typically you'd do softmax over the last dimension:
    G = F.softmax(gates_raw.squeeze(-1), dim=-1)      # => (B, P, P)

    # 6) Now produce aggregator "values" from x1 using V(...)
    #    e.g. xv => (B, P, T, v_dim)
    xv = V(x1)

    # 7) Weighted sum across j with the gating distribution G_{i,j}
    #    x_gated => (B, P, T, v_dim)
    x_gated = torch.einsum("b i j, b j t d -> b i t d", G, xv)

    # 8) Optionally project the aggregator output to d_output
    #    => (B, P, T, d_output)
    x_gated =proj_token_dim(x_gated)
    return G, x_gated


import torch
import os
import matplotlib.pyplot as plt

def partial_forward_block_set_encoder(x, model, block_idx):
    """
    x shape: [B, F, P, T]
    1) Pass x through the top-level encoder => rep
    2) pass rep through block0..block_idx-1
    3) partial forward the block_idx-th block's set_encoder => get x_in, x_1
    Returns:
      x_in, x_1 that is the input to that block_idx-th M2
    """
    # -- pass x through top-level encoder
    rep, _ = model.encoder(x)  # shape [B, P, T, D]
    
    # -- pass through earlier blocks
    kwargs = {"nr_units": 2500}
    for i in range(block_idx):
        rep, _ = model.model.layers[i](rep, **kwargs)

    # -- Now partial forward the set_encoder in block_idx
    set_enc = model.model.layers[block_idx].layer.set_encoder
    
    # replicate the lines from set_enc.forward() up to m_1
    # set_enc does:
    #   x_in = transpose(...),
    #   x_in, _ = self.m_1(x_in),
    # but note that the input to set_enc here is 'rep' (not x),
    # depending on your code structure
    #
    # Actually, your code shows set_enc.forward always does a transpose of shape 
    # [B, F, P, T] -> [B, P, T, F].
    # But your 'rep' might already be [B, P, T, D]. 
    # If so, you might skip the transpose or replicate exactly how block i does it.
    #
    # For simplicity, let's assume the set_enc is the same as top-level:
    # If so, we do partial_forward_set_encoder. But that may re-transpose incorrectly. 
    # You must adapt it carefully to match how the block uses set_enc.

    # If your block's set_encoder is used like:
    #   x_1_block, _ = set_enc.m_1(rep)
    # then just do:
    
    x_in = rep  # or reshape if needed
    # unsqueeze x_in to [B, F, P, T]
    x_in = x_in.unsqueeze(0)
    x_1 = x_in
    #x_1, _ = set_enc.m_1(x_in)
    return x_in, x_1

def get_loan_sort_indices(
    x_sample: torch.Tensor,
    zip_feature_indices=[50, 51, 52, 53],
    prime_feature_index=25,
    time_idx=0
):
    """
    x_sample shape: [B, F, P, T]
      B: batch
      F: features
      P: loans
      T: timesteps

    We assume B=1, so we'll just use x_sample[0] for sorting (or pick the first in the batch).
    
    Returns:
      sorted_indices: a torch.LongTensor of shape [P], 
        representing the order in which to rearrange the gating matrix.
    """
    # We'll assume B=1 for simplicity. If B>1, you might need to handle each sample separately.
    x_ = x_sample[0]  # shape (F, P, T)
    
    # Choose a single time slice for feature inspection (e.g. time_idx=0).
    # shape => (F, P)
    x_t = x_[:, :, time_idx]

    P = x_t.shape[1]

    # For each loan p in [0..P-1], find zip_code_cat and prime_cat
    zip_codes = x_t[zip_feature_indices, :]  # shape (4, P)
    prime_vals = x_t[prime_feature_index, :] # shape (P,)

    categories = []
    for p in range(P):
        # 1) Zip code cat: look for index k in [0..3] where zip_codes[k,p] == 1
        #    If none is 1, we say zip_cat = 4 (or -1) to indicate missing.
        zip_one_hot = zip_codes[:, p]  # shape (4,)
        zip_cat = torch.argmax(zip_one_hot).item()  # best guess
        if zip_one_hot.max() < 0.5:  # no 1 found => missing
            zip_cat = 4

        # 2) Prime cat: 
        #    0 => if prime=1, 1 => if prime=0, 2 => else missing
        prime_val = prime_vals[p].item()
        if abs(prime_val - 1.0) < 1e-5:
            prime_cat = 0  # prime=1
        elif abs(prime_val - 0.0) < 1e-5:
            prime_cat = 1  # prime=0
        else:
            prime_cat = 2  # missing/other

        # We'll store a tuple
        categories.append((zip_cat, prime_cat, p))

    # Now we sort by (zip_cat, prime_cat), ascending
    # That means all zip_cat=0 come first, within that prime=0 cat, etc.
    categories_sorted = sorted(categories, key=lambda x: (x[0], x[1]))

    # Extract just the loan indices p in new order
    sorted_indices = [x[2] for x in categories_sorted]
    return torch.LongTensor(sorted_indices)


def sort_gating_matrix(G, sort_idx):
    """
    Reorders gating matrix G along both dimensions
    using the same permutation sort_idx.

    G: shape (P, P) or (B, P, P)
    sort_idx: shape (P,)

    Returns G_sorted of the same shape, but re-ordered
    along the P dimension.
    """
    if G.dim() == 3:
        # Assume B=1 for simplicity; if B>1, adapt as needed
        G = G[0]
    # G shape => (P, P)
    G_sorted = G[sort_idx][:, sort_idx]
    return G_sorted.unsqueeze(0)  # re-add batch dim if you like

def compute_boundaries(categories_sorted):
    """
    categories_sorted is a list of tuples [(zip_cat, prime_cat, original_index), ...]
    in sorted order. We assume it is already sorted by (zip_cat, prime_cat).

    Returns:
      zip_boundaries: list of vertical/horizontal lines to draw for zip
      prime_boundaries: list of vertical/horizontal lines for prime
    """
    zip_boundaries = []
    prime_boundaries = []

    if not categories_sorted:
        return zip_boundaries, prime_boundaries

    current_zip = categories_sorted[0][0]
    current_prime = categories_sorted[0][1]

    for i, (zcat, pcat, _) in enumerate(categories_sorted):
        # If the zip category changes, we mark a new boundary
        if zcat != current_zip:
            zip_boundaries.append(i)  # boundary before index i
            current_zip = zcat
            # Reset prime boundary since prime categories start over with new zip
            current_prime = pcat

        # If prime category changes, we mark a new boundary
        if pcat != current_prime:
            prime_boundaries.append(i)  # boundary before index i
            current_prime = pcat

    return zip_boundaries, prime_boundaries



def visualize_gated_selection(
    train, 
    val, 
    test, 
    dataset, 
    model_config,
    layer=0  # default: top-level encoder
):
    """
    Loads a model checkpoint, takes one sample from val,
    extracts gating matrix G from the specified layer (default=0 => top-level),
    sorts it by zip-code (features 50-53) and prime-flag (feature 25), 
    and saves the sorted gating matrix to a PDF.
    """
    # 1) Load checkpoint
    i_ckpt = 7 #hejhej
    model_config["checkpoint_path"] = model_config["checkpoint_paths"][i_ckpt]
    model = load_model_corelogic(**model_config)
    model.eval()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    val_date = model_config["checkpoint_paths"][i_ckpt].split("_")[-1][:-5]
    update_data(train, val, test, dataset, val_date)
    
    if len(val) == 0:
        print("Val is empty. Exiting.")
        return

    # 2) Grab a single sample (x_sample)
    x_sample, y_sample, _ = val[0]  # or however your dataset returns
    x_sample = x_sample.to(device)
    if x_sample.dim() == 3:
        x_sample = x_sample.unsqueeze(0)  # add batch dim => (B, F, P, T)

    # 3) Identify which M2 we want to use
    # layer=0 => top-level SetEncoder (model.encoder[0].m_2)
    # layer>0 => inside model.model.layers[layer-1].layer.set_encoder.m_2
    if layer == 0:
        set_encoder = model.encoder[0]
        m2 = set_encoder.m_2
        # partial forward => x_in, x_1
        x_in, x_1 = partial_forward_set_encoder(x_sample, set_encoder)
        layer_name = "encoder0"
    else:
        block_idx = layer - 1
        # We'll pass x_sample through the top-level encoder + up to block_idx
        # Then partial forward that block's set_encoder
        set_encoder = model.model.layers[block_idx].layer.set_encoder
        # we need a helper: get_input_to_block_m2
        x_in, x_1 = partial_forward_block_set_encoder(x_sample, model, block_idx)
        m2 = set_encoder.m_2
        layer_name = f"layer{block_idx+1}"

    # Check architecture
    if getattr(m2, "architecture", None) != "gated_selection":
        print(f"M2 at layer={layer} is not 'gated_selection'. Nothing to visualize.")
        return

    # 4) Compute gating matrix G
    G, x_gated = compute_gating_matrix_and_output(
        x_1,               # (B, P, T, D_in)
        m2.gate_net,       # gate_net
        m2.proj_token_dim, # final projection
        m2.V,
        m2._m1_forward,
        m2.K,
        t_g=0
    )
    # G => (B, P, P)

    # 5) Build sort order from original x_sample
    sort_idx = get_loan_sort_indices(
        x_sample, 
        zip_feature_indices=[50, 51, 52, 53],
        prime_feature_index=25,
        time_idx=0
    ).to(device)  # shape [P]

    P = G.shape[1]           # number of loans

    # Indices for zip-code one-hot features
    zip_feature_indices = [50, 51, 52, 53]
    prime_feature_index = 25
    x_ = x_sample[0, :, :, 0]  # shape => (F, P)
    P = x_.shape[1]
    zip_codes = x_[zip_feature_indices, :]  # shape => (4, P)
    prime_vals = x_[prime_feature_index, :] # shape => (P,)

    categories = []
    for p in range(P):
        # --- zip-cat ---
        zip_one_hot = zip_codes[:, p]                # shape => (4,)
        zip_cat = torch.argmax(zip_one_hot).item()   # pick whichever one-hot index is max
        if zip_one_hot.max() < 0.5:
            # if none is >= 0.5, treat as category 4 (missing or 'no zip')
            zip_cat = 4

        # --- prime-cat ---
        prime_val = prime_vals[p].item()
        # e.g. 0 => prime=1, 1 => prime=0, 2 => missing/other
        if abs(prime_val - 1.0) < 1e-5:
            prime_cat = 0
        elif abs(prime_val - 0.0) < 1e-5:
            prime_cat = 1
        else:
            prime_cat = 2

        categories.append((zip_cat, prime_cat, p))

    categories_sorted = sorted(categories, key=lambda x: (x[0], x[1]))

    # 3) compute boundaries
    zip_boundaries, prime_boundaries = compute_boundaries(categories_sorted)

    # 6) Sort G
    G_sorted = sort_gating_matrix(G, sort_idx)  # => (B, P, P)
    
    # 7) Plot
    save_dir = f"{BASE_PATH}/scripts/notebooks/plot"
    os.makedirs(save_dir, exist_ok=True)
    out_path = os.path.join(save_dir, f"G_{layer_name}_sorted_8.pdf")
    from matplotlib.colors import LogNorm
    # If you use the previous "plot_and_save_g_matrix" helper, or do inline:
    plt.figure(figsize=(12,10))
    G_plot = G_sorted[0].detach().cpu().numpy()  # (P, P), B=1\
    eps = 1e-10
    G_plot += eps
    lower_bound = np.percentile(G_plot, 0.1)   # 5th percentile
    upper_bound = np.percentile(G_plot, 99.9)  # 95th percentile

    plt.imshow(
        G_plot,
        cmap="viridis",
        aspect="auto",
        norm=LogNorm(vmin=lower_bound, vmax=upper_bound)
    )
    #plt.imshow(G_plot, cmap="viridis", aspect="auto", norm= LogNorm())
    #plt.colorbar()
    
    for zb in zip_boundaries:
        plt.axhline(zb - 0.5, color="white", linewidth=0.8)
        plt.axvline(zb - 0.5, color="white", linewidth=0.8)

    for pb in prime_boundaries:
        plt.axhline(pb - 0.5, color="red", linewidth=0.8)
        plt.axvline(pb - 0.5, color="red", linewidth=0.8)

    plt.colorbar()
    plt.title(f"Gating Matrix (Sorted by zip_cat, prime_cat)")
    plt.savefig(out_path, bbox_inches='tight', dpi=300)
    plt.close()
    
    #plt.title(f"Gating Matrix {layer_name} (Sorted)")
    #plt.savefig(out_path, bbox_inches='tight')
    #plt.close()

    print(f"Saved sorted gating matrix to {out_path}")




def get_set_var(x, model, layer=0, nr_units=1000):
    """
    Extract the set variable (latent factors) from the specified layer.
    
    For layer == 0:
        * Process x by transposing to the expected shape.
        * Use the top-level encoder’s m_1 and m_2 (as before) to compute the set variable.
    
    For layer >= 1:
        * Feed x through the encoder’s forward method (so that the representation is
          processed through m_1, m_2, and m_3).
        * Pass the resulting representation through the first (layer - 1) residual blocks.
        * Finally, use the set module (m_1 then m_2) from the (layer - 1)th residual block.
    
    Args:
        x (Tensor): Input tensor of shape [B, nr_features, nr_loans, nr_timesteps].
        model: The model instance.
        layer (int): Specifies which set module to use.
                     - layer == 0 uses model.encoder[0] (with manual transpose).
                     - For layer >= 1, the encoder forward method is used, followed by
                       residual blocks up to the target block, whose set module is applied.
    
    Returns:
        Tensor: The computed set variable.
    """
    if layer == 0:
        # For the final set variable from the top-level encoder, perform the transposition
        # and call m_1 and m_2 directly.
        x_in = torch.transpose(x, 1, 2)
        x_in = torch.transpose(x_in, 2, 3)
        #m1 = model.encoder[0].m_1
        m2 = model.encoder[0].m_2
        #x_1, _ = m1(x_in)
        x_2 = m2(x_in)
        return x_2
    else:
        # For intermediate layers, first feed x through the encoder's forward method.
        # This ensures that the representation goes through all components of the encoder,
        # including m_3.
        
        rep, _ = model.encoder(x)
        
        # Pass the representation through the preceding residual blocks (if any).
        # For example, if layer==1, no residual block is applied before using the first block's set module.
        kwargs = {"nr_units": nr_units} # 1000
        for i in range(layer - 1):
            rep, _ = model.model.layers[i](rep, **kwargs)
        
        # Now use the set module from the (layer - 1)th residual block.
        if (layer - 1) >= len(model.model.layers):
            raise ValueError(f"Requested layer {layer} but the model has only "
                             f"{len(model.model.layers)} residual block(s).")
        
        block = model.model.layers[layer - 1]
        set_encoder = block.layer.set_encoder
        #x_1_block, _ = set_encoder.m_1(rep.unsqueeze(0))
        x_2 = set_encoder.m_2(rep.unsqueeze(0))
        return x_2


def main():
    # Example dataset config
    dataset_config = {
        "path_origination": f"{CORELOGIC_DATA_PATH}/filtered_origination_data_top_4_zips.csv",
        "path_performance": f"{CORELOGIC_DATA_PATH}/filtered_performance_data_top_4_zips.csv",
        "normalize_data": True,
        "database_size": 300000,
        "start_year": 1988,
        "end_year": 2023,
    }

    config = {
        "_name_": "corelogic_loan_dataset",
        "dataset_config": dataset_config,
        "val_split": 0.3,
        "test_split": 0.1,
        "val_split_date": "2009-06",
        "test_split_date": "2009-12",
        "load_data": True,
        "save_data": False,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_55_test.npz",
        "max_to_sample": 4500,
        "nr_sampling_timesteps": 50,
        "nr_loans_to_sample": 2500,
        "steps_per_epoch": 70,
        "sample_random_loan_index": True,
        "sample_random_time_index": True,
        "eval_mode": True,
        "eval_seed": 3000
    }

    
    # Toggle this to True if you want to load from default .npy files,
    # or False if you prefer o evaluate metrics on the fly.
    load_metric_files = False
    yr = 2002
    # Default file paths (change as needed)
    metric_file_no_hf = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}.npy"
    metric_file_hf_24 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_hf_24.npy"
    metric_file_same_model = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_same_model_{yr}.npy"
    metric_file_nn_cl = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_nn_cl.npy"
    metric_file_linear_cl = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_linear_cl.npy"
    metric_file_set_seq = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_set_seq.npy"
    metric_file_gated_selection_cl = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_gated_selection_cl.npy"
    metric_file_gated_selection_cl_55 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_gated_selection_cl_55.npy"
    metric_file_nn_cl_55 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_nn_cl_55.npy"
    metric_file_nn_cl_58 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_nn_cl_58.npy"
    metric_file_linear_cl_58 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_linear_cl_58.npy"
    metric_file_gated_selection_cl_58 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_gated_selection_cl_58.npy"
    metric_file_gated_selection_dmod32 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_gated_selection_dmod32_mar13.npy"
    metric_file_debug_gated_selection_dmod32 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_debug_gated_selection_dmod32_mar13.npy"
    metric_file_gated_selection_dmod32_kl_50 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_gated_selection_dmod32_mar13_kl_50.npy"
    metric_file_debug_gated_selection_dmod32_kl_50 = f"{BASE_PATH}/scripts/notebooks/plot/rolling_model_metrics_{yr}_debug_gated_selection_dmod32_mar13_kl_50.npy"
    
    # Where to save the PDF
    #output_pdf_path = f"{BASE_PATH}/scripts/notebooks/plot/all_metrics_{yr}_compare_models.pdf"
    #output_pdf_path_c_metrics = f"{BASE_PATH}/scripts/notebooks/plot/all_metrics_{yr}_compare_models_c_metrics.pdf"
    #output_pdf_path = f"{BASE_PATH}/scripts/notebooks/plot/all_metrics_{yr}_compare_models_55_58.pdf"
    #output_pdf_path_c_metrics = f"{BASE_PATH}/scripts/notebooks/plot/all_metrics_{yr}_compare_models_c_metrics_55_58.pdf"
    output_pdf_path = f"{BASE_PATH}/scripts/notebooks/plot/all_metrics_{yr}_compare_models_dmod32_mar13.pdf"
    output_pdf_path_c_metrics = f"{BASE_PATH}/scripts/notebooks/plot/all_metrics_{yr}_compare_models_c_metrics_dmod32_mar13.pdf"
    
    visualize_gating = False
    if visualize_gating:
        train, val, test, dataset = get_dataset(config)
        rolled_gated_selection_model_config_55 = get_model_config("rolled_gated_selection_model_config_55_visualize_layer")
        nr_layers = 1
        for layer in range(nr_layers):
            visualize_gated_selection(train, val, test, dataset, rolled_gated_selection_model_config_55, layer=layer)
    plot_metrics_bool = True
    if plot_metrics_bool:
        if load_metric_files:
            # Load from saved .npy
            print("Loading metrics from files...")
            #metric_list_no_hf = np.load(metric_file_no_hf, allow_pickle=True)
            #metric_list_hf_24 = np.load(metric_file_hf_24, allow_pickle=True)
            #metric_list_nn_cl = np.load(metric_file_nn_cl, allow_pickle=True)
            #metric_list_same_model = np.load(metric_file_same_model, allow_pickle=True)
            #metric_list_set_seq = np.load(metric_file_set_seq, allow_pickle=True)
            #metric_list_nn_cl = np.load(metric_file_nn_cl, allow_pickle=True)
            #metric_list_linear_cl = np.load(metric_file_linear_cl, allow_pickle=True)
            #metric_list_gated_selection_cl = np.load(metric_file_gated_selection_cl, allow_pickle=True)
            #metric_list_set_seq = np.load(metric_file_set_seq, allow_pickle=True)
            #metric_list_gated_selection_cl_55 = np.load(metric_file_gated_selection_cl_55, allow_pickle=True)
            #metric_list_nn_cl_55 = np.load(metric_file_nn_cl_55, allow_pickle=True)
            #metric_list_linear_cl_58 = np.load(metric_file_linear_cl_58, allow_pickle=True)
            #metric_list_gated_selection_cl_58 = np.load(metric_file_gated_selection_cl_58, allow_pickle=True)
            #metric_list_nn_cl_58 = np.load(metric_file_nn_cl_58, allow_pickle=True)
            metric_list_gated_selection_dmod32 = np.load(metric_file_gated_selection_dmod32, allow_pickle=True)
            metric_list_debug_gated_selection_dmod32 = np.load(metric_file_debug_gated_selection_dmod32, allow_pickle=True)
            #print("Evaluating rolled model model...")
            #train, val, test, dataset = get_dataset(config)
            #gated selection
            #rolled_gated_selection_model_config = get_model_config("rolled_gated_selection_model_config")
        # metric_list_gated_selection_cl = evaluate_rolling_model(rolled_gated_selection_model_config, train, val, test, dataset)
            #np.save(metric_file_gated_selection_cl, metric_list_gated_selection_cl)

            #Temp
            ##print("Evaluating rolled model model...")
            #train, val, test, dataset = get_dataset(config)
            #rolled_model_config = get_model_config("rolled_model_config")
            #metric_list_set_seq = evaluate_rolling_model(rolled_model_config, train, val, test, dataset)
            #np.save(metric_file_set_seq, metric_list_set_seq)
            

        else:
            print("Loading data...")
            train, val, test, dataset = get_dataset(config)
            print("Data loaded")
            # Generate metrics on the fly
            print("Evaluating rolling models on the fly...")
            #"rolled_gated_selection_model_config_55": rolled_gated_selection_model_config_55,
            #"rolled_nn_model_config_55": rolled_nn_model_config_55,
        # "rolled_nn_model_config_58": rolled_nn_model_config_58,
            #"linear_model_config_58": linear_model_config_58,
            #"gated_selection_model_config_58": gated_selection_model_config_58,
            #rolled_gated_selection_model_config_55 = get_model_config("rolled_gated_selection_model_config_55")
            #metric_list_gated_selection_cl_55 = evaluate_rolling_model(rolled_gated_selection_model_config_55, train, val, test, dataset)
           # np.save(metric_file_gated_selection_cl_55, metric_list_gated_selection_cl_55)

            #rolled_nn_model_config_55 = get_model_config("rolled_nn_model_config_55")
            #metric_list_nn_cl_55 = evaluate_rolling_model(rolled_nn_model_config_55, train, val, test, dataset)
            #np.save(metric_file_nn_cl_55, metric_list_nn_cl_55)
            #config["data_path"] = f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_58.npz"
            #print("Loading data...")
            #train, val, test, dataset = get_dataset(config)
           # print("Data loaded")

            #rolled_nn_model_config_58 = get_model_config("rolled_nn_model_config_58")
            #metric_list_nn_cl_58 = evaluate_rolling_model(rolled_nn_model_config_58, train, val, test, dataset)
            #np.save(metric_file_nn_cl_58, metric_list_nn_cl_58)

            #linear_model_config_58 = get_model_config("linear_model_config_58")
            #metric_list_linear_cl_58 = evaluate_rolling_model(linear_model_config_58, train, val, test, dataset)
            #np.save(metric_file_linear_cl_58, metric_list_linear_cl_58)
            #config["data_path"] = f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_58.npz"
            #print("Loading data...")
            #train, val, test, dataset = get_dataset(config)
            #print("Data loaded")

            #gated_selection_model_config_58 = get_model_config("gated_selection_model_config_58")
            #metric_list_gated_selection_cl_58 = evaluate_rolling_model(gated_selection_model_config_58, train, val, test, dataset)
            #np.save(metric_file_gated_selection_cl_58, metric_list_gated_selection_cl_58)
            
            
            gated_selection_dmod32_model_config = get_model_config("gated_selection_dmod32")
            metric_list_gated_selection_dmod32 = evaluate_rolling_model(gated_selection_dmod32_model_config, train, val, test, dataset)
            np.save(metric_file_gated_selection_dmod32, metric_list_gated_selection_dmod32)

            debug_gated_selection_dmod32_model_config_kl_50 = get_model_config("debug_gated_selection_dmod32_kl_50")
            metric_list_debug_gated_selection_dmod32_kl_50 = evaluate_rolling_model(debug_gated_selection_dmod32_model_config_kl_50, train, val, test, dataset)
            np.save(metric_file_debug_gated_selection_dmod32_kl_50, metric_list_debug_gated_selection_dmod32_kl_50)
            
            gated_selection_dmod32_model_config_kl_50 = get_model_config("gated_selection_dmod32_kl_50")
            metric_list_gated_selection_dmod32_kl_50 = evaluate_rolling_model(gated_selection_dmod32_model_config_kl_50, train, val, test, dataset)
            np.save(metric_file_gated_selection_dmod32_kl_50, metric_list_gated_selection_dmod32_kl_50) 

            debug_gated_selection_dmod32_model_config = get_model_config("debug_gated_selection_dmod32")
            metric_list_debug_gated_selection_dmod32 = evaluate_rolling_model(debug_gated_selection_dmod32_model_config, train, val, test, dataset)
            np.save(metric_file_debug_gated_selection_dmod32, metric_list_debug_gated_selection_dmod32)

            


            #rolled_nn_model_config = get_model_config("rolled_nn_model_config")
            #metric_list_nn_cl = evaluate_rolling_model(rolled_nn_model_config, train, val, test, dataset)

            #rolled_linear_model_config = get_model_config("rolled_linear_model_config")
        # metric_list_linear_cl = evaluate_rolling_model(rolled_linear_model_config, train, val, test, dataset)

            #rolled_model_config = get_model_config("rolled_model_config")
            #metric_list_set_seq = evaluate_rolling_model(rolled_model_config, train, val, test, dataset)

            #rolled_gated_selection_model_config = get_model_config("rolled_gated_selection_model_config")
            #metric_list_gated_selection_cl = evaluate_rolling_model(rolled_gated_selection_model_config, train, val, test, dataset)

            #rolled_model_config_half_life_24 = get_model_config("rolled_model_half_life_24")
            #metric_list_hf_24 = evaluate_rolling_model(rolled_model_config_half_life_24, train, val, test, dataset)


            #rolled_model_config_same_model = get_model_config("rolled_model_config")
            #metric_list_same_model = evaluate_rolling_model(rolled_model_config_same_model, train, val, test, dataset, same_model=True)
            # Save for future re-use
            #print("Saving generated metrics to disk...")
            #np.save(metric_file_no_hf, metric_list_no_hf)
            #np.save(metric_file_hf_24, metric_list_hf_24)
            #np.save(metric_file_nn_cl, metric_list_nn_cl)
            #np.save(metric_file_linear_cl, metric_list_linear_cl)
            #np.save(metric_file_set_seq, metric_list_set_seq)
            #np.save(metric_file_gated_selection_cl, metric_list_gated_selection_cl)
            #np.save(metric_file_same_model, metric_list_same_model)
        # Prepare for plotting:
        #metric_lists = [metric_list_same_model, metric_list_no_hf, metric_list_hf_24, metric_list_nn_cl]
        #labels = ["No Re-Train","No Train Half-life", "Train Half-life 24 Months", "NN Half-life 24 Months"]
        #metric_lists = [metric_list_set_seq, metric_list_nn_cl, metric_list_linear_cl, metric_list_gated_selection_cl]
        #labels = ["Set-Seq", "NN", "Logistic", "Set-Seq Gated Selection"]
        #metric_lists = [metric_list_gated_selection_cl_55, metric_list_nn_cl_55, metric_list_linear_cl_58, metric_list_gated_selection_cl_58, metric_list_nn_cl_58]
        #labels = ["Gated Selection 55", "NN 55", "Linear 58", "Gated Selection 58", "NN 58"]
        breakpoint()
        metric_lists = [metric_list_gated_selection_dmod32, metric_list_debug_gated_selection_dmod32,metric_list_gated_selection_dmod32_kl_50, metric_list_debug_gated_selection_dmod32_kl_50]
        labels = ["Gated Selection (KL 1)", "Set var 0 (KL 1)", "Gated Selection (KL 50)", "Set var 0 (KL 50)"]
        start_year = 2003
        print("Creating plot...")
        fig = plot_metrics(metric_lists, labels, output_pdf_path, title="Model Performance Metrics Over Time", start_year=start_year)
        fig_c_metrics = plot_c_metrics(metric_lists, labels, output_pdf_path_c_metrics, title="Model Performance Metrics Over Time", start_year=start_year)
        print("Plot created.")

if __name__ == "__main__":
    main()