# 1. Load model
# 2. Get the data we think the set-variable will capture
# 3. For the samples, plot the set variable (normalized) and look for correlation with the data
import sys
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
sys.path.append(BASE_PATH)
from scripts.notebooks.true_loss_level.get_transition_probabilities import load_model_corelogic, load_model
from scripts.notebooks.true_loss_level.get_corelogic import get_dataset
from torch.utils.data import DataLoader, ConcatDataset
import torch
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.stats import pearsonr
from matplotlib.lines import Line2D


mpl.rcParams.update({
    "figure.dpi": 300,         # High-res figure for screen
    "savefig.dpi": 300,        # High-res output for saved files
    "font.size": 14,           # Base font size
    "axes.labelsize": 14,      # Axis label font size
    "xtick.labelsize": 10,     # X-tick label font size
    "ytick.labelsize": 10,     # Y-tick label font size
    "legend.fontsize": 10,     # Legend font size
    "axes.grid": True,         # Turn on grid by default
    "grid.linestyle": "-",    # Dashed grid lines
    "grid.color": "gray",      # Gray grid lines
    "grid.alpha": 0.7,         # Slightly transparent
    "lines.linewidth": 2,      # Thicker lines
    "lines.markersize": 6,     # Medium marker size
    "figure.facecolor": "white",
    "axes.facecolor": "white",
    # If you want LaTeX text, set "text.usetex": True and ensure a LaTeX environment.
})

