#
# Simulation with perturbed nuisance functions
#

import numpy as np
import pandas as pd
import itertools
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

def set_size(width, fraction=1, subplots=(3, 3)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 426.79135
    elif width == 'beamer':
        width_pt = 307.28987
    else:
        width_pt = width

    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

if __name__ == "__main__":

    run_simulations = True
    n = 1000

    if run_simulations:
        np.random.seed(6)

        att = 5
        n_simulations = 1000
        alpha_mus = np.linspace(0, 1, 5)
        alpha_pis = np.linspace(0, 1, 5)
        alpha_gammas = np.linspace(0, 1, 5)

        estimates = pd.DataFrame()

        for seed in tqdm.tqdm(range(n_simulations)):

            #
            # Generate data
            #

            baseline_effect = np.random.normal(2, 0.5)
            mu00 = np.random.normal(5, 1)
            mu01 = mu00
            mu10 = mu00 + baseline_effect
            mu11 = mu10 + att
            pi = np.random.uniform(0.1, 0.9)
            gamma1 = np.random.uniform(0.1, 0.9)
            gamma0 = np.random.uniform(0.1, 0.9)

            A = np.random.binomial(1, pi, n)
            Y0 = mu00 + np.random.normal(0, 1, n)
            Y11 = mu11 + np.random.normal(0, 1, n)
            Y10 = mu10 + np.random.normal(0, 1, n)
            Y1 = np.where(A == 1, Y11, Y10)
            R00 = np.random.binomial(1, gamma0, n)
            R01 = np.random.binomial(1, gamma1, n)
            R0 = np.where(A == 1, R01, R00)
            Y0_obs = np.where(R0 == 1, Y0, 0)

            for alpha_mu, alpha_pi, alpha_gamma in itertools.product(alpha_mus, alpha_pis, alpha_gammas):

                #
                # Perturb nuisance functions
                #

                mu00_hat = mu00 * alpha_mu + np.random.normal(0, 1) * (1 - alpha_mu)
                mu01_hat = mu01 * alpha_mu + np.random.normal(0, 1) * (1 - alpha_mu)
                mu10_hat = mu10 * alpha_mu + np.random.normal(0, 1) * (1 - alpha_mu)
                mu11_hat = mu11 * alpha_mu + np.random.normal(0, 1) * (1 - alpha_mu)
                eta_hat00 = mu00_hat
                pi_hat = pi * alpha_pi + np.random.uniform(0.1, 0.9) * (1 - alpha_pi)
                gamma0_hat = gamma0 * alpha_gamma + np.random.uniform(0.1, 0.9) * (1 - alpha_gamma)
                gamma1_hat = gamma1 * alpha_gamma + np.random.uniform(0.1, 0.9) * (1 - alpha_gamma)

                #
                # Influence function
                #

                influence_function = (A / A.sum()) * (Y1 - (mu01_hat + R0 * (Y0_obs - mu01_hat) / gamma1_hat) - mu10_hat + eta_hat00) - ((1 - A) * pi_hat / ((1 - pi_hat) * A.sum())) * (Y1 - mu10_hat - mu00_hat + eta_hat00 - R0 * (Y0_obs - mu00_hat) / gamma0_hat)

                att_estimate = influence_function.sum()
                delta_method_var = np.sum((influence_function - A *att_estimate / A.sum()) ** 2)
                # conf_inf = (att_estimate - 1.96 * influence_function.std() * np.sqrt(A.sum()), att_estimate + 1.96 * influence_function.std() * np.sqrt(A.sum()))
                conf_inf = (att_estimate - 1.96 * np.sqrt(delta_method_var), att_estimate + 1.96 * np.sqrt(delta_method_var))
                coverage = (conf_inf[0] <= att <= conf_inf[1])
                width = conf_inf[1] - conf_inf[0]

                estimates = pd.concat([estimates, pd.DataFrame({
                    "seed": [seed],
                    "alpha_mu": [alpha_mu],
                    "alpha_pi": [alpha_pi],
                    "alpha_gamma": [alpha_gamma],
                    "att_estimate": [att_estimate],
                    "absolute_bias": [np.abs(att_estimate - att)],
                    "squared_error": [(att_estimate - att) ** 2],
                    "coverage": [coverage],
                    "width": [width],
                })], ignore_index=True)

        estimates.to_csv(f"implicit_simulation_results_{n}.csv", index=False)

    else: 
        estimates = pd.read_csv(f"implicit_simulation_results_{n}.csv")

    #
    # Analysis
    #

    # only 0 and 1
    estimates_extreme = estimates[(estimates["alpha_mu"].isin([0, 1])) & (estimates["alpha_pi"].isin([0, 1])) & (estimates["alpha_gamma"].isin([0, 1]))]
    print(estimates_extreme.groupby(["alpha_mu", "alpha_pi", "alpha_gamma"])[["absolute_bias", "squared_error", "coverage", "width"]].mean().reset_index())

    #
    # Plotting
    #

    # plot matrix of heatmaps for bias, coverage, width, MSE
    sns.set_theme(style="whitegrid", palette="pastel", font_scale=0.5)
    width = 396
    
    # Pre-calculate global min/max for each metric
    metrics = ["absolute_bias", "squared_error", "coverage", "width"]
    global_limits = {}
    
    grouped = estimates.groupby(["alpha_mu", "alpha_pi", "alpha_gamma"])[metrics].mean().reset_index()
    
    for metric in metrics:
        if metric == "squared_error":
            vals = np.sqrt(grouped[metric])
        else:
            vals = grouped[metric]
        global_limits[metric] = (vals.min(), vals.max())

    # UPDATED: Added a spacer row ('.') to shift the colorbars down
    # height_ratios: [plots (1), ..., spacer (0.3), cbars (0.08)]
    axd = plt.figure(figsize=set_size(width, subplots=(5.5, 4))).subplot_mosaic(
        [['absolute_bias0', 'squared_error0', 'coverage0', 'width0'],
         ['absolute_bias25', 'squared_error25', 'coverage25', 'width25'],
         ['absolute_bias50', 'squared_error50', 'coverage50', 'width50'],
         ['absolute_bias75', 'squared_error75', 'coverage75', 'width75'],
         ['absolute_bias100', 'squared_error100', 'coverage100', 'width100'],
         ['.', '.', '.', '.'],  # Spacer row
         ['cbar_bias', 'cbar_mse', 'cbar_cov', 'cbar_width']], 
         gridspec_kw={'wspace': 0.1, 'hspace': 0.1, 'height_ratios': [1, 1, 1, 1, 1, 0.3, 0.08]})
    
    alphas_pi = [0, 0.25, 0.5, 0.75, 1]

    # Define specific cmaps for consistency
    cmap_error = sns.cm.rocket_r
    cmap_coverage = plt.cm.viridis # Explicitly set to match colorbar

    for i, alpha_pi in enumerate(alphas_pi):
        subset = estimates[estimates["alpha_pi"] == alpha_pi]
        pivot_bias = subset.pivot_table(index="alpha_mu", columns="alpha_gamma", values="absolute_bias", aggfunc="mean")
        pivot_mse = subset.pivot_table(index="alpha_mu", columns="alpha_gamma", values="squared_error", aggfunc="mean")
        pivot_coverage = subset.pivot_table(index="alpha_mu", columns="alpha_gamma", values="coverage", aggfunc="mean")
        pivot_width = subset.pivot_table(index="alpha_mu", columns="alpha_gamma", values="width", aggfunc="mean")

        pivot_bias_annot = pivot_bias.round(2)
        pivot_mse_annot = pivot_mse.apply(np.sqrt).round(2) 
        pivot_coverage_annot = pivot_coverage.round(3)
        pivot_width_annot = pivot_width.round(2)

        vmin_bias, vmax_bias = global_limits["absolute_bias"]
        vmin_mse, vmax_mse = global_limits["squared_error"]
        vmin_cov, vmax_cov = global_limits["coverage"] 
        vmin_width, vmax_width = global_limits["width"]

        sns.heatmap(pivot_bias, ax=axd[f'absolute_bias{int(alpha_pi*100)}'], annot=pivot_bias_annot, 
                    cbar=False, cmap=cmap_error, vmin=vmin_bias, vmax=vmax_bias)
        
        sns.heatmap(pivot_mse.apply(np.sqrt), ax=axd[f'squared_error{int(alpha_pi*100)}'], annot=pivot_mse_annot, 
                    cbar=False, cmap=cmap_error, vmin=vmin_mse, vmax=vmax_mse)
        
        # UPDATED: Added cmap argument here to match the viridis colorbar
        sns.heatmap(pivot_coverage, ax=axd[f'coverage{int(alpha_pi*100)}'], annot=pivot_coverage_annot, 
                    cbar=False, cmap=cmap_coverage, vmin=vmin_cov, vmax=vmax_cov)
        
        sns.heatmap(pivot_width, ax=axd[f'width{int(alpha_pi*100)}'], annot=pivot_width_annot, 
                    cbar=False, cmap=cmap_error, vmin=vmin_width, vmax=vmax_width)

        for metric in metrics:
            ax = axd[f'{metric}{int(alpha_pi*100)}']
            ax.set_xlabel('')
            ax.set_ylabel('')

        for metric in metrics:
            ax = axd[f'{metric}{int(alpha_pi*100)}']
            if i == 4:
                ax.set_xticklabels(['0', '0.25', '0.5', '0.75', '1'], rotation=0)
                ax.set_xlabel('$\\alpha_{\\gamma}$')
            else:
                ax.set_xticklabels([])
                ax.set_xlabel('')
            if metric == "absolute_bias":
                ax.set_yticklabels(['0', '0.25', '0.5', '0.75', '1'], rotation=0)
                ax.set_ylabel('$\\alpha_{\\mu}$')
            else:
                ax.set_yticklabels([]) 
                ax.set_ylabel('') 
            ax.tick_params(pad=-3)

    axd[f'absolute_bias0'].set_title(f'Bias')
    axd[f'squared_error0'].set_title(f'RMSE')
    axd[f'coverage0'].set_title(f'Coverage')
    axd[f'width0'].set_title(f'CI Width')

    for i, alpha_pi in enumerate(alphas_pi):
        ax = axd[f'width{int(alpha_pi*100)}']
        ax.annotate(f'$\\alpha_{{\\pi}}$ = {alpha_pi}', xy=(1.03, 0.5), xycoords='axes fraction', rotation=-90, va='center')

    # UPDATED: Use explicit cmap object for coverage to match heatmap
    cbar_configs = [
        ('cbar_bias', global_limits["absolute_bias"], cmap_error),
        ('cbar_mse', global_limits["squared_error"], cmap_error),
        ('cbar_cov', global_limits["coverage"], cmap_coverage), 
        ('cbar_width', global_limits["width"], cmap_error)
    ]

    for ax_key, (vmin, vmax), cmap in cbar_configs:
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        
        cb = plt.colorbar(sm, cax=axd[ax_key], orientation='horizontal')
        cb.outline.set_visible(False)
        cb.ax.tick_params(labelsize=6)

    plt.savefig(f"implicit_simulation_results_{n}.pdf", bbox_inches='tight')