import sys
import numpy as np
import torch
import os 
BASE_PATH = os.environ.get("BASE_PATH", "")
sys.path.append(BASE_PATH)
from scripts.notebooks.true_loss_level.get_corelogic import get_model_config
from scripts.notebooks.true_loss_level.get_corelogic import load_model_corelogic
from scripts.notebooks.true_loss_level.get_corelogic import get_dataset
from scripts.notebooks.true_loss_level.get_corelogic import evaluate_model_cl
from scripts.notebooks.true_loss_level.get_corelogic import make_prediction
from scripts.notebooks.true_loss_level.get_corelogic import get_metrics_cl
from scripts.notebooks.true_loss_level.get_corelogic import zero_out_time_range
from torch.utils.data import DataLoader, ConcatDataset

SEED_NR = 42
def evaluate_model_cl(
    model,         # logistic model, set-seq model, or np.ndarray (transition matrix)
    model_name, 
    val_set, 
    test_set=None, 
    batch_size=1,
    fix_seed=True,
    all_units_in_batch_dim=False
):
    """
    Evaluates the given model on val_set (and optionally test_set),
    returning predictions and ground-truth, all as PyTorch Tensors.
    """
    if fix_seed:
        torch.manual_seed(SEED_NR)
        np.random.seed(SEED_NR)
    # Merge val/test sets if needed
    if test_set is not None:
        combined_dataset = ConcatDataset([val_set, test_set])
    else:
        combined_dataset = val_set

    dataloader = DataLoader(
        combined_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        drop_last=True
    )
    all_y_pred = []
    all_y_true = []
    all_x = []

    for batch in dataloader:
        x, y, valid_indices = batch  # x and y are PyTorch tensors, on CPU by default (unless you pinned them or moved them).

        # If we're using a "set-seq" model on GPU, we might want to move x, y to GPU here:
        # But you can also do that inside the model or pass them as is. For example:
        if model_name == "set-seq":
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            x = x.to(device)
            y = y.to(device)

        # Now let make_prediction handle conversions
        y_pred_torch = make_prediction(
            model=model,
            model_name=model_name,
            x_batch=x,
            y_batch=y,
            all_units_in_batch_dim=all_units_in_batch_dim
        )
        start_time = valid_indices[0][0]
        end_time = valid_indices[1][0]
        if y.shape != y_pred_torch.shape:
            print(f"Shapes of y and y_pred_torch do not match: {y.shape} vs {y_pred_torch.shape}")
            breakpoint()
        # make y_pred_torch and y be in state 7 outside of the sampling bounds
        y_pred_torch, y = zero_out_time_range(y_pred_torch, y, start_time, end_time)
        
        # y_pred_torch has shape [B, n_loans, n_time_steps, n_classes]

        all_y_pred.append(y_pred_torch)
        all_y_true.append(y)
        all_x.append(x)
    
    # Concatenate along the batch dimension
    y_pred_concat = torch.cat(all_y_pred, dim=0)  # shape [N, n_loans, n_time_steps, n_classes]
    y_true_concat = torch.cat(all_y_true, dim=0)
    x_concat = torch.cat(all_x, dim=0)
    
    return y_pred_concat, y_true_concat, x_concat

