"""
Illustrate convergence of R² and explosion of variance along a causal chain.
"""
import CDExperimentSuite_DEV as CDES
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns
import os

plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{bm}\usepackage{amssymb}"
plt.rcParams["axes.labelsize"] = 22
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["legend.fontsize"] = 16
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["lines.markersize"] = 12


def var_explosion(opt, show_CEV=True):
    """Show how variance explodes along a chain"""
    # Check if we already have some data:
    data_path = opt.base_dir + "/" + opt.exp_name + ".csv"
    if not os.path.exists(data_path) or opt.overwrite:
        # create all run combinations
        DG = CDES.DataGenerator()
        datasets = DG.generate(opt)
        R2s, cev, vars, pos = [], [], [], []
        for d in datasets:
            cev += list(np.maximum(0, d.vars - d.sigma**2) / d.vars)
            R2s += list(1 - np.diag(1 / np.linalg.inv(np.corrcoef(d.data.T))))
            vars += list(np.log(d.vars))
            powers = [
                np.linalg.matrix_power(d.B_true, i) for i in range(1, len(d.B_true))
            ]
            pos += list(sum([(1.0 * np.count_nonzero(p, axis=0) != 0) for p in powers]))
        df = pd.DataFrame({"R2": R2s, "CEV": cev, "Variance": vars, "Position": pos})
    else:
        df = pd.read_csv(data_path)
    CDES.utils.create_folder(opt.base_dir, overwrite=False)

    _, ax1 = plt.subplots(figsize=(9, 4.5))
    sns.lineplot(
        data=df, x="Position", y="Variance", ax=ax1, color="#1F77B4", label="Variance"
    )  # mpl tab10 colors
    ax1.set_ylabel("log(Variance)")
    ax1.tick_params(axis="y", labelcolor="#1F77B4")
    ax1.set_xlabel("Position in causal chain")

    ax2 = ax1.twinx()
    sns.lineplot(
        data=df, x="Position", y="R2", ax=ax2, color="#FF7F0E", label=r"$R^2"
    )  # mpl tab10 colors
    if show_CEV:
        ax2.set_ylabel(r"$R^2$, CEV-fraction", labelpad=10)
    else:
        ax2.set_ylabel(r"$R^2$", labelpad=10)
    ax2.set_xlabel("Position in causal chain")
    ax2.tick_params(axis="y", labelcolor="#FF7F0E")

    if show_CEV:
        sns.lineplot(
            data=df,
            x="Position",
            y="CEV",
            ax=ax2,
            color="#ff5f2e",
            linestyle="--",
            label="cause-explained variance fraction",
        )

    # Create joint legend manually
    h1, _ = ax1.get_legend_handles_labels()
    h2, _ = ax2.get_legend_handles_labels()
    ax1.get_legend().remove()
    ax2.get_legend().remove()

    if show_CEV:
        plt.legend(h2 + h1, [r"$R^2$", "CEV-fraction", "log(Variance)"])
    else:
        plt.legend(h2 + h1, [r"$R^2$", "log(Variance)"])

    ax1.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    ax2.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))

    plt.tight_layout()
    plt.savefig(opt.base_dir + "/" + opt.exp_name + ".pdf")
    plt.close("all")


if __name__ == "__main__":
    opt = {
        "overwrite": True,
        "base_dir": f"src/results/ChainRsb",
        "exp_name": "Chain",
        # ---
        "MEC": False,
        "thres": 0.0,
        "thres_type": "standard",
        "vsb_function": CDES.utils.var_sortability,
        "R2sb_function": CDES.utils.r2_sortability,
        "CEVsb_function": CDES.utils.cev_sortability,
        "scaler": CDES.Scalers.Identity(),
        # ---
        "n_repetitions": 10,
        "edges": [2],
        "graphs": ["chain"],
        "edge_weights": [(0.5, 2)],
        "edge_types": ["fixed"],
        "noise_distributions": [
            CDES.utils.NoiseDistribution("gauss", "uniform", (0.5, 2.0)),
        ],
        "n_nodes": [20],
        "n_obs": [1000],
    }
    opt = CDES.utils.Options(**opt)
    CDES.utils.snapshot(opt)
    var_explosion(opt)