def get_model_config(name):    

    model_config_700_loans = {
         "experiment": "timeseries/set_corelogic_top4",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-29/12-12-43/step_680.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    } # top 4 zip, 700 loans at a time

    model_config_700_batch_size_52 = {
         "experiment": "timeseries/set_corelogic_top4_52",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-29/17-40-51/step_720.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    } # top 4 zip, 1 loan at a time bz 700 outputs/2025-01-30/01-17-46/step_760.ckpt

    model_config_700_batch_size_50 = {
         "experiment": "timeseries/set_corelogic_top4",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/01-17-46/step_760.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config_2500_loan_size_50_j30 = {
         "experiment": "timeseries/set_corelogic_top4",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/12-22-08/step_920.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config_500_batch_size_50_j30 = {
         "experiment": "timeseries/set_corelogic_top4_1lz",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/12-52-51/step_920.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    model_config_500_batch_size_52_j30 = {
         "experiment": "timeseries/set_corelogic_top4_52",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-30/13-00-30/step_1080.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    test = {
         "experiment": "timeseries/set_corelogic_top4_exp",
        #"checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-20/21-34-01/step_6800.ckpt",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-02-01/18-59-36/step_1200.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    logunif008_best_set = {
        "experiment": "timeseries/ts_lc_set",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-16/19-14-21/step_6600.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" } #logunif 0.08


    model_config_synthetic = {
        "experiment": "timeseries/synthetics/synthetics_set_seq_variable_input.yaml",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-05-19/19-31-53/last.ckpt",
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json"
    }
    
    # 
    # 2025-02-01/11-24-46/step_1160.ckpt
    # 2025-01-31/15-21-07/step_1200.ckpt
    # 2025-02-01/11-24-46/step_1160.ckpt only set variable

    model_config_dict = {
        "top4-jan29-50": model_config_700_loans,
        "top4-jan29-52": model_config_700_batch_size_52,
        "top41lz-jan29-50": model_config_700_batch_size_50,
        "top4-jan30-50": model_config_2500_loan_size_50_j30,
        "top41lz-jan30-50": model_config_500_batch_size_50_j30,
        "top41lz-jan30-52": model_config_500_batch_size_52_j30,
        "test-jan30-50": test,
        "logunif008-best-set": logunif008_best_set,
        "synthetic_task": model_config_synthetic
    }
    if name in model_config_dict:
        return model_config_dict[name]
    
    assert False, f"Model config {name} not found"

def get_set_var1(x, model,layer=0):
    # x has shape [B, nr_features, nr_loans, nr_timesteps]
    
    if layer == 0:
        m1  = model.encoder[0].m_1
        m2 = model.encoder[0].m_2

        x_in = torch.transpose(x, 1, 2)
        x_in = torch.transpose(x_in, 2, 3)
        x_1, _ = m1(x_in)
        # 5. Apply m2 to get x_2 [BZ, L, common_pool_embedding_dim]
        x_2 = m2(x_1)
        if x_2.shape[-1] ==1:
            x_2 = x_2.expand(-1, -1, 2)
        return x_2
    else:
        pass
        # Implement this 

def get_set_var(x, model, layer=0, nr_units=1000):
    """
    Extract the set variable (latent factors) from the specified layer.
    
    For layer == 0:
        * Process x by transposing to the expected shape.
        * Use the top-level encoder’s m_1 and m_2 (as before) to compute the set variable.
    
    For layer >= 1:
        * Feed x through the encoder’s forward method (so that the representation is
          processed through m_1, m_2, and m_3).
        * Pass the resulting representation through the first (layer - 1) residual blocks.
        * Finally, use the set module (m_1 then m_2) from the (layer - 1)th residual block.
    
    Args:
        x (Tensor): Input tensor of shape [B, nr_features, nr_loans, nr_timesteps].
        model: The model instance.
        layer (int): Specifies which set module to use.
                     - layer == 0 uses model.encoder[0] (with manual transpose).
                     - For layer >= 1, the encoder forward method is used, followed by
                       residual blocks up to the target block, whose set module is applied.
    
    Returns:
        Tensor: The computed set variable.
    """
    if layer == 0:
        # For the final set variable from the top-level encoder, perform the transposition
        # and call m_1 and m_2 directly.
        x_in = torch.transpose(x, 1, 2)
        x_in = torch.transpose(x_in, 2, 3)
        #m1 = model.encoder[0].m_1
        m2 = model.encoder[0].m_2
        #x_1, _ = m1(x_in)
        x_2 = m2(x_in)
        return x_2
    else:
        # For intermediate layers, first feed x through the encoder's forward method.
        # This ensures that the representation goes through all components of the encoder,
        # including m_3.
        
        rep, _ = model.encoder(x)
        
        # Pass the representation through the preceding residual blocks (if any).
        # For example, if layer==1, no residual block is applied before using the first block's set module.
        kwargs = {"nr_units": nr_units} # 1000
        for i in range(layer - 1):
            rep, _ = model.model.layers[i](rep, **kwargs)
        
        # Now use the set module from the (layer - 1)th residual block.
        if (layer - 1) >= len(model.model.layers):
            raise ValueError(f"Requested layer {layer} but the model has only "
                             f"{len(model.model.layers)} residual block(s).")
        
        block = model.model.layers[layer - 1]
        set_encoder = block.layer.set_encoder
        #x_1_block, _ = set_encoder.m_1(rep.unsqueeze(0))
        
        x_2 = set_encoder.m_2(rep.unsqueeze(0))
        return x_2

def get_first_nonzero_row(lagged_foreclosure):
    """
    Finds and returns the first row in lagged_foreclosure[0, :, :]
    where all elements are nonzero.

    Returns:
        torch.Tensor or None: The first row with all nonzero elements, or None if no such row exists.
    """
    array = lagged_foreclosure[0, :, :]  # Extract the 2D slice

    # Compute mask where all elements in a row are nonzero
    mask = (array != 0).all(dim=1)  # Use dim=1 for rows

    # Find the first index where the mask is True
    if mask.any():
        first_index = torch.where(mask)[0][0].item()  # Get first index where condition is met
        return array[first_index]  # Return the corresponding row

    return None  # No valid row found

def plot_set_var(train, val, test, model):
    # get the lagged prepayment rate and the lagged foreclosure rate on the observed sample
    batch_size = 1
    test.steps_per_epoch = test.X.shape[-1]
    test.return_start_time = True
    dataloader = DataLoader(
        test, 
        batch_size=batch_size, 
        shuffle=False, 
        drop_last=True
    )
    all_y_pred = []
    all_y_true = []
    all_units_in_batch_dim = False
    correlations_0_foreclosure = []
    correlations_0_prepayment = []
    correlations_1_foreclosure = []
    correlations_1_prepayment = []
    corr_0_foreclosure_gt = []
    corr_1_foreclosure_gt = []
    corr_foreclosure_foreclosure_gt = []
    for batch in dataloader:
        x, y, valid_indices, start_time = batch  # x and y are PyTorch tensors, on CPU by default (unless you pinned them or moved them).
        # x has shape [B, nr_features, nr_loans, nr_timesteps]
        # If we're using a "set-seq" model on GPU, we might want to move x, y to GPU here:
        # But you can also do that inside the model or pass them as is. For example:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        x = x.to(device)
        x, lagged_foreclosure, lagged_prepayment = x[:, :-2, :, :], x[:, -2, :, :], x[:, -1, :, :]
        y = y.to(device)
        device = x.device
        if all_units_in_batch_dim:
            # Reshape to [B*n_loans, n_features, 1, n_time_steps]
            x = x.permute(0, 2, 1, 3).reshape(-1, x.shape[1], 1, x.shape[3])
        
        # Move model to device (if not already)
        model = model.to(device)
        model.eval()
        model._state = None

        # No need to convert x_batch to CPU or NumPy. We'll stay on GPU for speed if available.
        with torch.no_grad():
            # The user code calls model((x_batch, {}))[0]
            # shape might be [B, n_loans, n_time, n_classes], for example
            y_pred = model((x, {}))[0]

            set_var = get_set_var(x, model)
            
            # Then apply softmax over dim=3
            y_pred = F.softmax(y_pred, dim=3)
            
            # reshape y_pred to same shape as y_batch
            y_pred = y_pred.view(y.shape)

            if all_units_in_batch_dim:
                # Reshape back to [B, n_loans, n_time_steps, n_classes]
                y_pred = y_pred.reshape(-1, y.shape[1], y.shape[2], y.shape[3])

        lagged_prepayment_gt = torch.sum(x[0, 0,:,:-1], dim=0)/x.shape[2]  # Shift this by one step
        lagged_prepay_norm = (lagged_prepayment_gt - torch.mean(lagged_prepayment_gt))/torch.std(lagged_prepayment_gt)
        lagged_prepay_norm = lagged_prepay_norm.cpu().numpy().flatten()
        lagged_foreclosure = lagged_foreclosure[:,:,1:]
        lagged_prepayment = lagged_prepayment[:,:,1:]
        l_fc = get_first_nonzero_row(lagged_foreclosure).cpu().numpy().flatten()
        l_pr = get_first_nonzero_row(lagged_prepayment).cpu().numpy().flatten()
        norm_l_fc = (l_fc - np.mean(l_fc))/np.std(l_fc)
        norm_l_pr = (l_pr - np.mean(l_pr))/np.std(l_pr)
        plt.plot(norm_l_fc, label="Lagged Foreclosure Rate")
        plt.plot(norm_l_pr, label="Lagged Prepayment Rate")
        set_var_mean = torch.mean(set_var[0,1:,0], dim=0)
        set_var_std = torch.std(set_var[0,1:,0], dim=0)
        print(f"Set Variable Std: {set_var_std}")
        idx_0 = 0
        idx_1 = 1
        norm_0 = (set_var[0,1:,idx_0]- torch.mean(set_var[0,1:,idx_0]))/torch.std(set_var[0,1:,idx_0])
        norm_1 = (set_var[0,1:,idx_1]- torch.mean(set_var[0,1:,idx_1]))/torch.std(set_var[0,1:,idx_1])
        plt.plot(norm_0.cpu().numpy().flatten(), label="Set Variable 0")
        plt.plot(norm_1.cpu().numpy().flatten(), label="Set Variable 1")
        plt.plot(lagged_prepay_norm, label="Ground Truth Lagged Prepayment Rate")
        plt.legend()
        plt.savefig("./visualizations/set_var_corelogic.png")
        plt.close()
        breakpoint()
        
        corr_0_foreclosure = np.corrcoef(set_var[0,1:,0].cpu().numpy().flatten(), get_first_nonzero_row(lagged_foreclosure).cpu().numpy().flatten())[0][1]
        corr_0_prepayment = np.corrcoef(set_var[0,1:,0].cpu().numpy().flatten(), get_first_nonzero_row(lagged_prepayment).cpu().numpy().flatten())[0][1]
        corr_1_foreclosure = np.corrcoef(set_var[0,1:,1].cpu().numpy().flatten(), get_first_nonzero_row(lagged_foreclosure).cpu().numpy().flatten())[0][1]
        corr_1_prepayment = np.corrcoef(set_var[0,1:,1].cpu().numpy().flatten(), get_first_nonzero_row(lagged_prepayment).cpu().numpy().flatten())[0][1]
        corr_0_foreclosure_gt_s = np.corrcoef(set_var[0,1:,0].cpu().numpy().flatten(), lagged_prepayment_gt.cpu().numpy().flatten())[0][1]
        corr_1_foreclosure_gt_s = np.corrcoef(set_var[0,1:,1].cpu().numpy().flatten(), lagged_prepayment_gt.cpu().numpy().flatten())[0][1]
        corr_prepay_prepay_gt_s = np.corrcoef(get_first_nonzero_row(lagged_prepayment).cpu().numpy().flatten(), lagged_prepayment_gt.cpu().numpy().flatten())[0][1]
        print(f"Correlation between set variable 0 and lagged foreclosure rate: {corr_0_foreclosure}")
        print(f"Correlation between set variable 0 and lagged prepayment rate: {corr_0_prepayment}")
        #print(f"Correlation between set variable 1 and lagged foreclosure rate: {corr_1_foreclosure}")
        #print(f"Correlation between set variable 1 and lagged prepayment rate: {corr_1_prepayment}")
        print(f"Correlation between set variable 0 and ground truth lagged prepay rate: {corr_0_foreclosure_gt_s}")
        #print(f"Correlation between set variable 1 and ground truth lagged prepay rate: {corr_1_foreclosure_gt_s}")
        print(f"Correlation between lagged prepayment rate and ground truth lagged prepayment rate: {corr_prepay_prepay_gt_s}")
        correlations_0_foreclosure.append(corr_0_foreclosure)
        correlations_0_prepayment.append(corr_0_prepayment)
        correlations_1_foreclosure.append(corr_1_foreclosure)
        correlations_1_prepayment.append(corr_1_prepayment)
        corr_0_foreclosure_gt.append(corr_0_foreclosure_gt_s)
        corr_1_foreclosure_gt.append(corr_1_foreclosure_gt_s)
        corr_foreclosure_foreclosure_gt.append(corr_prepay_prepay_gt_s)
    print(f"Average correlation between set variable 0 and lagged foreclosure rate: {np.mean(correlations_0_foreclosure)}")
    print(f"Average correlation between set variable 0 and lagged prepayment rate: {np.mean(correlations_0_prepayment)}")
    print(f"Average correlation between set variable 1 and lagged foreclosure rate: {np.mean(correlations_1_foreclosure)}")
    print(f"Average correlation between set variable 1 and lagged prepayment rate: {np.mean(correlations_1_prepayment)}")
    print(f"Average correlation between set variable 0 and ground truth lagged prepay rate: {np.mean(corr_0_foreclosure_gt)}")
    print(f"Average correlation between set variable 1 and ground truth lagged prepay rate: {np.mean(corr_1_foreclosure_gt)}")
    print(f"Average correlation between lagged prepay rate and ground truth lagged prepay rate: {np.mean(corr_foreclosure_foreclosure_gt)}")

    pass

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

def get_first_nonzero_row_3d(tensor_3d: torch.Tensor):
    """
    Given a tensor of shape [num_loans, time_steps, embedding_dim],
    find the first loan (row index in num_loans dimension) whose entire 
    [time_steps, embedding_dim] slice is nonzero. 

    Returns that 2D slice (shape [time_steps, embedding_dim]) as a Tensor,
    or None if no such row exists.

    Example:
        If tensor_3d has shape [10, 50, 2], we look at each loan i in [0..9].
        If all of tensor_3d[i, :, :] != 0, we return tensor_3d[i].
    """
    # Flatten time_steps and embedding_dim to a single axis for the "all nonzero" check
    # shape => [num_loans, time_steps * embedding_dim]
    flat = tensor_3d.view(tensor_3d.shape[0], -1)  
    
    # Create a boolean mask where mask[i] = True if all elements in row i are nonzero
    mask = (flat != 0).all(dim=1)

    # Find the first index i for which mask[i] is True
    idx_nonzero = torch.where(mask)[0]
    if len(idx_nonzero) > 0:
        first_index = idx_nonzero[0].item()
        return tensor_3d[first_index]  # shape [time_steps, embedding_dim]
    else:
        return None  # No fully nonzero row found

def get_mean_set_vars_corelogic(test_loader, EMB_DIM, T, model, get_set_var_fn, model_layer):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    sums = np.zeros((EMB_DIM, T), dtype=np.float32)
    counts = np.zeros((EMB_DIM, T), dtype=np.float32)

    with torch.no_grad():
        for batch in test_loader:
            x, y, valid_indices, start_times = batch
            
            # We assume start_times is shape [1], so get the integer:
            start_time = start_times.item()
            x = x.to(device)  # shape [1, nr_features, nr_loans, 50]
            x, lagged_foreclosure, lagged_prepayment = x[:, :-2, :, :], x[:, -2, :, :], x[:, -1, :, :]
            # get_set_var_fn(x, model) -> shape [B, nr_loans, 50, EMB_DIM]
            set_var = get_set_var_fn(x, model, model_layer, nr_units=2500)  # [1, n_loans, 50, 2]
            
            # Move to CPU for processing with NumPy or additional Torch ops
            set_var = set_var.cpu()

            # We'll drop the batch dim => shape [n_loans, 50, 2]
            set_var_loans = set_var[0]
            # extend set var by concat of lagged foreclosure and prepayment
            lagged_foreclosure = get_first_nonzero_row(lagged_foreclosure).cpu()
            lagged_prepayment = get_first_nonzero_row(lagged_prepayment).cpu()
            #set_var_loans = torch.cat([set_var_loans, lagged_foreclosure.unsqueeze(-1), lagged_prepayment.unsqueeze(-1)], dim=1)

            active_loans = torch.sum(torch.max(x[0,:7,:,:], dim=0)[0], dim=0)
            foreclosed_loans = torch.sum(x[0,5,:,:], dim=0)
            prepaid_loans = torch.sum(x[0,0,:,:], dim=0)
            observed_foreclosure_rate = foreclosed_loans/active_loans
            observed_prepayment_rate = prepaid_loans/active_loans
            set_var_loans = torch.cat([
                set_var_loans[:-1], 
                lagged_foreclosure[1:].unsqueeze(-1), 
                lagged_prepayment[1:].unsqueeze(-1),
                observed_foreclosure_rate[:-1].unsqueeze(-1).cpu(),
                observed_prepayment_rate[:-1].unsqueeze(-1).cpu() 
                  ], dim=1)





            # Find the first nonzero row 
            first_nonzero_row = set_var_loans
            # shape => [50, EMB_DIM] or None
            
            if first_nonzero_row is None:
                # Skip if no fully nonzero row
                continue
            
            first_nonzero_row = first_nonzero_row.numpy()  # shape [50, 2]

            # Accumulate into sums/counts
            for local_t in range(first_nonzero_row.shape[0]):  # 0..49
                global_t = start_time + local_t
                if global_t >= T:
                    break
                #if global_t > 370:
                # first_nonzero_row[local_t] => shape (2,)
                sums[:, global_t] += first_nonzero_row[local_t]
                counts[:, global_t] += 1

    # Avoid division by zero
    counts = np.maximum(counts, 1e-6)

    # Compute the final average 
    averages = sums / counts  # shape [2, T]
    return averages


def plot_set_vars(
    test_loader,
    model,
    get_set_var_fn,
    save_path="./visualizations/set_var_over_test_may19.pdf",
    model_layer=0
):
    """
    Plots the average "first nonzero row" of the set variables over the entire test timeline.
    For each batch, set variables are extracted and averaged over the global timeline.
    
    Then, for the averaged set variables:
      - The normalized set variables (rows 0 and 1) are rotated to find the best positive correlation
        with foreclosure rates (rows 2 and 4).
      - The x-axis is converted from a global month index (starting at January 1988) to yearly ticks.
    
    Args:
        test_loader: Data loader for the test set.
        model: The model instance.
        get_set_var_fn: A function that extracts set variables given an input.
        save_path (str): Where to save the final plot (PDF by default).
        model_layer (int): The layer to use when extracting set variables.
    """
    # Number of months in your entire dataset:
    T = 420
    # Dimensionality of the set variables:
    EMB_DIM = 6

    # Extract or compute the averaged set variables:
    averages = get_mean_set_vars_corelogic(
        test_loader, EMB_DIM, T, model, get_set_var_fn, model_layer
    )
    
    # Find first and last nonzero indices (using row 0 for slicing).
    first_nonzero = np.argmax(np.abs(averages) > 0, axis=1)  # shape: (EMB_DIM,)
    last_nonzero = first_nonzero[0] + np.nonzero(averages[0, first_nonzero[0]:])[0][-1]
    averages = averages[:, first_nonzero[0] : last_nonzero + 1]

    # Standardized foreclosure rates (example: row 2 is lagged, row 4 is observed)
    standardized_foreclosure_lagged = (
        (averages[2] - np.mean(averages[2])) / np.std(averages[2])
        if np.std(averages[2]) > 0
        else averages[2]
    )
    standardized_foreclosure_obs = (
        (averages[4] - np.mean(averages[4])) / np.std(averages[4])
        if np.std(averages[4]) > 0
        else averages[4]
    )
    
    # Print correlation with row 0 just for reference
    corr_lagged = np.corrcoef(averages[0], standardized_foreclosure_lagged)[0, 1]
    print("Correlation between Set Var 0 and Lagged Foreclosure Rate:", corr_lagged)
    corr_obs = np.corrcoef(averages[0], standardized_foreclosure_obs)[0, 1]
    print("Correlation between Set Var 0 and Observed Foreclosure Rate:", corr_obs)
    
    # Normalize set variables 0 and 1
    set_0_norm = (averages[0] - np.mean(averages[0])) / np.std(averages[0])
    set_1_norm = (averages[1] - np.mean(averages[1])) / np.std(averages[1])
    
    def best_rotation_single(pred, true):
        """
        Given a 2xT matrix 'pred' (rows 0,1 are the two set variables)
        and a target vector 'true' (length T), find the angle theta that
        yields the best (positive) Pearson correlation:
           comb = cos(theta)*pred[0] + sin(theta)*pred[1].
        """
        thetas = np.linspace(0, 2 * np.pi, 361)
        best_corr = 0.0
        best_theta = 0.0
        best_comb = None
        for theta in thetas:
            comb = np.cos(theta) * pred[0] + np.sin(theta) * pred[1]
            r = pearsonr(comb, true)[0]
            if r > best_corr:
                best_corr = r
                best_theta = theta
                best_comb = comb
        return best_theta, best_corr, best_comb

    # Stack the normalized set variables (rows 0 and 1)
    pred_mat = np.vstack([set_0_norm, set_1_norm])
    
    # Find best rotation for lagged and observed foreclosure rates
    _, best_corr_lagged, rotated_lagged = best_rotation_single(pred_mat, standardized_foreclosure_lagged)
    _, best_corr_obs, rotated_obs = best_rotation_single(pred_mat, standardized_foreclosure_obs)
    
    print("Best positive correlation (Lagged Foreclosure):", best_corr_lagged)
    print("Best positive correlation (Observed Foreclosure):", best_corr_obs)
    
    # Time axis: convert global month index to years
    start_time = first_nonzero[0]
    L = averages.shape[1]
    x_months = np.arange(start_time, start_time + L)
    x_years = 1988 + x_months / 12.0
    
    # Determine integer year ticks
    min_year = int(np.floor(x_years[0]))
    max_year = int(np.ceil(x_years[-1]))
    year_ticks = np.arange(min_year, max_year + 1)

    # -------------------------------------------------------------
    # Create and style the figure
    # -------------------------------------------------------------
    fig = plt.figure(figsize=(6, 4))  # Typical single-column figure
    ax = plt.gca()

    # Plot: rotated set variable (lagged) and foreclosure (lagged)
    ax.plot(
        x_years,
        rotated_lagged,
        label=f"Set Variable, r={best_corr_lagged:.2f}",
        color="C0",
    )
    ax.plot(
        x_years,
        standardized_foreclosure_lagged,
        label="Lagged Foreclosure Rate",
        color="C1",
    )
    
    # Example: you could also plot the observed version similarly:
    # ax.plot(
    #     x_years,
    #     rotated_obs,
    #     label=f"Rotated (Observed), r={best_corr_obs:.2f}",
    #     color="C2",
    # )
    # ax.plot(
    #     x_years,
    #     standardized_foreclosure_obs,
    #     label="Observed Foreclosure Rate",
    #     color="C3",
    # )

    ax.set_title("Set Variable vs. Foreclosure Rate", fontsize=16)
    ax.set_xlabel("Year", fontsize=14)
    ax.set_ylabel("Normalized Value", fontsize=14)

    # Place one tick per year, rotated if you prefer
    ax.set_xticks(year_ticks)
    ax.set_xticklabels([str(yr) for yr in year_ticks], rotation=45)
    
    ax.legend(loc="best")

    # Minor ticks on (if desired)
    ax.minorticks_on()
    # (We are currently only drawing major grid lines via rcParams.)

    # Tight layout often helps in papers
    plt.tight_layout()

    # -------------------------------------------------------------
    # Save as PDF; you can also specify another path if desired
    # -------------------------------------------------------------
    plt.savefig(save_path, bbox_inches="tight")  # PDF output
    plt.show()


def get_config(task):

    if task == "corelogic":
        dataset_config  =  {
            
        "path_origination": "/share/data/llm_mortgages/data/filtered_origination_data_top_4_zips.csv",
        "path_performance": "/share/data/llm_mortgages/data/filtered_performance_data_top_4_zips.csv",
        "normalize_data": True,
        "database_size": 300000,
        "start_year": 1988, # correct start year
        "end_year": 2023, # correct end year
        "columns_to_normalize_origination": [
            "fico_score_at_origination", 
            "original_balance", 
            "initial_interest_rate", 
            "original_ltv"
            ],
        "columns_to_normalize_performance": [
            "current_balance", 
            "current_interest_rate", 
            "scheduled_monthly_pi",
            "scheduled_principal",
            "mba_days_delinquent"
            ],
        "feature_set": [
            "current_state", #8
            'fico_score_at_origination', # 1
            "original_balance",  # 1
            "initial_interest_rate", # 1 
            "original_ltv", # 1
            "unemployment_rate",  # 1
            "national_mortgage_rate", # 1 
            "current_balance",  # 2
            "current_interest_rate", # 2 
            "scheduled_monthly_pi", # 2
            "scheduled_principal", # 2 
            "mba_days_delinquent", # 2
            "inferred_collateral_type", # 2
            "convertible_flag", # 2
            "pool_insurance_flag", # 2
            "io_flag", # 2
            "prepay_penalty_flag", # 2
            "negative_amortization_flag", # 2
            "buydown_flag", # 2
            "loan_age", # 4
            "original_term", # 2
            "times_30dd", # 2
            "times_60dd", # 1
            "times_90dd", # 1
            "times_current", # 1
            "times_foreclosure", # 1
            "lagged_foreclosure_rate", # 1 or 2
            "lagged_prepayment_rate", # 1 or 2
            ], # total 52 # or 54 
        "nr_classes": 8,
        "verbose": True,
    }

        config = {
            "_name_": "corelogic_loan_dataset",
            "dataset_config": dataset_config,
            "val_split": 0.1,
            "test_split": 0.35,
            "val_split_date": "2009-06",
            "test_split_date": "2009-12",
            "load_data": True  ,
            "save_data": False,
            "data_path": f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_52.npz",  # is this correct data? May be 52?!
            "max_to_sample": 4500, # Total nr loans
            "nr_sampling_timesteps": 50,
            "nr_loans_to_sample": 2500, #2500 #500, #6000 Actually we have 30000 loans, so 60000 samples per epoch is about right
            "steps_per_epoch": 50, #200,  # 20
            "sample_random_loan_index": True,
            "sample_random_time_index": True,
            "eval_mode": False
        }
        return config
    else:

        nr_partial_obs_hawkes = [1,2,5,10,50,100,200,500,1000]
        config = {
            "_name_": "timeseries_synthetics",
            "num_states": 3, #total number of states
            "num_terminal_states": 1, #number of terminal states
            "use_feature": True, #If false, will not include the macro variable as a feature, and also not include the loan specific features
            "simulation_steps": 100, #length of each sequence
            "loan_pool_size": 1000, #pool size
            "load_saved_data": False,
            "saved_data_directory": f"{BASE_PATH}/data/mortgage_new2/",
            "save_data": True,
            "num_seq": 20, #number of sequences 20
            "val_split": 0.1, #fraction of samples in the validation split
            "test_split": 0.1, #fraction of samples in the test split
            "dataset_name": "timeseries_synthetics",
            "nr_steps": 10, #number of different starting points
            "forecasting": False,
            "n_obs_partial_hawkes":  nr_partial_obs_hawkes, #[0,1,2,5,10, 50,100,200, 500,1000], # 1,  5, 10, 20, 100
            "partial_obs_method": "partially_observed_hawkes", #partially_observed_hawkes_kalman
            "use_random_input_size": False,
            "random_input_size_options": [2,5,10,50,100,200,500,1000], #[1,2,5,10,50,100,200,500,1000],
            "random_input_size_probabilities": [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.3], #[
            "forecasting_horizon": 1,
            "lookback_horizon": 1,
            "generator": {
                
                "level": "supereasy_2d", #supereasy_2d
                "path_dependency_dimension": 2,
                "h_kappa": 100,  # 0  updated to control the path dependency
                "mu": 0.001,
                "alpha": 0.004,
                "beta": 0.5,
                "debug": False, #if true the dynamics are simplified
                "hard": True, #only used if debug=True. If False, the dynamics are deterministic
            }
        }
        config["load_saved_data"] = True
        return config
def get_second_set_var(hidden_var):
    j = 1
    while torch.all(hidden_var[j,:] == hidden_var[0,:]):
        j+=1
    return hidden_var[j,:]


def normalize(x):
    """
    Z-score normalize a 1D array.
    """
    mean = np.mean(x)
    std = np.std(x)
    if std == 0:
        return x - mean
    return (x - mean) / std

def best_rotation_separate(pred, true1, true2):
    """
    Optimizes the rotation for each predicted latent factor separately.
    
    For the first latent factor, we search for theta1 that maximizes the absolute Pearson correlation 
    between comb1 = cos(theta1)*pred[0,:] + sin(theta1)*pred[1,:] and true1.
    
    For the second latent factor, we search for theta2 that maximizes the absolute Pearson correlation 
    between comb2 = -sin(theta2)*pred[0,:] + cos(theta2)*pred[1,:] and true2.
    
    Args:
        pred (np.ndarray): Normalized predicted latent factors of shape (2, T).
        true1 (np.ndarray): Normalized true latent factor 1 of shape (T,).
        true2 (np.ndarray): Normalized true latent factor 2 of shape (T,).
        
    Returns:
        best_theta1 (float): The best rotation angle (radians) for latent factor 1.
        best_r1 (float): Pearson correlation (for best_theta1) for latent factor 1.
        best_comb1 (np.ndarray): Rotated prediction for latent factor 1.
        best_theta2 (float): The best rotation angle (radians) for latent factor 2.
        best_r2 (float): Pearson correlation (for best_theta2) for latent factor 2.
        best_comb2 (np.ndarray): Rotated prediction for latent factor 2.
    """
    thetas = np.linspace(0, 2 * np.pi, 5, endpoint=True)  # 1° resolution #5

    # Optimize rotation for latent factor 1.
    best_score1 = -np.inf
    best_theta1 = None
    best_r1 = None
    best_comb1 = None

    for theta in thetas:
        comb1 = np.cos(theta) * pred[0, :] + np.sin(theta) * pred[1, :]
        r1, _ = pearsonr(comb1, true1)
        if r1 > best_score1:
            best_score1 = r1
            best_theta1 = theta
            best_r1 = r1
            best_comb1 = comb1

    # Optimize rotation for latent factor 2.
    best_score2 = -np.inf
    best_theta2 = None
    best_r2 = None
    best_comb2 = None

    for theta in thetas:
        comb2 = -np.sin(theta) * pred[0, :] + np.cos(theta) * pred[1, :]
        r2, _ = pearsonr(comb2, true2)
        if r2 > best_score2:
            best_score2 = r2
            best_theta2 = theta
            best_r2 = r2
            best_comb2 = comb2

    return best_theta1, best_r1, best_comb1, best_theta2, best_r2, best_comb2

def plot_hidden_var_synthetic(X, hidden_var, model, layer=0):
    """
    Plots the normalized set variables (predicted and true latent factors) along with the optimally rotated predictions.
    
    For layer==0:
      - The top-level encoder (model.encoder[0]) is used with manual transposition.
    For layer>=1:
      - x is fed through the encoder's forward method, then through the residual blocks up to (layer-1),
        and finally the set module from the (layer-1)th block is used.
    
    All set variables are normalized before the rotation search.
    
    Args:
        X (Tensor): Input tensor of shape [B, nr_features, nr_loans, nr_timesteps].
        hidden_var (Tensor): True latent factors of shape [B, nr_units, nr_timesteps].
        model: The model instance.
        layer (int): Determines which set module is used (see above).
    """
    nr_samples, feature_dim, nr_units, nr_timesteps = X.shape
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model._state = None
    model.to(device)
    model.eval()

    with torch.no_grad():
        for i in range(nr_samples):
            x = X[i:i+1, :, :, :].to(device)
            # Get the predicted set variable from the specified layer.
            # Assumes get_set_var is defined as in your previous examples.
            pred_set_var = get_set_var(x, model, layer=layer)[0, :, :]  # shape: (2, T)
            
            # True latent factors.
            true_set_var1 = hidden_var[i, 0, :]                       # shape: (T,)
            true_set_var2 = get_second_set_var(hidden_var[i, :, :])     # shape: (T,)

            # Move to CPU and convert to NumPy arrays.
            pred_set_var = pred_set_var.cpu().numpy().T  # shape: (2, T)
            true_set_var1 = true_set_var1.cpu().numpy()   # shape: (T,)
            true_set_var2 = true_set_var2.cpu().numpy()   # shape: (T,)

            # Normalize the predicted set variables and true latent factors.
            pred_norm = np.zeros_like(pred_set_var)
            pred_norm[0, :] = normalize(pred_set_var[0, :])
            pred_norm[1, :] = normalize(pred_set_var[1, :])
            true_set_var1_norm = normalize(true_set_var1)
            true_set_var2_norm = normalize(true_set_var2)

            # Perform independent rotation searches.
            (best_theta1, r1, rotated_pred1, 
             best_theta2, r2, rotated_pred2) = best_rotation_separate(
                pred_norm, true_set_var1_norm, true_set_var2_norm
            )
            
            print(f"Sample {i} (Layer {layer}):")
            print(f"  Best theta1 for latent factor 1: {best_theta1:.2f} rad, correlation: {r1:.2f}")
            print(f"  Best theta2 for latent factor 2: {best_theta2:.2f} rad, correlation: {r2:.2f}")

            # Plot the normalized predictions and true latent factors.
            plt.figure(figsize=(12, 10))
            
            # Plot normalized original predicted set variables.
            plt.subplot(3, 1, 1)
            plt.plot(pred_norm[0, :], label="Normalized Pred Set Var 1")
            plt.plot(pred_norm[1, :], label="Normalized Pred Set Var 2")
            plt.title("Normalized Predicted Set Variables")
            plt.legend()
            
            # Plot normalized true latent factors.
            plt.subplot(3, 1, 2)
            plt.plot(true_set_var1_norm, label="Normalized True Latent Factor 1")
            plt.plot(true_set_var2_norm, label="Normalized True Latent Factor 2")
            plt.title("Normalized True Latent Factors")
            plt.legend()
            
            # Plot rotated predictions.
            plt.subplot(3, 1, 3)
            plt.plot(rotated_pred1, label=f"Rotated Pred 1 (theta={best_theta1:.2f}, r={r1:.2f})")
            plt.plot(rotated_pred2, label=f"Rotated Pred 2 (theta={best_theta2:.2f}, r={r2:.2f})")
            plt.title("Independently Rotated Predictions (Normalized)")
            plt.legend()

            plt.suptitle(f"Latent Factors and Rotated Predictions for Sample {i}, Layer {layer}")
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.savefig(f"{BASE_PATH}/scripts/notebooks/true_loss_level/visualizations/set_var_synthetic_sample_{0}_layer_{layer}.png")
            plt.close()
            
            # Optionally, pause here to inspect one sample interactively.
            

def create_correlation_table(nr_obs_list, layer_range, X, hidden_var, model, n_draws=10):
    """
    Creates a correlation table with rows corresponding to layers (from layer_range)
    and columns corresponding to the number of observations (from nr_obs_list) used
    when subsampling the dimension=2 (nr_units) of X.

    For each combination of layer and n_obs:
      - Randomly subsample n_obs units from the available ones.
      - Compute the predicted set variable (using get_set_var) from the subsampled X.
      - Similarly, subsample the corresponding true latent factors from hidden_var.
      - For the true latent factors, we use:
            true1 = hidden_sub[0, :]    (i.e. the first “unit”)
            true2 = get_second_set_var(hidden_sub) 
      - Both the predicted and true latent factors (each a vector of length T) are normalized.
      - We then run a rotation search (using best_rotation_separate) which independently
        optimizes the rotation angle for each predicted latent channel and picks the one
        with the highest (absolute) correlation.
      - This “best correlation” is computed for each sample in the batch; we average over
        samples and then over n_draws.
      
    The resulting table is a NumPy array with shape (len(layer_range), len(nr_obs_list)).
    A raw LaTeX tabular string is generated and saved to "correlation_table.tex".
    
    Args:
        nr_obs_list (list of int): List of numbers of observations to sample (e.g., [1,2,5,10,...]).
        layer_range (list of int): List of layer indices (e.g., [0,1,2,3,4,5]).
        X (Tensor): Input tensor of shape [B, nr_features, nr_units, nr_timesteps].
        hidden_var (Tensor): True latent factors of shape [B, nr_units, nr_timesteps].
        model: The model instance.
        n_draws (int): Number of random subsampling draws to average over.
    
    Returns:
        correlation_table (np.ndarray): Array of shape (len(layer_range), len(nr_obs_list)) containing the average correlation.
        latex_str (str): A raw LaTeX table (as a string) that you can copy into Overleaf.
    """
    # Get dimensions. (Assume X has shape [B, nr_features, nr_units, T].)
    B, F, total_units, T = X.shape
    correlation_table = np.zeros((len(layer_range), len(nr_obs_list)))
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    with torch.no_grad():
        # Loop over each layer in layer_range (each row in the table)
        for li, layer in enumerate(layer_range):
            # Loop over each n_obs (each column in the table)
            for oi, n_obs in enumerate(nr_obs_list):
                # If there are not enough units to subsample, record NaN.
                if n_obs > total_units:
                    correlation_table[li, oi] = np.nan
                    continue

                draw_corrs = []  # To store average correlations over draws.
                for draw in range(n_draws):
                    sample_corrs = []  # To store correlations for each sample in the batch.
                    for i in range(B):
                        # Randomly select n_obs indices from 0 to total_units-1
                        indices = np.random.choice(total_units, size=n_obs, replace=False)
                        indices = np.sort(indices)  # sort for consistency
                        
                        # Subsample along dimension=2 for X and dimension=1 for hidden_var.
                        # x_sub: shape [1, F, n_obs, T]
                        x_sub = X[i:i+1, :, indices, :]
                        # h_sub: shape [n_obs, T]
                        h_sub = hidden_var[i, indices, :]
                        
                        # Get the predicted set variable from the subsampled input.
                        # get_set_var is assumed to accept inputs of shape [B, F, n_obs, T]
                        pred_set_var = get_set_var(x_sub, model, layer=layer, nr_units=len(indices))  # shape: [1, 2, T]
                        # Remove the batch dimension: now shape is [2, T]
                        pred_set_var = pred_set_var[0, :, :]
                        
                        # For true latent factors from hidden_var, we mimic our previous approach:
                        # Use the first unit as latent factor 1, and use get_second_set_var for latent factor 2.
                        true1 = h_sub[0, :]  # shape: [T,]
                        true2 = get_second_set_var(h_sub)  # shape: [T,]
                        
                        # Convert to NumPy arrays (if not already) and normalize.
                        pred_np = pred_set_var.cpu().numpy().T if torch.is_tensor(pred_set_var) else pred_set_var
                        true1_np = true1.cpu().numpy() if torch.is_tensor(true1) else true1
                        true2_np = true2.cpu().numpy() if torch.is_tensor(true2) else true2
                        
                        # Normalize each predicted latent factor and each true latent factor.
                        pred_norm = np.zeros_like(pred_np)
                        pred_norm[0, :] = normalize(pred_np[0, :])
                        pred_norm[1, :] = normalize(pred_np[1, :])
                        true1_norm = normalize(true1_np)
                        true2_norm = normalize(true2_np)
                        
                        # Use best_rotation_separate to optimize the rotation angles independently.
                        # This returns:
                        #   best_theta1, r1, rotated_pred1, best_theta2, r2, rotated_pred2
                        best_theta1, r1, _, best_theta2, r2, _ = best_rotation_separate(pred_norm, true1_norm, true2_norm)
                        
                        # Choose the latent channel that gives the higher absolute correlation.
                        best_corr = r1 if abs(r1) >= abs(r2) else r2
                        sample_corrs.append(abs(best_corr))
                    
                    # Average correlation over the batch samples for this draw.
                    draw_corrs.append(np.mean(sample_corrs))
                # Average over the n_draws for this combination of layer and n_obs.
                avg_corr = np.mean(draw_corrs)
                correlation_table[li, oi] = avg_corr

    # Build a LaTeX table string.
    # Precompute the best and second best values in each column.
    ncols = len(nr_obs_list)
    best_vals = np.empty(ncols)
    second_vals = np.empty(ncols)
    for j in range(ncols):
        col = correlation_table[:, j]
        if np.all(np.isnan(col)):
            best_vals[j] = np.nan
            second_vals[j] = np.nan
        else:
            best_val = np.nanmax(col)
            best_vals[j] = best_val
            # For second best, replace all occurrences of the best value with -infinity.
            col_copy = col.copy()
            col_copy[np.abs(col - best_val) < 1e-6] = -np.inf
            second_val = np.nanmax(col_copy)
            second_vals[j] = second_val

    # Build the LaTeX table string with best values in bold and second best underlined.
    latex_str = ""
    latex_str += "\\begin{table}[ht]\n"
    latex_str += "\\centering\n"
    latex_str += ("\\caption{Average correlation between the set variable at layer \\textit{i} " 
                  "and the true latent variable. Best values in each column are in bold, "
                  "with the second best underlined.}\n")
    latex_str += "\\label{tab:interpretability}\n"
    latex_str += "\\begin{tabular}{l" + "c" * len(nr_obs_list) + "}\n"
    latex_str += "\\toprule\n"
    
    # Header row.
    header = "Layer / n\\_obs"
    for n_obs in nr_obs_list:
        header += f" & {n_obs}"
    header += " \\\\\n"
    latex_str += header
    latex_str += "\\midrule\n"
    
    # Data rows.
    for li, layer in enumerate(layer_range):
        row = f"{layer}"
        for oi in range(len(nr_obs_list)):
            corr_val = correlation_table[li, oi]
            if np.isnan(corr_val):
                cell = "-"
            else:
                formatted_val = f"{corr_val:.3f}"
                # Check if this cell is best or second best in the column.
                if np.abs(corr_val - best_vals[oi]) < 1e-6:
                    cell = f"\\textbf{{{formatted_val}}}"
                elif np.abs(corr_val - second_vals[oi]) < 1e-6:
                    cell = f"\\underline{{{formatted_val}}}"
                else:
                    cell = formatted_val
            row += f" & {cell}"
        row += " \\\\\n"
        latex_str += row
    latex_str += "\\bottomrule\n"
    latex_str += "\\end{tabular}\n"
    latex_str += "\\end{table}"
    # Save the LaTeX table to a file.
    with open(f"{BASE_PATH}/scripts/notebooks/true_loss_level/visualizations/correlation_table_synthetic.tex", "w") as f:
        f.write(latex_str)
    
    return correlation_table, latex_str


def plot_different_examples(nr_obs_list, layer_range, hidden_var, X, model):
    """
    Creates a figure with three horizontally arranged subplots. Each subplot corresponds 
    to a different number of observed units (from nr_obs_list, which must have length==3).
    
    For each subplot:
      - A subsample is taken from X and hidden_var along the unit dimension using the given number.
      - The true latent factor is extracted as the first unit:
            true1 = hidden_sub[0, :]
      - The second true latent factor is computed by get_second_set_var (used only for rotation).
      - For each layer in layer_range (a list of 2 layers, e.g. [3, 5]), the predicted set variable 
        is computed via get_set_var (with nr_units=n_obs), normalized, and rotated (using best_rotation_separate)
        to select the channel with the highest absolute correlation with the true latent factor.
      - Each subplot shows:
            - The true latent factor 1 (solid blue line)
            - The predicted latent factor (solid line) for layer layer_range[0] (green)
            - The predicted latent factor (solid line) for layer layer_range[1] (red)
      - Each subplot has a local legend (in the upper right) displaying only the correlation values:
            green: r = <value> (for layer layer_range[0])
            red:   r = <value> (for layer layer_range[1])
    A global legend below the figure shows:
            True Latent Factor 1 (blue),
            Set Variable 1 at Layer <layer_range[0]> (green),
            Set Variable 1 at Layer <layer_range[1]> (red).
    
    The figure is saved as a high-quality PDF.
    
    Args:
        nr_obs_list (list of int): List of three numbers of observed units (e.g. [5, 50, 200]).
        layer_range (list of int): List with two layer indices (e.g. [3, 5]).
        hidden_var (Tensor): True hidden variables of shape [B, nr_units, nr_timesteps].
        X (Tensor): Input tensor of shape [B, nr_features, nr_units, nr_timesteps].
        model: The model instance.
    """
    # Assert that inputs have the expected lengths.
    assert len(nr_obs_list) == 3, "nr_obs_list must have exactly 3 elements."
    assert len(layer_range) == 2, "layer_range must have exactly 2 elements."
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # For illustration, use the first sample in the batch.
    sample_index = 0
    B, F, total_units, T = X.shape

    # Prepare the figure with 1 row and 3 columns.
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

    # Define colors.
    color_true1 = 'blue'
    color_pred_layer0 = 'green'
    color_pred_layer1 = 'red'
    
    with torch.no_grad():
        for i, n_obs in enumerate(nr_obs_list):
            ax = axes[i]

            # Randomly subsample n_obs units from total_units.
            indices = np.random.choice(total_units, size=n_obs, replace=False)

            # Subsample the input X and hidden_var for sample_index.
            # x_sub: shape [1, F, n_obs, T]
            x_sub = X[sample_index:sample_index+1, :, indices, :]
            # h_sub: shape [n_obs, T]
            h_sub = hidden_var[sample_index, indices, :]

            # Compute true latent factor: take the first unit.
            true1 = h_sub[0, :]  
            # Compute second latent factor (for rotation only).
            true2 = get_second_set_var(h_sub)
            
            # Convert to NumPy and normalize.
            true1_np = true1.cpu().numpy() if torch.is_tensor(true1) else true1
            true2_np = true2.cpu().numpy() if torch.is_tensor(true2) else true2
            true1_norm = normalize(true1_np)
            true2_norm = normalize(true2_np)

            # Dictionary to store the best predicted rotated output for each layer.
            pred_best = {}

            # For each of the two layers in layer_range, compute the predicted set variable.
            for layer in layer_range:
                # Get the predicted set variable from the specified layer.
                # Expected shape: [1, 2, T]
                pred_set_var = get_set_var(x_sub.to(device), model, layer=layer, nr_units=n_obs)
                pred_set_var = pred_set_var[0, :, :]  # shape: (2, T)
                pred_np = pred_set_var.cpu().numpy().T  # shape (2, T)

                # Normalize each channel of the prediction.
                pred_norm = np.zeros_like(pred_np)
                pred_norm[0, :] = normalize(pred_np[0, :])
                pred_norm[1, :] = normalize(pred_np[1, :])

                # Optimize rotation for each channel.
                # This returns: best_theta1, r1, rotated_pred1, best_theta2, r2, rotated_pred2.
                best_theta1, r1, rotated_pred1, best_theta2, r2, rotated_pred2 = best_rotation_separate(
                    pred_norm, true1_norm, true2_norm
                )
                # Select the channel with the higher absolute correlation.
                if abs(r1) >= abs(r2):
                    best_pred = rotated_pred1
                    best_corr = r1
                else:
                    best_pred = rotated_pred2
                    best_corr = r2
                pred_best[layer] = (best_pred, best_corr)

            # Time axis.
            t = np.arange(T)

            # Plot the true latent factor.
            ax.plot(t, true1_norm, color=color_true1, lw=3)
            # Plot predicted curves (solid lines).
            layer0 = layer_range[0]
            best_pred0, corr0 = pred_best[layer0]
            ax.plot(t, best_pred0, color=color_pred_layer0, lw=3)
            layer1 = layer_range[1]
            best_pred1, corr1 = pred_best[layer1]
            ax.plot(t, best_pred1, color=color_pred_layer1, lw=3)

            # Enable grid lines.
            ax.grid(True, linestyle='-', color='gray', alpha=0.5)

            # Set subplot title and axis labels.
            ax.set_title(f"Number of Observed Units = {n_obs}", fontsize=18)
            ax.set_xlabel("Time", fontsize=16)
            if i == 0:
                ax.set_ylabel("Normalized Value", fontsize=16)
            ax.tick_params(axis='both', labelsize=14)

            # Create a local legend for the predicted curves showing only the correlation values.
            local_handles = [
                Line2D([0], [0], color=color_pred_layer0, lw=3),
                Line2D([0], [0], color=color_pred_layer1, lw=3)
            ]
            local_labels = [f"r={corr0:.2f}", f"r={corr1:.2f}"]
            ax.legend(local_handles, local_labels, fontsize=14, loc='upper right')
        
        # Create a global legend below the figure.
        global_handles = [
            Line2D([0], [0], color=color_true1, lw=3),
            Line2D([0], [0], color=color_pred_layer0, lw=3),
            Line2D([0], [0], color=color_pred_layer1, lw=3)
        ]
        global_labels = [
            "True Latent Factor 1", 
            f"Set Variable 1 at Layer {layer_range[0]}",
            f"Set Variable 1 at Layer {layer_range[1]}"
        ]
        fig.legend(global_handles, global_labels, loc='lower center', ncol=3, fontsize=16)

        fig.tight_layout(rect=[0, 0.08, 1, 0.95])
        
        # Save the figure as a high-quality PDF.
        fig.savefig(f"{BASE_PATH}/scripts/notebooks/true_loss_level/visualizations/plot_different_examples_synthetic.pdf",
                    dpi=300, format='pdf', bbox_inches='tight')
        plt.close(fig)


def main():
    
    task = 'synthetic'
    #task = 'corelogic'
    assert task in ['corelogic', 'synthetic']
    config = get_config(task)

    if task == 'corelogic':
        dataset = "test"
        date = "jan30"
        nr_features = 50
        name = dataset+"-"+date+"-"+str(nr_features)
    elif task == 'synthetic':
        #name = 'logunif008-best-set'
        name = 'synthetic_task'
    
    
    train, val, test = get_dataset(config, task)
    print("Data Loaded")
    model_config = get_model_config(name=name)
    if task == 'corelogic':
        model = load_model_corelogic(**model_config)
    else:
        model = load_model(**model_config)
    print("Model loaded")
    if task == 'synthetic':
        
        
        nr_obs_list = [50,200,1000]
        layer_range = [3,5]
        plot_different_examples(nr_obs_list, layer_range, test.hidden_path_var, test.X, model)
        breakpoint()
        nr_obs_list = [20,50,100,200,500,1000]
        layer_range = [0,1,2,3,4,5]
        create_correlation_table(nr_obs_list, layer_range, test.X, test.hidden_path_var, model, n_draws=100)
        breakpoint()
        plot_hidden_var_synthetic(train.X, train.hidden_path_var, model, layer=3)
    
    if task == 'corelogic':
        test.steps_per_epoch = 100
        train.steps_per_epoch = 300
        test.return_start_time = True
        train.return_start_time = True
        val.return_start_time = True
        combined_dataset = ConcatDataset([train, val, test])
        

        dataloader = DataLoader(
            combined_dataset, 
            batch_size=1, 
            shuffle=False, 
            drop_last=True
        )
        layer = 1
        plot_set_vars(
        dataloader,
        model, 
        get_set_var, 
        save_path=f"./visualizations/corelogic_set_var_over_test_{layer}.pdf",
        model_layer=layer
        )
        breakpoint()
        test.return_start_time = False
        plot_set_var(train, val, test, model)

if __name__ == "__main__":
    main()

    