from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from typing import Literal

def plot_density(
    df: pd.DataFrame,
    y: Literal["posterior", "posterior_distance", "posterior_distance_delta", "delta"],
    x: Literal["prior", "prior_distance"],
    hue: Literal["model", "reasoning_mode", "domain", "prompt"],
) -> None:
    df = df.copy()
    df["posterior"] = df["prior"] + df["delta"]
    df["posterior_distance"] = abs(df["posterior"] - (1 - df["gt"]))
    df["prior_distance"] = abs(df["prior"] - (1 - df["gt"]))
    df["posterior_distance_delta"] = df["posterior_distance"] - df["prior_distance"]
    
    df[x.capitalize()] = df[x]; x = x.capitalize()
    df[y.capitalize()] = df[y]; y = y.capitalize()
    
    jp = sns.jointplot(
        data=df,
        x=x,
        y=y,
        hue=hue,
        kind="kde",
        bw_adjust=2,
    )
    
    sns.despine(left=True, bottom=True)
    plt.xlim((-.02, 1.02))
    
    x = x.lower()
    y = y.lower()
    
    if "delta" not in y:
        plt.ylim((-.02, 1.02))
    else:
        plt.ylim((-0.75, 0.75))
    
    if x == "prior_distance":
        plt.xlabel("Prior distance to ground truth")
    
    if y == "posterior_distance":
        plt.ylabel("Posterior distance to ground truth")
    elif y == "posterior_distance_delta":
        plt.ylabel("Posterior distance - prior distance")
    
    # Show the plot and save it
    plt.savefig(f"data/figures/per_traj/density_{y}_{x}_{hue}.pdf")
    


if __name__ == "__main__":
    sns.set_theme(style="whitegrid")
    df = pd.read_csv("data/tmp/causal-attribution-data-per-trajectory.csv")
    
    # Remove all rows where domain is "CMV" (CMV doesn't have ground truths)
    df = df[df["domain"] != "CMV"]
    
    plot_density(df, "posterior_distance", "prior_distance", "reasoning_mode")
    plot_density(df, "posterior_distance", "prior_distance", "domain")
    plot_density(df, "posterior_distance", "prior_distance", "model")
    plot_density(df, "posterior_distance", "prior_distance", "prompt")
    plot_density(df, "delta", "prior", "reasoning_mode")
    plot_density(df, "delta", "prior", "domain")
    plot_density(df, "delta", "prior", "model")
    plot_density(df, "delta", "prior", "prompt")
    plot_density(df, "posterior_distance_delta", "prior_distance", "reasoning_mode")
    plot_density(df, "posterior_distance_delta", "prior_distance", "domain")
    plot_density(df, "posterior_distance_delta", "prior_distance", "model")
    plot_density(df, "posterior_distance_delta", "prior_distance", "prompt")