import networkx
import pydot
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout
import cvxpy as cp
import numpy as np
from numpy import random as rand


class TreExplore(networkx.DiGraph):
    def __init__(self, equalities, action_number, tree_depth=5, verbose=True):
        super(TreExplore, self).__init__()

        if verbose:
            print(f"Constructing graph from action equalities {equalities}")

        self.equalities = equalities
        self.action_number = action_number
        self.tree_depth = tree_depth
        self.all_lists = []
        for t in equalities:
            self.all_lists.append(t[0])
            self.all_lists.append(t[1])
        self.current_depth = 0
        self.add_node(0)

        self.optimized = []
        self.verbose = verbose

    def index_matching_seq(self, index):
        return index + 1 if index % 2 == 0 else index - 1

    def matching_seq(self, index):
        matching_index = self.index_matching_seq(index)
        return self.all_lists[matching_index]

    def first_to_finish(self, node, index):
        sequ_to_match_index = self.nodes[node]["sequence_index"][index]
        sequ_to_match_clone = self.matching_seq(sequ_to_match_index)
        clone = self.nodes[node]["clone"][index]
        clone_index = self.nodes[node]["clone_index"][index]
        current_seq_clone = self.nodes[clone]["current_sequence"][clone_index]
        return len(current_seq_clone) < len(sequ_to_match_clone)

    def seq_is_first(self, index):
        if index % 2 == 0:
            return len(self.all_lists[index]) <= len(self.all_lists[index + 1])
        else:
            return len(self.all_lists[index]) < len(self.all_lists[index - 1])

    def add_node(self, node_for_adding, **attr):
        super(TreExplore, self).add_node(node_for_adding, **attr)
        self.nodes[node_for_adding]["current_sequence"] = []
        self.nodes[node_for_adding]["clone"] = []
        self.nodes[node_for_adding]["clone_index"] = []
        self.nodes[node_for_adding]["sequence_index"] = []
        for i, e in enumerate(self.all_lists):
            self.nodes[node_for_adding]["current_sequence"].append([])
            self.nodes[node_for_adding]["sequence_index"].append(i)
            self.nodes[node_for_adding]["clone"].append(node_for_adding)
            self.nodes[node_for_adding]["clone_index"].append(
                self.index_matching_seq(i)
            )

    def add_attributes(self, node, attributes):
        current_index = len(self.nodes[node]["clone"])
        for (cur_seq, seq_i, clone, clone_i) in zip(
            attributes["current_sequence"],
            attributes["sequence_index"],
            attributes["clone"],
            attributes["clone_index"],
        ):

            self.nodes[node]["current_sequence"].append(cur_seq)
            self.nodes[node]["sequence_index"].append(seq_i)
            self.nodes[node]["clone"].append(clone)
            self.nodes[node]["clone_index"].append(clone_i)

            self.nodes[clone]["clone"][clone_i] = node
            self.nodes[clone]["clone_index"][clone_i] = current_index
            current_index += 1

    def add_double_attributes(self, node, attributes):
        current_index = len(self.nodes[node]["clone"])
        for (cur_seq, seq_i, clone, clone_i) in zip(
                attributes["current_sequence"],
                attributes["sequence_index"],
                attributes["clone_sequence"],
                attributes["clone_index"],
        ):

            self.nodes[node]["current_sequence"].append(cur_seq)
            self.nodes[node]["current_sequence"].append(clone)
            self.nodes[node]["sequence_index"].append(seq_i)
            self.nodes[node]["sequence_index"].append(clone_i)
            self.nodes[node]["clone"].append(node)
            self.nodes[node]["clone"].append(node)
            self.nodes[node]["clone_index"].append(current_index + 1)
            self.nodes[node]["clone_index"].append(current_index)
            current_index += 2

    def expand(self, node, action):
        is_new_node = True
        clones_to_change = []
        sequ_to_copy = []
        node_attributes = {
            "current_sequence": [],
            "sequence_index": [],
            "clone": [],
            "clone_index": [],
        }

        for i, l in enumerate(self.nodes[node]["current_sequence"]):
            new_l = l + [action]
            sequ_to_match_index = self.nodes[node]["sequence_index"][i]
            sequ_to_match = self.all_lists[sequ_to_match_index]
            sequ_clone_index = self.matching_seq(sequ_to_match_index)

            if len(l) >= len(sequ_to_match):
                pass

            elif sequ_to_match[len(l)] == action:

                clone_node = self.nodes[node]["clone"][i]
                # is_first = self.seq_is_first(sequ_to_match_index)
                is_first = self.first_to_finish(node, i)

                finished = len(new_l) == len(sequ_to_match)
                if finished:
                    is_new_node = is_new_node and is_first
                    if not is_first:
                        real_clone = clone_node

                node_attributes["current_sequence"].append(new_l)
                node_attributes["sequence_index"].append(sequ_to_match_index)
                node_attributes["clone"].append(self.nodes[node]["clone"][i])
                node_attributes["clone_index"].append(
                    self.nodes[node]["clone_index"][i]
                )

        # border case when clones are the same node
        # look for double

        node_attributes_unique = {
            "current_sequence": [],
            "sequence_index": [],
            "clone": [],
            "clone_index": [],
        }

        node_attributes_double = {
            "current_sequence": [],
            "clone_sequence": [],
            "sequence_index": [],
            "clone_index": [],
        }

        for i, (cur_seq, seq_i, clone, clone_i) in enumerate(zip(
                node_attributes["current_sequence"],
                node_attributes["sequence_index"],
                node_attributes["clone"],
                node_attributes["clone_index"],
        )):
            current_seq_clone = self.nodes[clone]["current_sequence"][clone_i]
            sequ_to_match_index = self.nodes[clone]["sequence_index"][clone_i]
            sequ_to_match_clone = self.all_lists[sequ_to_match_index]
            if clone == node and len(current_seq_clone) < len(sequ_to_match_clone) and sequ_to_match_clone[len(current_seq_clone)] == action:
                # double found

                # check not already here:
                is_there = False
                for (cur_seq_d, seq_i_d, clone_d, clone_i_d) in zip(
                        node_attributes_double["current_sequence"],
                        node_attributes_double["sequence_index"],
                        node_attributes_double["clone_sequence"],
                        node_attributes_double["clone_index"],
                ):
                    is_there_aux = cur_seq == clone_d and seq_i_d == sequ_to_match_index
                    is_there = is_there or is_there_aux
                if not is_there:
                    node_attributes_double["current_sequence"].append(cur_seq)
                    node_attributes_double["clone_sequence"].append(current_seq_clone + [action])
                    node_attributes_double["sequence_index"].append(seq_i)
                    node_attributes_double["clone_index"].append(sequ_to_match_index)

            else:
                node_attributes_unique["current_sequence"].append(cur_seq)
                node_attributes_unique["sequence_index"].append(seq_i)
                node_attributes_unique["clone"].append(clone)
                node_attributes_unique["clone_index"].append(clone_i)

        if is_new_node:
            current_node_number = len(self.nodes)
            self.add_node(current_node_number)
            # node_next_depth.append(current_node_number)
        else:
            current_node_number = real_clone
        self.add_attributes(current_node_number, node_attributes_unique)
        self.add_double_attributes(current_node_number, node_attributes_double)
        self.add_edge(node, current_node_number, action=action)

    def expand_length(self):
        new_nodes = [0]
        while self.current_depth < self.tree_depth and len(new_nodes) > 0:
            n_nodes = len(self.nodes)
            for node in new_nodes:
                for action in range(self.action_number):
                    self.expand(node, action)

            new_nodes = list(range(n_nodes, len(self.nodes)))
            self.current_depth += 1

    def draw(self, attribute="action", objective=None, objective_draw=True):
        pos = graphviz_layout(self, prog="dot")
        if len(self.optimized) > 0:
            if objective is None:
                objective = self.optimized[-1]
                weights = list(networkx.get_edge_attributes(self, objective).values())
        else:
            weights = None
        if objective_draw:
            networkx.draw(
                self,
                pos,
                with_labels=True,
                width=weights,
                connectionstyle="arc2, rad = 0.4",
            )
        else:
            networkx.draw(
                self, pos, with_labels=True, connectionstyle="arc3, rad = 0.1"
            )
        edge_labels = networkx.get_edge_attributes(self, attribute)
        networkx.draw_networkx_edge_labels(
            self, pos, edge_labels=edge_labels, label_pos=0.7, font_size=7
        )
        plt.show()

    def _add_edge_attributes(self, attribute, values):
        for i, edge in enumerate(self.edges):
            self.edges[edge[0], edge[1]][attribute] = values[i]

    def find_opt(self, objective="node_balance"):

        assert objective in [
            "node_balance",
            "leaf_balance",
            "transition_balance",
        ], f"Objective {objective} not recognised"

        # dirty fix to loops
        to_remove = []
        for edge in self.edges:
            if edge[1] <= edge[0]:
                to_remove.append(edge)
        for edge in to_remove:
            self.remove_edge(edge[0], edge[1])

        self.edge_index = {}
        for k in range(len(self.edges)):
            self.edge_index[list(self.edges)[k]] = k

        x = cp.Variable(shape=len(self.edges))

        # flow constraints
        constraint_matrix = np.zeros((len(self.nodes), len(self.edges)))
        for node in self.nodes:  # assume node equal to its index
            if len(self.out_edges(node)) > 0:
                for in_edge in self.in_edges(node):
                    constraint_matrix[node, self.edge_index[in_edge]] = -1
                for out_edge in self.out_edges(node):
                    constraint_matrix[node, self.edge_index[out_edge]] = 1

        b = np.zeros(len(self.nodes))
        b[0] = 1

        constraints = [constraint_matrix @ x == b, x >= 0, x <= 1]

        if objective == "transition_balance":
            P = np.identity(len(self.edges))
        else:
            P = np.zeros((len(self.nodes), len(self.edges)))
            for node in self.nodes:  # assume node equal to its index
                if len(self.in_edges(node)) > 0 and (  # check not root
                    objective == "node_balance"
                    or len(self.out_edges(node))
                    == 0  # check if leaf unless objective is node_balance
                ):
                    for in_edge in self.in_edges(node):
                        P[node, self.edge_index[in_edge]] = 1

        if objective == "node_balance" or objective == "transition_balance":
            P = P / self.tree_depth

        obj = cp.Maximize(cp.sum(cp.entr(P @ x)))
        prob = cp.Problem(obj, constraints)
        prob.solve(verbose=self.verbose)
        print(f"Found optimum for {objective} to depth {self.tree_depth}: {x.value}")
        self._add_edge_attributes(objective, x.value)
        self.optimized.append(objective)

    def node_from_actions(self, actions):
        node = 0
        assert (
            len(actions) <= self.tree_depth
        ), f"Too many actions ({len(actions)}) for depth of tree {self.tree_depth}"
        for a in actions:
            for next_node in self.neighbors(node):
                if self.edges[node, next_node]["action"] == a:
                    break
            node = next_node
        return node

    def flow_from_node(self, node, action, objective="node_balance"):
        for next_node in self.neighbors(node):
            if self.edges[node, next_node] == action:
                break
        return self.edges[node, next_node][objective]

    def flow_to_node(self, node, objective="node_balance"):
        s = 0
        for in_edge in self.in_edges(node):
            s += self.edges[in_edge][objective]
        return s

    def action_probabilities(self, node, objective="node_balance"):
        p_actions = np.zeros(self.action_number)
        for e in self.out_edges(node):
            action = self.edges[e[0], e[1]]["action"]
            p = self.edges[e[0], e[1]][objective]
            p_actions[action] = p
        return p_actions / np.sum(p_actions)

    def reset_traj(self):
        self.current_node = 0

    def sample_next(self):
        p = self.action_probabilities(self.current_node)
        action = rand.choice(np.arange(self.action_number), p=p)
        if len(self.out_edges(self.current_node)) > 0:
            for next_node in self.neighbors(self.current_node):
                if self.edges[self.current_node, next_node]["action"] == action:
                    break
            self.current_node = next_node
        return action