#from src.aico import AICO
def compute_aico(model, test):

    # =========== Evaluate Set-Seq ============
        y_pred_seq, y_true_seq, x = evaluate_model_cl(
            model=model,
            model_name='set-seq',
            val_set=test,
            test_set=None,
            batch_size=1,
            fix_seed=False,  # We'll handle seeds ourselves
            all_units_in_batch_dim=True
        )
        # Slice off last time-step from predictions, first from labels
        y_pred = y_pred_seq[:, :, :-1, :].cpu().numpy()
        y_true = y_true_seq[:, :, 1:, :].cpu().numpy()  #
        #x = x.cpu().numpy()
        x = x[:,:,:,:-1]
        x = x.permute(0,2,1,3)
        x = x.permute(0,1,3,2)
        x = x.cpu().numpy() 
        # Stack the first three dimensions together
        breakpoint()
        x = x.reshape(x.shape[0]*x.shape[1]*x.shape[2], x.shape[3])
        y_true = y_true.reshape(y_true.shape[0]*y_true.shape[1]*y_true.shape[2], y_true.shape[3])
        y_pred = y_pred.reshape(y_pred.shape[0]*y_pred.shape[1]*y_pred.shape[2], y_pred.shape[3])
        # x shape: (nr_samples, nr_features, nr_loans, nr_timesteps) = (21, 52, 2500, 50)
        # y_{} pred/true shape: (nr_samples, nr_loans, nr_timesteps, nr_classes) = (21, 2500, 49, 8)

        metrics_set_seq_test = get_metrics_cl(y_true, y_pred, name='set-seq')
        average_auc = metrics_set_seq_test['Avg. AUC']
        print(f"Average AUC: {average_auc}")
        breakpoint()


def main():

    dataset_config  =  {
            
            "path_origination": "/share/data/llm_mortgages/data/filtered_origination_data_top_4_zips.csv",
            "path_performance": "/share/data/llm_mortgages/data/filtered_performance_data_top_4_zips.csv",
            "normalize_data": True,
            "database_size": 300000, #300000, #300000, #300000,
            "start_year": 1988, # correct start year
            "end_year": 2023, # correct end year
            "columns_to_normalize_origination": [
                "fico_score_at_origination", 
                "original_balance", 
                "initial_interest_rate", 
                "original_ltv"
                ],
            "columns_to_normalize_performance": [
                "current_balance", 
                "current_interest_rate", 
                "scheduled_monthly_pi",
                "scheduled_principal",
                "mba_days_delinquent"
                ],
            "feature_set": [
                "current_state", #8
                'fico_score_at_origination', # 1
                "original_balance",  # 1
                "initial_interest_rate", # 1 
                "original_ltv", # 1
                "unemployment_rate",  # 1
                "national_mortgage_rate", # 1 
                "current_balance",  # 2
                "current_interest_rate", # 2 
                "scheduled_monthly_pi", # 2
                "scheduled_principal", # 2 
                "mba_days_delinquent", # 2
                "inferred_collateral_type", # 2
                "convertible_flag", # 2
                "pool_insurance_flag", # 2
                "io_flag", # 2
                "prepay_penalty_flag", # 2
                "negative_amortization_flag", # 2
                "buydown_flag", # 2
                "loan_age", # 4
                "original_term", # 2
                "times_30dd", # 2
                "times_60dd", # 1
                "times_90dd", # 1
                "times_current", # 1
                "times_foreclosure", # 1
                "zip-code", # 5
                "lagged_foreclosure_rate", # 1
                "lagged_prepayment_rate", # 2
                ], # Total 58, Total 55 
                 
            "nr_classes": 8,
            "verbose": True,
        }

    config = {
        "_name_": "corelogic_loan_dataset",
        "dataset_config": dataset_config,
        "val_split": 0.1,
        "test_split": 0.3,
        "val_split_date": "2009-06", #"2009-06",
        "test_split_date": "2009-12", #"2009-12",
        "load_data": True,
        "save_data": False,
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_top4_zip_52.npz",  # 52 # is this correct data? May be 52?!
        "max_to_sample": 4500, #4500, #4500, # Total nr loans
        "nr_sampling_timesteps": 50,
        "nr_loans_to_sample": 2500, #500, #6000 Actually we have 30000 loans, so 60000 samples per epoch is about right
        "steps_per_epoch": 70, #200,  # 20
        "sample_random_loan_index": True,
        "sample_random_time_index": True,
        "eval_mode": True,
        "eval_seed": 3000
    }

    model_config = get_model_config(name="may_14_set_seq")
    model = load_model_corelogic(**model_config)
    train, val, test = get_dataset(config)

    compute_aico(model, test)

if __name__ == "__main__":
    main()