# Things we need to do
# 1. Define the transition structure
# The L1 loss is simple, we just need to estimate 
# 8 numbers (t1-t2,t2-t2, t1-t1, t1-t3, t2-t3, t2-t1) for x= 1 or 0
# This can be done with an empirical average
# The L2 loss is more complicated, as we also need to 
# estimate the path variable. 
# L2 loss assumes full knowledge, so take the loss
# between the true and observed transition probabilities
# L3: This is where it gets more complex. (this is also for a certain sample size)
# Sample 10000 paths (approx infinite) and estimate transition probs from a subset of them
# Then use this to compute the loss. Bootstrap with different subsets.
import copy


"""Plot the distribution of defaults for a given loan pool size, lookback
horizon and time step.
"""
import sys
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
sys.path.append(BASE_PATH)
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
from src.tasks.metrics import next_step_log_loss
from src.tasks.metrics import next_step_log_loss_GT
from src.dataloaders.dataloader_mortgage import PathGenerator, MortgageDataset, SyntheticTimeSeries, Split
import numpy as np
import torch
import copy
import matplotlib.pyplot as plt
import torch.nn.functional as F


from hydra import compose, initialize

from scripts.notebooks.multi_class_auc import get_y_true, get_true_probs, get_kalman_predictions, get_model_predictions, plot_kalman_model_comparison


def hydraload(config,path,dataset):
    from train import SequenceLightningModule
    import src.utils as utils

    #from train import preemption_setup
    config = utils.train.process_config(config)
    config.train.pretrained_model_path = path
    utils.train.print_config(config, resolve=True)

    #config = preemption_setup(config)
    config.dataset.data_dir = dataset
    
    model = SequenceLightningModule(config)
    model = SequenceLightningModule.load_from_checkpoint(
            config.train.pretrained_model_path,
            config=config,
            strict=config.train.pretrained_model_strict_load,
        )
    return model

def hydraload_corelogic(config,path,dataset):
    from train import SequenceLightningModule
    import src.utils as utils

    #from train import preemption_setup
    config = utils.train.process_config(config)
    config.train.pretrained_model_path = path
    utils.train.print_config(config, resolve=True)

    #config = preemption_setup(config)
    
    #dataset = f"{BASE_PATH}/data/corelogic/loan_data_top1_zip_52.npz"
    config.dataset.data_path = dataset
    
    model = SequenceLightningModule(config)
    model = SequenceLightningModule.load_from_checkpoint(
            config.train.pretrained_model_path,
            config=config,
            strict=config.train.pretrained_model_strict_load,
        )
    return model

def load_model_corelogic(experiment,checkpoint_path,data_path, **kwargs):
    try:
        initialize(version_base=None, config_path="./../../../configs/")
    except:
        print("Already initialized")
    cfg = compose(config_name="config.yaml",
                overrides=["experiment="+experiment])
    cfg.dataset.load_data = True # Why true
    cfg.dataset.dataset_config.database_size = 1000
    model = hydraload_corelogic(cfg, checkpoint_path, data_path)
    return model

def load_model(experiment,checkpoint_path,data_path):
    try:
        initialize(version_base=None, config_path="./../../../configs/")
    except:
        print("Already initialized")
    cfg = compose(config_name="config.yaml",
                overrides=["experiment="+experiment])
    cfg.dataset.load_saved_data = True
    cfg.dataset.num_seq = 5
    model = hydraload(cfg, checkpoint_path, data_path)
    return model

def get_optimal_loss(train_set):
    
    n_s = len(train_set)
    L = 0
    for sample in train_set:
        x, y,M,I = sample # M is the true transition probs
        #extend y and M to have 4 dimensions by adding a dimension at 0:
        y = torch.unsqueeze(y,dim=0)
        M = torch.unsqueeze(M,dim=0)
        loss = next_step_log_loss_GT(M,y,M)
        L += loss
    return (L/n_s).item()