if __name__ == "__main__":
    # EQUALITIES = [([0], [])]
    EQUALITIES = [([0, 1], [1, 0]), ([2], [2, 2])]
    # EQUALITIES = [([0, 0], []), ([1, 1], [])]
    # EQUALITIES = [([0, 0], [1, 1])]
    # EQUALITIES = [([0, 1], [1, 0]), ([0, 2], [2, 0]), ([1, 2], [2, 1])]

    # EQUALITIES = [([0, 1], [1, 0]), ([0, 2], [2, 0]), ([1, 2], [2, 1]),
    #               ([0, 3], [3, 0]), ([3, 1], [1, 3]), ([3, 2], [2, 3])]

    # EQUALITIES = [([0], [0, 0])]

    # EQUALITIES = [([0, 1], [1, 0]), ([0, 2], [2, 0]), ([1, 2], [2, 1])]
    # EQUALITIES = [([1, 0], []), ([0, 2], [2, 0]), ([1, 2], [2, 1]),
    #           ([0, 3], [3, 0]), ([3, 1], [1, 3]), ([3, 2], []), ([], [2, 3]),
    #               ([0, 1], [])]
    # EQUALITIES = [([1, 0], []), ([], [0, 1])]

    action_number = 3
    g = TreExplore(EQUALITIES, action_number, tree_depth=3)
    g.expand_length()
    g.find_opt()
    # g.find_opt(objective='transition_balance')
    g.draw(objective_draw=False)
