import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

def set_size(width, fraction=1, subplots=(3, 3)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 426.79135
    elif width == 'beamer':
        width_pt = 307.28987
    else:
        width_pt = width

    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

def plot_updated_dags():
    # 1. Define Edges (Removed A -> Y_0)
    dag1_edges = [
        ('X', 'A'),
        ('X', 'R_0'), ('A', 'R_0'),
        ('X', 'Y_0'),                # ('A', 'Y_0') is gone
        ('X', 'Y_1'), ('A', 'Y_1'), ('Y_0', 'Y_1')
    ]

    dag2_edges = [
        ('X', 'A'),
        ('X', 'R_0'), ('A', 'R_0'), ('Y_1', 'R_0'), # Extra link for DAG 2
        ('X', 'Y_0'),                # ('A', 'Y_0') is gone
        ('X', 'Y_1'), ('A', 'Y_1'), ('Y_0', 'Y_1')
    ]

    # 2. Labels
    custom_labels = {
        'X': r'$X$', 'A': r'$A$',
        'R_0': r'$R_0$', 'Y_0': r'$Y_0$', 'Y_1': r'$Y_1$'
    }

    # 3. Layout
    pos = {
        'X': (0, 1),   'A': (0, 0),
        'Y_0': (1, 0.5),
        'Y_1': (2, 0.5),
        'R_0': (3, 0.5)
    }


    axd = plt.figure(figsize=set_size(width, subplots=(1,2))).subplot_mosaic(
        [['dag1', 'dag2']])

    def draw_graph(edges, ax, title):
        G = nx.DiGraph()
        G.add_edges_from(edges)

        # Draw Nodes (Size 1000 to prevent hiding arrows)
        nx.draw_networkx_nodes(G, pos, ax=ax, 
                               node_color='white', edgecolors='black')

        # Draw Labels
        nx.draw_networkx_labels(G, pos, labels=custom_labels, ax=ax, font_size=10)

        # Draw Edges
        nx.draw_networkx_edges(G, pos, ax=ax, 
                               edge_color='black',
                               width=1, 
                               arrows=True, 
                               arrowstyle='-|>') 
        
        ax.set_title(title)
        ax.margins(0.2)
        ax.axis('off')

    # 4. Plot
    draw_graph(dag1_edges, axd['dag1'], "Assumption 2.3")
    draw_graph(dag2_edges, axd['dag2'], "Assumption 2.4")

    plt.savefig("dags.pdf", bbox_inches='tight')

if __name__ == "__main__":
    sns.set_theme(style="whitegrid", palette="pastel", font_scale=1)
    width = 396

    plot_updated_dags()