def get_no_path_dependency_loss(train_set, validation_set):
    print("Computing no path dependency loss")
    print("Validation set size:", len(validation_set))
    
    L = 0
    all_y_pred = []
    for i in range(len(validation_set)):
        marginal_probabilities = validation_set.marginal_probabilities[i,:,:,:].cpu()
        y = validation_set.Y[i,:,:,:].cpu()
        
        loss = next_step_log_loss_GT(torch.unsqueeze(torch.tensor(marginal_probabilities),dim=0), torch.unsqueeze(y, dim=0),torch.unsqueeze(torch.tensor(marginal_probabilities),dim=0))
        
        L += loss
        all_y_pred.append(np.array(marginal_probabilities))
                
            
    return L/len(validation_set), all_y_pred

def check_causality(model, dataset, t = 20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for sample in dataset:
        x, y, M, I = sample
        x = x.requires_grad_(True).to(device)
        y = y.to(device)
        y_pred= model((x.unsqueeze(0),{}))[0]
        y_pred = y_pred.squeeze(1).to(device)
        y_pred = F.softmax(y_pred, dim=2).to(device)

        loss = ((y_pred[:, t, :] -y[:,t,:])*((y_pred[:, t, :] -y[:,t,:]))).mean()  # Use the mean of the predictions at time step t as loss

        # Compute gradients
        model.zero_grad()  # Zero gradients for clean backpropagation
        loss.backward()  # Backpropagation

        # Check gradients of x at future time steps t' > t
        future_timestep_grads = x.grad[:, :, t+1].abs().sum().item()
        print(f"Future timestep gradients at time step {t}: {future_timestep_grads}")

        if future_timestep_grads > 1e-5:  # Threshold to detect non-zero gradients
            print(f"Model is NOT causal! Detected gradient influence at future time steps after {t}")
        else:
            print(f"Model is causal up to time step {t}")

def sanity_check_batching(model, dataset):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    nr_samples_for_inference = 20
    L = 1000
    model.eval()
    torch.backends.cudnn.deterministic = True
    with torch.no_grad():
        for sample in dataset:
            y_pred_no_batch = []
            x, y, M, I = sample
            x = x = x.to(device)
            samples_preds = []
            batch_size = int(L/nr_samples_for_inference)
            
            batched_x = torch.zeros((batch_size, x.shape[0], nr_samples_for_inference, x.shape[2]))
            for i in range(batch_size):
                batched_x[i,:,:,:] = x[:, i*nr_samples_for_inference:i*nr_samples_for_inference+nr_samples_for_inference,:, ]  # shape (nr_features,  nr_samples_for_inference,nr_timesteps)
                res = (model((batched_x[i,:,:,:].unsqueeze(0).to(device),{}))[0])
                y_pred_no_batch.append(res)
            with torch.no_grad():
                y_pred_batched = model((batched_x.to(device),{}))[0]
            
            y_pred_no_batch = torch.cat(y_pred_no_batch, dim=0)
            ## Difference:
            
            print("Difference in predictions between batched and non-batched")
            print(torch.abs(y_pred_no_batch - y_pred_batched).sum())
            print("Shape: ", y_pred_no_batch.shape)
            print("Relative Error: ", torch.abs(y_pred_no_batch - y_pred_batched).sum()/torch.abs(y_pred_no_batch).sum())
            break


def evaluate_model(model, dataset, nr_samples_for_inference_vec, check_grads=False, check_batching=True):
    """If nr_samples_for_inference is set to -1, all samples in the dataset are used for inference.
    else only nr_samples_for_inference are used at a time, corresponding to the 
    partially observed setting.
    """
    
    import traceback  # Add traceback for error location
    
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
    if torch.cuda.is_available():  # Check if CUDA (NVIDIA GPU) is available
        device = torch.device("cuda")  # Use NVIDIA GPU
    else:
        device = torch.device("cpu")  # Fall back to CPU
    #device = torch.device("cpu")
    print(f"Using device: {device}")
    model.to(device)
    if check_grads:
        check_causality(model, dataset)
    
    if check_batching:
        sanity_check_batching(model,dataset)
    evaluations = []
    
    model.eval()
    for nr_samples_for_inference in nr_samples_for_inference_vec:
        nr_samples_for_inference = int(nr_samples_for_inference)
        L = dataset.num_data_per_time_step()
        if nr_samples_for_inference == -1:
            nr_samples_for_inference = L
        assert nr_samples_for_inference <=L
        if L % nr_samples_for_inference != 0:
            nr_samples_for_inference = L/(L//nr_samples_for_inference)

        all_model_preds = []
        
        for sample in dataset:
            x, y, M, I = sample
            x = x.to(device)
            batch_size = int(L/nr_samples_for_inference)
            batched_x = torch.zeros((batch_size, x.shape[0], nr_samples_for_inference, x.shape[2]))
            for i in range(batch_size):
                try:
                    batched_x[i,:,:,:] = x[:, i*nr_samples_for_inference:i*nr_samples_for_inference+nr_samples_for_inference,:, ]  # shape (nr_features,  nr_samples_for_inference,nr_timesteps)
                except Exception as e:
                    print(f"Error when batching: {e}")
                    breakpoint()
            with torch.no_grad():
                try:
                    # Fix here: Properly handle the model output
                    model_output = model((batched_x.to(device),{}))
                except Exception as e:
                    print(f"Error when calling model: {e}")
                    print("Error traceback:")
                    print(traceback.format_exc())  # This will show exactly where the error happens
                    kwargs = {"nr_units": 1000}
                    model_output = model(batched_x.to(device),**kwargs)
                    #raise  # Re-raise to stop execution
                
                # Check if the output is a tuple and extract the predictions correctly
                if isinstance(model_output, tuple):
                    y_pred = model_output[0]  # Get predictions part
                else:
                    y_pred = model_output  # Model output is directly the predictions

            #(batch_size, nr_samples_for_inference, nr_timesteps, nr_states)
            y_pred = F.softmax(y_pred, dim=3)
            y_pred = y_pred.reshape(-1, y_pred.shape[2], y_pred.shape[3]).cpu().detach().numpy()
        
            all_model_preds.append(y_pred)
        evaluations.append(all_model_preds)
    assert len(evaluations) == len(nr_samples_for_inference_vec)

    return evaluations

def save_model_transition_probabilities(train_set,config, nr_samples_for_inference_vec, model_config, name, pretext):
    """A simple sanitycheck for the MortgageDataset class."""
    
    check_grads = False
    print("Loading model")
    model = load_model(model_config["experiment"], model_config["checkpoint_path"], model_config["data_path"])
    print("Model loaded")
    model._state = None
    all_model_preds_vec  = evaluate_model(model, train_set,nr_samples_for_inference_vec=nr_samples_for_inference_vec, check_grads=check_grads)
    for i in range(len(nr_samples_for_inference_vec)):
        a = config["generator"]["alpha"]
        b = config["generator"]["beta"]
        mu = config["generator"]["mu"]
        np.save(pretext+"data/all_set_seq_model_preds_mu_{}_alpha_{}_beta_{}_n_{}_{}.npy".format(mu, a, b, nr_samples_for_inference_vec[i],name), np.array(all_model_preds_vec[i]))
    
    
    
    


def save_true_y_and_y_probs(train_set, config, pretext):
    a = config["generator"]["alpha"]
    b = config["generator"]["beta"]
    mu = config["generator"]["mu"]
    np.save(pretext+"data/y_label_mu_{}_alpha_{}_beta_{}.npy".format(mu, a, b), np.array(train_set.Y.cpu()))  # True values for y
    np.save(pretext+"data/y_true_probs_mu_{}_alpha_{}_beta_{}.npy".format(mu, a, b), np.array(train_set.transition_probabilities.cpu())) # True transition probabilities
    print("Saved True y and y_probs")


def save_kalman_filter_probabilities(train_set, config, pretext):
    a = config["generator"]["alpha"]
    b = config["generator"]["beta"]
    mu = config["generator"]["mu"]
    
    if len(config["n_obs_partial_hawkes"])>0:
        probs, _ = train_set.get_partial_transition_probs()
        for idx, n_obs in enumerate(config["n_obs_partial_hawkes"]):
            if config["partial_obs_method"] == "partially_observed_hawkes_kalman":
                np.save(
                    pretext+"data/all_y_partial_true_probs_N_{}_mu_{}_alpha_{}_beta_{}_kalman.npy".format(
                        n_obs, mu,a,b), np.array(probs[:,:,:,:,idx].cpu().detach())
                    )
            else:
                np.save(
                    pretext+"data/all_y_partial_true_probs_N_{}_mu_{}_alpha_{}_beta_{}.npy".format(
                        n_obs, mu,a,b), np.array(probs[:,:,:,:,idx].cpu().detach()))
    print("Saved")

def get_dataset(config):
    dataset = MortgageDataset(**copy.deepcopy(config))
    dataset.setup()
    train_set = dataset.dataset_train
    return train_set

def get_configs():

    model_confi1g= {
        "experiment": "timeseries/ts_lc_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-15/13-33-01/step_7600.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json"
    } 

    model_config_best_overall= {
        "experiment": "timeseries/ts_lc_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-15/14-16-30/step_7600.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" 
    }   

    model_config_ = {
        "experiment": "timeseries/ts_lc_best_large_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-15/15-13-09/step_7600.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json"
    }    # best avg


    set = {
        "experiment": "timeseries/ts_lc_best_large_set",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-15/18-59-10/step_7600.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json"
    }    # best avg

    logunif008_best_set = {
        "experiment": "timeseries/ts_lc_set",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-16/19-14-21/step_6600.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" } #logunif 0.08

    logunif008_best_mha = {
        "experiment": "timeseries/ts_lc_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-16/12-18-23/last.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" } #logunif 0.08
    
    logunif02 = {
        "experiment": "timeseries/ts_lc_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-16/12-05-36/step_7000.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" } #logunif 0.2
    
    dirac_1000 = {
        "experiment": "timeseries/ts_lc_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-16/01-18-21/last.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" } #logunif 0.0
    
    dirac_100 = {
        "experiment": "timeseries/ts_lc_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-16/01-31-59/last.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" } # dirac(100)
    
    dirac_10 = {
        "experiment": "timeseries/ts_lc_mha",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-01-16/02-07-55/last.ckpt",  ## 20 samples
        "data_path": f"{BASE_PATH}/data/mortgage_new2913734.json" } #all weight on 10

    model_configs2 = {
        "MHA-Seq LogUnif 0.08": logunif008_best_mha,
        "Set-Seq LogUnif 0.08": logunif008_best_set,
        #"set": set,
    }
    
    model_configs = {
        "MHA-Seq Dirac 10": dirac_10,
        "MHA-Seq Dirac 100": dirac_100,
        "MHA-Seq Dirac 1000": dirac_1000,
    }
    return model_configs



def main():
    nr_partial_obs_hawkes = [1,2,5,10,50,100,200,500,1000]
    config = {
        "_name_": "timeseries_synthetics",
        "num_states": 3, #total number of states
        "num_terminal_states": 1, #number of terminal states
        "use_feature": True, #If false, will not include the macro variable as a feature, and also not include the loan specific features
        "simulation_steps": 100, #length of each sequence
        "loan_pool_size": 1000, #pool size
        "load_saved_data": False,
        "saved_data_directory": f"{BASE_PATH}/data/mortgage_new2/", 
        "save_data": True,
        "num_seq": 20, #number of sequences 20
        "val_split": 0.1, #fraction of samples in the validation split
        "test_split": 0.1, #fraction of samples in the test split
        "dataset_name": "timeseries_synthetics",
        "nr_steps": 10, #number of different starting points
        "forecasting": False,
        "n_obs_partial_hawkes":  nr_partial_obs_hawkes, #[0,1,2,5,10, 50,100,200, 500,1000], # 1,  5, 10, 20, 100
        "partial_obs_method": "partially_observed_hawkes", #partially_observed_hawkes_kalman
        "use_random_input_size": False,
        "random_input_size_options": [2,5,10,50,100,200,500,1000], #[1,2,5,10,50,100,200,500,1000],
        "random_input_size_probabilities": [0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.3], #[
        "forecasting_horizon": 1,
        "lookback_horizon": 1,
        "generator": {
            
            "level": "supereasy_2d", #supereasy_2d
            "path_dependency_dimension": 2,
            "h_kappa": 100,  # 0  updated to control the path dependency
            "mu": 0.001,
            "alpha": 0.004,
            "beta": 0.5,
            "debug": False, #if true the dynamics are simplified
            "hard": True, #only used if debug=True. If False, the dynamics are deterministic
        }
    }
    config["load_saved_data"] = True
    model_configs = get_configs()

    train_set = get_dataset(copy.deepcopy(config))
    
    pretext = f"{BASE_PATH}/scripts/notebooks/" 
    save_kalman_filter_probabilities(train_set, copy.deepcopy(config), pretext)
    save_true_y_and_y_probs(train_set, copy.deepcopy(config), pretext)
    #nr_samples_for_inference_vec = [1,2,4,5,10,20,25, 40, 50,100,200,250, 500,1000] #[1,2,5,10,50,100,200,500,1000] #[1,2,5,10, 50,100,200, 500,1000]
    nr_samples_for_inference_vec = [1,2,5,10,50,100,200,500,1000]
    # Iterate over model_configs
    for name, model_config in model_configs.items():
        print("Evaluating model: ", name)
        save_model_transition_probabilities(train_set, copy.deepcopy(config), nr_samples_for_inference_vec, model_config, name, pretext)
    
    plot_results = True
    params = {
        "alpha": 0.004,
        "beta": 0.5,
        "mu": 0.001,
    }
    compare_kalman_with_baseline = True
    if compare_kalman_with_baseline:
        y_true = get_y_true(params, pretext)
        true_probs = get_true_probs(params,pretext)
        kalman_preds = get_kalman_predictions(params, nr_partial_obs_hawkes, pretext)
        partially_observed_probs = get_kalman_predictions(params, nr_partial_obs_hawkes, pretext, kalman=False)
        model_dict = {
            "Kalman Filter": kalman_preds,
            "Partially Observed Hawkes": partially_observed_probs
        }
        save_path = pretext+"data/"+f"compare_kalman_baseline_mu_{params['mu']}_alpha_{params['alpha']}_beta_{params['beta']}.pdf"
        plot_kalman_model_comparison(
            y_true, 
            true_probs, 
            model_dict, 
            nr_samples_for_inference_vec, 
            save_path)
        

    if plot_results:
        os.chdir('../')
        save_path = pretext+"data/"+f"compare_models_mu_{params['mu']}_alpha_{params['alpha']}_beta_{params['beta']}.pdf"
        
        y_true = get_y_true(params, pretext)
        true_probs = get_true_probs(params,pretext)
        kalman_preds = get_kalman_predictions(params, nr_partial_obs_hawkes, pretext)
        model_dict = {}
        for name in model_configs.keys():
            model_dict[name] = get_model_predictions(name,params, nr_samples_for_inference_vec, pretext)
        model_dict["Kalman Filter"] = kalman_preds
        plot_kalman_model_comparison(
            y_true, 
            true_probs, 
            model_dict, 
            nr_samples_for_inference_vec, 
            save_path)
        


if __name__ == "__main__":
    import cProfile
    import pstats

    # Create a profiler
    profiler = cProfile.Profile()
    profiler.enable()
    
    # Run your main function
    main()  
    
    profiler.disable()
    
    # Sort and display the stats
    stats = pstats.Stats(profiler)
    stats.sort_stats("cumtime").print_stats(30)  # Sort by cumulative time, show top 10 functions
