"""
Code modified from https://github.com/xunzheng/notears/blob/master/notears/utils.py
"""
import numpy as np
import igraph as ig
import random

# For plot_graphs function
import matplotlib.pyplot as plt
import pydot
from PIL import Image
from io import BytesIO


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)


def is_dag(W):
    G = ig.Graph.Weighted_Adjacency(W.tolist())
    return G.is_dag()


def plot_graphs(interv_targets, selection_parents, dag,
                pag_edges_from_observational, pag_edges_from_interventional, savedir):
    def draw_edges(dotG, edges_dict, colored_edges=None):
        if colored_edges is None: colored_edges = {}
        for x, y in edges_dict.get('->', []):
            dotG.add_edge(pydot.Edge(x, y, arrowtail='none', arrowhead='normal', color=colored_edges.get((x, y), 'black')))
        for x, y in edges_dict.get('<->', []):
            dotG.add_edge(pydot.Edge(x, y, dir='both', color=colored_edges.get((x, y), 'black'))) if x < y else None
        for x, y in edges_dict.get('--', []):
            dotG.add_edge(pydot.Edge(x, y, dir='none', color=colored_edges.get((x, y), 'black'))) if x < y else None
        for x, y in edges_dict.get('⚬--', []): dotG.add_edge(
            pydot.Edge(x, y, arrowtail='odot', arrowhead='none', dir='both', color=colored_edges.get((x, y), 'black')))
        for x, y in edges_dict.get('⚬->', []): dotG.add_edge(
            pydot.Edge(x, y, arrowtail='odot', arrowhead='normal', dir='both', color=colored_edges.get((x, y), 'black')))
        for x, y in edges_dict.get('⚬-⚬', []): dotG.add_edge(
            pydot.Edge(x, y, arrowtail='odot', arrowhead='odot', dir='both', color=colored_edges.get((x, y), 'black'))) if x < y else None

    # Get useful parameters
    nodenum = len(dag)
    edgelist = [(u, v) for u, v in np.array(np.where(dag)).T.tolist()]
    num_of_intervs = len(interv_targets)
    num_of_selections = len(selection_parents)
    fig, axs = plt.subplots(1, 3, figsize=(15, 8))
    for ax in axs.flatten(): ax.axis('off')

    # 1: abstract graph (though incorrect)
    graph = pydot.Dot(graph_type='digraph')
    for i in range(nodenum): graph.add_node(pydot.Node(f'X{i}', shape='circle'))
    for i in range(num_of_selections): graph.add_node(pydot.Node(f'S{i}', shape='diamond', color='blue', fontcolor='blue'))
    for i in range(1, 1+num_of_intervs): graph.add_node(pydot.Node(f'I{i}', shape='square', color='red', fontcolor='red'))
    for x, y in edgelist: graph.add_edge(pydot.Edge(f'X{x}', f'X{y}', arrowtail='none', arrowhead='normal'))
    for i, sparents in enumerate(selection_parents):
        for sp in sparents: graph.add_edge(pydot.Edge(f'X{sp}', f'S{i}', arrowtail='none', arrowhead='normal', color='blue'))
    for i, ichildren in enumerate(interv_targets):
        for it in ichildren: graph.add_edge(pydot.Edge(f'I{i+1}', f'X{it}', arrowtail='none', arrowhead='normal', color='red'))
    axs[0].imshow(Image.open(BytesIO(graph.create_png())))
    axs[0].set_title('Abstract graph (incorrect)')

    # 4: PAG of G (observational)
    graph = pydot.Dot(graph_type='digraph')
    for i in range(nodenum): graph.add_node(pydot.Node(f'X{i}', shape='circle'))
    edges_dict = {k: [(f'X{x}', f'X{y}') for x, y in v] for k, v in pag_edges_from_observational.items()}
    draw_edges(graph, edges_dict)
    axs[1].imshow(Image.open(BytesIO(graph.create_png())))
    axs[1].set_title('PAG of G (observational); sure no latents')

    # 5: PAG of Gtwin (interventional)
    graph = pydot.Dot(graph_type='digraph')
    for i in range(1, nodenum): graph.add_node(pydot.Node(f'X{i}', shape='circle'))
    for i in range(1, 1 + num_of_intervs): graph.add_node(pydot.Node(f'I{i}', shape='square', color='red', fontcolor='red'))
    pag_edges_on_interventional = pag_edges_from_interventional
    colored_edges = {(f'I{k + 1}', f'X{i}'): 'red' for k, targets in enumerate(interv_targets) for i in targets} | \
                     {(x, y): 'green' for x, y in pag_edges_on_interventional['->'] if 'X' in x and 'X' in y}
    draw_edges(graph, pag_edges_on_interventional, colored_edges=colored_edges)
    axs[2].imshow(Image.open(BytesIO(graph.create_png())))
    axs[2].set_title('PAG of Gtwin (interventional); sure I targets')

    plt.tight_layout()
    if savedir is not None:
        plt.savefig(f'{savedir}/graph.pdf')