"""Plot the distribution of defaults for a given loan pool size, lookback
horizon and time step.
"""
import sys
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
sys.path.append(BASE_PATH)
sys.path.append(BASE_PATH+"/data/")
sys.path.append(BASE_PATH+"/data/mortgage_new2/")
from src.dataloaders.dataloader_mortgage import PathGenerator, MortgageDataset, Mortgage, Split
import numpy as np
import torch
import copy
import matplotlib.pyplot as plt

# Set larger font sizes
plt.rcParams['axes.titlesize'] = 20   # Set the font size for plot titles
plt.rcParams['axes.labelsize'] = 18   # Set the font size for axes labels
plt.rcParams['xtick.labelsize'] = 16  # Set the font size for xtick labels
plt.rcParams['ytick.labelsize'] = 16  # Set the font size for ytick labels

def mortgage_2d_path_dependency_test(config):
    """A simple sanitycheck for the MortgageDataset class."""
    dataset = MortgageDataset(**copy.deepcopy(config))
    dataset.setup()
    dataset_hard = MortgageDataset(**copy.deepcopy(config))
    dataset_hard.setup()
    train_set_hard = dataset_hard.dataset_train
    
    train_set = dataset.dataset_train
    I_train = train_set.sequence_start_index #I is the start index for the sequence
    f_train = train_set.observed_macro_variable #f is a macro variable that is known
    v_train = train_set.unobserved_macro_variable #v is a macro variable that is unknown (i.e not included as a feature for x)

    hidden_path_var = train_set.hidden_path_var
    defaults_at_time_t = torch.zeros(config["num_samples"]-1, len(train_set))
    idx = 0
    for sample in train_set:
        x, y,M,I = sample
        assert y.shape == (
            config["loan_pool_size"],
            config["num_samples"]-1,
            config["num_states"]
            )
        assert x.shape == (
            config["num_states"]*config["loan_pool_size"]+1+config["loan_pool_size"],
            config["num_samples"]-1
            )
        terminal_state = 2
        for i in range(config["num_samples"]-1):
            for j in range(config["loan_pool_size"]):
                if y[j,i,terminal_state] == 1:
                    defaults_at_time_t[i,idx] +=1
        idx+=1
    
    #For each time step, plot the distribution of the number of defaults
    #Save in the ./../plots/ folder
    
    time_points = [0, (config["num_samples"] - 1) // 3, 2 * (config["num_samples"] - 1) // 3, config["num_samples"] - 2]

    # Create a figure with 4 subplots
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))

    # Plot the histograms and add labels
    for i, time_point in enumerate(time_points):
        axs[i].hist(defaults_at_time_t[time_point, :].numpy(), bins=range(0, config["loan_pool_size"] + 1), density=True)
        axs[i].set_title(f"Defaults at time t={time_point + 1}")
        axs[i].set_xlabel("Number of Defaults")
        axs[i].set_ylabel("Frequency")

    # Adjust layout to prevent overlap
    plt.tight_layout()

    # Save the figure
    plot_path = f"{BASE_PATH}/scripts/notebooks/plot/defaults_distribution_subplot_1.png"
    plt.savefig(plot_path)
    #from matplotlib.animation import FuncAnimation
    #def update_plot(i):
    #    plt.cla()  # Clear the current axes
    #    plt.hist(defaults_at_time_t[i, :].numpy(), bins=range(0, config["loan_pool_size"] + 1), density=True)
    #    plt.title(f"Defaults at time t={i+1}")
    
    #fig = plt.figure()

    # Create an animation
    #anim = FuncAnimation(fig, update_plot, frames=range(config["num_samples"]-1), repeat=False)

    # Save the animation
    #anim.save('/mnt/safari-dev/scripts/notebooks/plot/defaults_animation.mp4')  
    #plt.close()
    
    #for i in range(config["num_samples"]-1):
    #    plt.hist(defaults_at_time_t[i,:].numpy(), bins=range(0,config["loan_pool_size"]+1), density=True)
    #    plt.title(f"Defaults at time t={i+1}")
    #    plt.savefig(f"/mnt/safari-dev/scripts/notebooks/plot/movie/defaults_at_time_t_{i+1}.png")
    #    plt.close()
    #from moviepy.editor import ImageSequenceClip
    #image_files = [f"/mnt/safari-dev/scripts/notebooks/plot/movie/defaults_at_time_t_{i+1}.png" for i in range(config["num_samples"] - 1)]
    #clip = ImageSequenceClip(image_files, fps=5)  # Adjust fps (frames per second) as needed
    #clip.write_videofile(f"/mnt/safari-dev/scripts/notebooks/plot/defaults_video_h_{config['generator']['h']}_look_back_{config['generator']['h_look_back']}_upd1.mp4", codec="libx264")





def main():
    config = {
        "_name_": "timeseries_synthetics",
        "num_states": 3, #total number of states
        "num_terminal_states": 1, #number of terminal states
        "use_feature": True, #If false, will not include the macro variable as a feature, and also not include the loan specific features
        "num_samples": 100, #length of each sequence
        "loan_pool_size": 30, #pool size
        "load_saved_data": True,
        "saved_data_directory": f"{BASE_PATH}/data/mortgage_new2/",
        "save_data": True,
        "num_seq": 10, #number of sequences
        "val_split": 0.1, #fraction of samples in the validation split
        "test_split": 0.1, #fraction of samples in the test split
        "dataset_name": "timeseries_synthetics",
        "nr_steps": 10, #number of different starting points
        "seed": 42, # For validation split
        "forecasting": False,
        "forecasting_horizon": 1,
        "lookback_horizon": 10,
        "generator": {
            
            "level": "supereasy_2d_no_loan_specific_feature",
            "path_dependency_dimension": 1,
            "h_kappa": 0,
            "h_look_back": 2,
            "k_bias": 0.65, #0.75
            "a_bias": 0.03,
            "b_bias": 0.3,
            "k_scale": 0.22,
            "a_scale": -0.03,
            "b_scale": 0.3,
            "debug": False, #if true the dynamics are simplified
            "hard": True, #only used if debug=True. If False, the dynamics are deterministic
        }
    }
    config["generator"]["h"] = 150
    config["generator"]["h_look_back"] = 2
    config["load_saved_data"] = False
    import copy
    mortgage_2d_path_dependency_test(copy.deepcopy(config))

main()
 