import os
from xml.etree.ElementInclude import include
import pickle5 as pickle
import argparse
import networkx as nx
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import ray

import misc
from trainers import get_trainer_class


def main():
    # parse argument
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'checkpoint',
        type=str,
        help='Checkpoint from which to roll out.')
    parser.add_argument(
        '--run',
        type=str,
        default='PPO',
        help='The algorithm or model to train')
    parser.add_argument(
        '--num-workers',
        default=0,
        type=int,
        help='Number of workers.')
    parser.add_argument(
        '--num-gpus',
        default=0,
        type=int,
        help='Number of GPUs.')
    parser.add_argument(
        '--temp-dir',
        default='~/tmp',
        type=str,
        help='Directory for temporary files generated by ray.')
    parser.add_argument(
        '--out',
        type=str,
        default=None,
        help='Path to saving plots.')
    args = parser.parse_args()

    # setup
    if os.path.isdir(args.checkpoint):
        args.checkpoint = misc.get_latest_checkpoint(args.checkpoint)
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, 'params.pkl')
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, '../params.pkl')
    
    with open(config_path, 'rb') as f:
        config = pickle.load(f)

    config['num_workers'] = args.num_workers
    config['num_gpus'] = args.num_gpus

    misc.register_custom_env(config['env'])
    misc.register_custom_model(config['model'])

    args.temp_dir = os.path.abspath(os.path.expanduser(args.temp_dir))
    ray.init(
        local_mode=True,
        _temp_dir=args.temp_dir,
        include_dashboard=False)

    # load from checkpoint
    cls = get_trainer_class(args.run)
    agent = cls(env=config['env'], config=config)
    if args.checkpoint:
        agent.restore(args.checkpoint)

    # plot graph
    ltc = agent.get_policy().model._rnn.to('cpu')
    for param in ltc.parameters():
        param.requires_grad = False
    legend_patches = draw_graph(ltc)
    plt.tight_layout()
    plt.legend(handles=legend_patches)
    plt.savefig(args.out)


def get_graph(ltc, include_sensory_neurons=False):
    DG = nx.DiGraph()
    for i in range(ltc.state_size):
        neuron_type = ltc._wiring.get_type_of_neuron(i)
        DG.add_node("neuron_{:d}".format(i), neuron_type=neuron_type)
    if include_sensory_neurons:
        for i in range(ltc.sensory_size):
            DG.add_node("sensory_{:d}".format(i), neuron_type="sensory")

    erev = ltc._params["erev"].numpy()
    sensory_erev = ltc._params["sensory_erev"].numpy()

    if include_sensory_neurons:
        for src in range(ltc.sensory_size):
            for dest in range(ltc.state_size):
                if ltc._wiring.sensory_adjacency_matrix[src, dest] != 0:
                    polarity = (
                        "excitatory" if sensory_erev[src, dest] >= 0.0 else "inhibitory"
                    )
                    DG.add_edge(
                        "sensory_{:d}".format(src),
                        "neuron_{:d}".format(dest),
                        polarity=polarity,
                    )

    for src in range(ltc.state_size):
        for dest in range(ltc.state_size):
            if ltc._wiring.adjacency_matrix[src, dest] != 0:
                polarity = "excitatory" if erev[src, dest] >= 0.0 else "inhibitory"
                DG.add_edge(
                    "neuron_{:d}".format(src),
                    "neuron_{:d}".format(dest),
                    polarity=polarity,
                )
    return DG


def draw_graph(
    ltc,
    layout=['kamada', 'circular', 'random', 'shell', 'spring', 'spectral', 'spiral'][1],
    neuron_colors=None,
    synapse_colors=None,
    draw_labels=False,
    draw_sensory_neurons=False,
):
    if isinstance(synapse_colors, str):
        synapse_colors = {
            "excitatory": synapse_colors,
            "inhibitory": synapse_colors,
        }
    elif synapse_colors is None:
        synapse_colors = {"excitatory": "tab:green", "inhibitory": "tab:red"}

    default_colors = {
        "inter": "tab:blue",
        "motor": "tab:orange",
        "command": "tab:olive",
    }
    if draw_sensory_neurons:
        default_colors["sensory"] = "tab:purple"
    if neuron_colors is None:
        neuron_colors = {}
    # Merge default with user provided color dict
    for k, v in default_colors.items():
        if not k in neuron_colors.keys():
            neuron_colors[k] = v

    legend_patches = []
    for k, v in neuron_colors.items():
        label = "{}{} neurons".format(k[0].upper(), k[1:])
        color = v
        legend_patches.append(mpatches.Patch(color=color, label=label))

    G = get_graph(ltc, draw_sensory_neurons)
    layouts = {
        "kamada": nx.kamada_kawai_layout,
        "circular": nx.circular_layout,
        "random": nx.random_layout,
        "shell": nx.shell_layout,
        "spring": nx.spring_layout,
        "spectral": nx.spectral_layout,
        "spiral": nx.spiral_layout,
    }
    if not layout in layouts.keys():
        raise ValueError(
            "Unknown layer '{}', use one of '{}'".format(
                layout, str(layouts.keys())
            )
        )
    pos = layouts[layout](G)

    # Draw neurons
    for i in range(ltc.state_size):
        node_name = "neuron_{:d}".format(i)
        neuron_type = G.nodes[node_name]["neuron_type"]
        neuron_color = "tab:blue"
        if neuron_type in neuron_colors.keys():
            neuron_color = neuron_colors[neuron_type]
        nx.draw_networkx_nodes(G, pos, [node_name], node_color=neuron_color)

    # Draw sensory neurons
    if draw_sensory_neurons:
        for i in range(ltc.sensory_size):
            node_name = "sensory_{:d}".format(i)
            neuron_color = "blue"
            if "sensory" in neuron_colors.keys():
                neuron_color = neuron_colors["sensory"]
            nx.draw_networkx_nodes(G, pos, [node_name], node_color=neuron_color)

    # Optional: draw labels
    if draw_labels:
        nx.draw_networkx_labels(G, pos)

    # Draw edges
    for node1, node2, data in G.edges(data=True):
        polarity = data["polarity"]
        edge_color = synapse_colors[polarity]
        nx.draw_networkx_edges(G, pos, [(node1, node2)], edge_color=edge_color)

    return legend_patches


if __name__ == '__main__':
    main()
