import numpy as np
import pydot 
import io
import matplotlib.image as mpimg
from matplotlib.cm import coolwarm_r
from matplotlib.colors import rgb2hex
from .utils import *


def rules(tree, pred_dims=None, sf=3, dims_as_indices=True, out_name=None): 
    """
    Represent tree as a rule set with pred_dims as the consequent. Formatted as valid Python code.
    """
    pred_dims = tree.space.idxify(pred_dims)
    lines = []
    def _recurse(node, depth=0):
        i = "    " * (depth+1) # Indent.     
        if node is None: lines.append(f"{i}return None")       
        elif node.split_dim is not None:
            dim_name = tree.space.dim_names[node.split_dim]
            if dims_as_indices:
                dim_text, comment = f"x[{node.split_dim}]", f" # {dim_name}"
            else: 
                dim_text, comment = dim_name, ""
            lines.append(f"{i}if {dim_text} < {round_sf(node.split_threshold, sf)}:{comment}")
            _recurse(node.left, depth+1)
            lines.append(f"{i}else:")
            _recurse(node.right, depth+1)
        else: 
            if pred_dims:
                lines.append(f"{i}return {round_sf(node.mean[pred_dims], sf)} # n={node.num_samples}, std={round_sf(np.sqrt(np.diag(node.cov)[pred_dims]), sf)}")
            else: lines.append(f"{i}return # n={node.num_samples}")
    _recurse(tree.root)
    lines.insert(0, f"def {tree.name}(x):")
    if out_name is not None:  # If out_name specified, write out.
        with open(out_name, "w", encoding="utf-8") as f:
            for l in lines: f.write(l+"\n")
    return "\n".join(lines)

def diagram(tree, pred_dims=None, colour_dim=None, cmap_lims=None,
            show_decision_node_preds=False, show_num_samples=False, show_std_rng=False, show_impurity=False,
            sf=3, out_as="svg", out_name=None, size=None):
    """
    Represent tree as a pydot diagram with pred_dims and the consequent.
    """
    if pred_dims is not None: pred_dims = pred_dims = tree.space.idxify(pred_dims)
    if colour_dim is not None:
        colour_dim = tree.space.idxify(colour_dim)
        leaf_means = tree.gather(("mean", colour_dim))
        if cmap_lims is None: cmap_lims = (min(leaf_means), max(leaf_means))
        colour = lambda node: rgb2hex(_values_to_colours([node.mean[colour_dim]], (coolwarm_r, "coolwarm_r"), cmap_lims)[0])
    graph_spec = 'digraph Tree {nodesep=0.2; ranksep=0.2; node [shape=box];'
    def _recurse(node, graph_spec, n=0, n_parent=0, dir_label=None):
        if node is None: graph_spec += f'{n} [label="None"];'
        else:
            if node.split_dim is None:
                c = colour(node) if colour_dim is not None else "white"
                leaf_num = tree.leaves.index(node)+1
            else:
                c = colour(node) if (colour_dim is not None and show_decision_node_preds) else "gray"
                split = f'{tree.space.dim_names[node.split_dim]}≥{round_sf(node.split_threshold, sf)}?'
            graph_spec += f'{n} [style=filled, fontname="sans-serif", fillcolor="{c}", label="'
            if node.split_dim is None or show_decision_node_preds:
                if node.split_dim is None: graph_spec += f'({leaf_num}) '
                if pred_dims:
                    for d, (mean, std, rng) in enumerate(zip(node.mean[pred_dims], np.sqrt(np.diag(node.cov)[pred_dims]), node.hr_min[pred_dims])):
                        graph_spec += f'{tree.space.dim_names[pred_dims[d]]}={round_sf(mean, sf)}'
                        if show_std_rng: graph_spec += f' (s={round_sf(std, sf)},r={round_sf(rng, sf)})'
                if show_num_samples: graph_spec += f'\nn={node.num_samples}'
                if pred_dims and show_impurity:
                    imp = f"{np.dot(node.var_sum[pred_dims], tree.space.global_var_scale[pred_dims]):.2E}"
                    graph_spec += f'\nimpurity: {imp}'
                if node.split_dim is not None:
                    graph_spec += f'\n-----\n{split}'
            else: 
                graph_spec += f'{split}'
            graph_spec += '"];'
            n_here = n
            if n_here > 0: # Make edge from parent.
                graph_spec += f'{n_parent} -> {n} [fontname="sans-serif", label=" {dir_label} "];' # Spaces add padding
            n += 1
            if node.split_dim is not None: # Recurse to children.
                graph_spec, n = _recurse(node.left, graph_spec, n, n_here, "No")
                graph_spec, n = _recurse(node.right, graph_spec, n, n_here, "Yes")
        return graph_spec, n
    # Create and save pydot graph.    
    graph_spec, _ = _recurse(tree.root, graph_spec)
    (graph,) = pydot.graph_from_dot_data(graph_spec+'}') 
    if size is not None: graph.set_size(f"{size[0]},{size[1]}!")
    if out_as == "png":   graph.write_png(f"{out_name if out_name is not None else tree.name}.png") 
    elif out_as == "svg": graph.write_svg(f"{out_name if out_name is not None else tree.name}.svg") 
    elif out_as == "plt": # https://stackoverflow.com/a/18522941
        png_str = graph.create_png()
        sio = io.BytesIO()
        sio.write(png_str)
        sio.seek(0)
        return mpimg.imread(sio)
    else: raise ValueError()

def _values_to_colours(values, cmap, cmap_lims):
    # Compute fill colour.
    if cmap_lims is None: mn, mx = np.min(values), np.max(values)
    else: mn, mx = cmap_lims
    if mx == mn: colours = [cmap[0](0.5) for _ in values] # Default to midpoint.
    else: colours = [cmap[0](v) for v in (np.array(values) - mn) / (mx - mn)]
    return colours
    