import networkx as nx
import matplotlib.pyplot as plt
import random
from networkx.algorithms.chains import chain_decomposition
import numpy as np
import re
from sys import argv


def rainbow_colors(n):
    hsv_colors = [(i / n, 1, 1) for i in range(n)]
    rgb_colors = [plt.cm.hsv(h) for h, _, _ in hsv_colors]
    return rgb_colors

# Step 1: generate random 2-connected graph
def create_random_biconnected_graph(n):
    G = nx.Graph()
    G.add_nodes_from(range(n))
    while not nx.is_connected(G) or nx.node_connectivity(G) < 2:
        i,j = random.sample(range(n), 2)
        G.add_edge(i, j)
    return G

# Step 2: Construct minimal biconnected subgraph
def construct_minimal_biconnected_subgraph(G, verbose):
    G_minimal = G.copy()
    for e in sorted(G.edges(), key=lambda  e_: G.adj[e_[0]][e_[1]]["length"], reverse=True):
        G_minimal.remove_edge(*e)
        if not (nx.node_connectivity(G_minimal) >= 2):
            G_minimal.add_edge(*e)
    return G_minimal

# Step 3: Ear decomposition
def ear_decomposition(G_minimal, verbose):
    chains = list(chain_decomposition(G_minimal))
    chains_nodes = [[u for u, _ in ch] + [ch[-1][1]] for ch in chains]
    nodes = {chains_nodes[0][0], chains_nodes[0][1]}
    for chain in chains_nodes:
        if len(nodes & set(chain)) != 2:
            raise ValueError("Ear decomposition failed: each ear must connect to the first ear at two nodes.")
        nodes |= set(chain)
    return chains, chains_nodes

# Step 4: Identify x and y^i
def identify_x_yi(G_minimal, chains_nodes, verbose):
    degree_dict = dict(G_minimal.degree())
    P0 = chains_nodes[0]
    x = next(v for v in P0 if degree_dict[v] == 2)

    def identify_y0(P0, x, degree_dict):
        for v in P0:
            if v != x and degree_dict[v] == 2:
                return v
        return None

    y0 = identify_y0(P0, x, degree_dict)
    ears = chains_nodes[1:]
    ys = [next(v for v in e[1:-1] if degree_dict[v] == 2) for e in ears]
    ys_full = [y0] + ys
    return x, ys_full

# Step 5: Construct Eulerian graph GG
def construct_eulerian_graph(G_minimal, chains_nodes, ys_full, verbose):
    GG = nx.MultiGraph()
    GG.add_nodes_from(G_minimal.nodes)
    removed_edges = []

    reversed_ears = chains_nodes[::-1]
    reversed_ys = ys_full[::-1]

    for i, ear in enumerate(reversed_ears):
        for u, v in zip(ear[:-1], ear[1:]):
            GG.add_edge(u, v, color=G_minimal.edges[u, v].get('color', (0, 0, 0, 1)))
        y_i = reversed_ys[i]
        j_y = ear.index(y_i)
        if verbose:
            print(f"Processing ear {i}: {ear}, y_i = {y_i}; index_y = {j_y}")
        for j in range(1, len(ear) - 1):
            u, v = ear[j], ear[j + 1]
            if GG.degree(u) % 2 == 1:
                if verbose:
                    print(f"Doubled edge: ({u}, {v}), GG.adj[{u}]={list(GG.adj[u].keys())}, GG.degree({u})={GG.degree(u)}")
                GG.add_edge(u, v,  color=GG.edges[u, v, 0].get('color', (0, 0, 0, 1)))
        u_last, v_last = ear[-2], ear[-1]
        if GG.number_of_edges(u_last, v_last) == 2:
            GG.remove_edges_from([(u_last, v_last)] * 2)
            removed_edges.append((u_last, v_last))
            if verbose:
                print(f"Removed edge: ({u_last}, {v_last})")
        else:
            if GG.number_of_edges(ear[j_y], ear[j_y+1]) == 2:
                GG.remove_edges_from([(ear[j_y], ear[j_y+1])] * 2)
                removed_edges.append((ear[j_y], ear[j_y+1]))
                if verbose:
                    print(f"Removed edge: ({ear[j_y]}, {ear[j_y+1]})")

    # check if GG is Eulerian
    if not nx.is_eulerian(GG):
        raise ValueError("Constructed graph GG is not Eulerian.")

    # check if each vertex has even degree
    for node in GG.nodes:
        if GG.degree(node) % 2 != 0:
            raise ValueError(f"Node {node} has odd degree {GG.degree(node)} in GG.")

    return GG, removed_edges

# Step 6: Orient edges by ears
def orient_edges_by_ears(GG, chains_nodes, x, ys_full, removed_edges, verbose):
    DG = nx.DiGraph()
    DG.add_nodes_from(GG.nodes)

    reversed_chains = chains_nodes[::-1]
    reversed_ys = ys_full[::-1]

    for i, ear in enumerate(reversed_chains):
        y_i = reversed_ys[i]
        u_last, v_last = ear[-2], ear[-1]

        def add_directed_edges(u, v):
            if GG.has_edge(u, v):
                for _ in range(GG.number_of_edges(u, v)):
                    DG.add_edge(u, v)

        # Case (a)
        if ((u_last, v_last) in removed_edges or (v_last, u_last) in removed_edges):
            for j in range(len(ear) - 1):
                u, v = ear[j], ear[j + 1]
                add_directed_edges(u, v)
                if GG.number_of_edges(u,v) == 2:
                    add_directed_edges(v, u)
        else:
            # Case (b) or (c)
            y_index = ear.index(y_i)
            for j in range(0, y_index):
                u, v = ear[j], ear[j + 1]
                add_directed_edges(u, v)
                if GG.number_of_edges(u,v) == 2:
                    add_directed_edges(v, u)
            for j in range(len(ear) - 1, y_index, -1):
                u, v = ear[j], ear[j - 1]
                add_directed_edges(u, v)
                if GG.number_of_edges(u,v) == 2:
                    add_directed_edges(v, u)

    # check max in_degree is not more than 2
    for node in DG.nodes:
        if DG.in_degree(node) > 2:
            raise ValueError(f"Node {node} has in-degree {DG.in_degree(node)}, which is greater than 2.")
    return DG

# Step 7: Contract digraph DG
def contract_digraph(DG, x, verbose):
    GG_c = nx.MultiGraph()
    GG_c.add_nodes_from(DG.nodes)
    GG_c.add_edges_from(DG.edges)
    contraction_map = {}  # (u, v) in GG_c → [(w, u), (w, v)] in DG

    for w in DG.nodes:
        if w == x:
            continue
        in_edges = list(DG.in_edges(w))
        if len(in_edges) == 2:
            (u, w), (v, w2) = in_edges
            u, v = min(u, v), max(u, v)
            if w2 != w: continue  
            if (u,v) not in contraction_map:
                GG_c.add_edge(u, v)
                contraction_map[(u, v)] = [w]
                GG_c.remove_edges_from([(u, w), (v, w)])
            else:
                GG_c.add_edge(u, v)
                contraction_map[(u, v)].append(w)
                GG_c.remove_edges_from([(u, w), (v, w)])
            if verbose:
                print(f"Contracted edge: ({u} -> {w} <- {v}) to ({u}, {v})")
    
    n_cont = sum([len(contraction_map[u,v]) for u,v in contraction_map])
    rainbow = rainbow_colors(n_cont)
    i = 0
    for u, v in contraction_map:
        for w in contraction_map[u, v]: # (u, w), (v, w)
            DG.adj[u][w]["color"] = rainbow[i]
            DG.adj[v][w]["color"] = rainbow[i]
            GG_c.adj[u][v][0]["color"] = rainbow[i]
            if len(GG_c.adj[u][v]) > 1:
                breakpoint()
            i += 1
    return GG_c, contraction_map

# Step 8: Find Eulerian cycle on contracted graph
def find_eulerian_cycle_on_contracted_graph(GG, GG_c, DG, contraction_map, x, verbose):
    DG_for_check = DG.copy()
    GG_c.remove_nodes_from([n for n, d in GG_c.degree() if d == 0])
    if not nx.is_eulerian(GG_c):
        raise ValueError("Graph GG_c is not Eulerian.")
    J_c = list(nx.eulerian_circuit(GG_c, source=x))

    n_c = len(J_c)
    rainbow = rainbow_colors(n_c) 
    for i, (u, v) in enumerate(J_c):
        k = 0 if "color_2" not in GG_c.adj[u][v][0] else min([k for k in GG_c.adj[u][v] if "color_2" not in GG_c.adj[u][v][k]])
        GG_c.adj[u][v][k]["color_2"] = rainbow[i]

    J = []
    J_str = f"{J_c[0][0]}"
    v_pre = None
    if verbose:
        print(J_c[0][0], end="")
    contraction_map_copy = contraction_map.copy()
    for u, v in J_c:
        if (u, v) in contraction_map or (v, u) in contraction_map:
            u, v = (v, u) if (u, v) not in contraction_map else (u, v)
            w = contraction_map[u,v][0]
            contraction_map[u,v] = contraction_map[u,v][1:]
            if v_pre == v: (u, v) = (v, u) 
            direction_uw = "->"
            direction_wv = "<-"
            if (u,w) in DG_for_check.edges():
                DG_for_check.remove_edge(u,w)
                J.append((u,w))
                print("a", u, w, J)
            else:
                direction_uw = "<-"
                DG_for_check.remove_edge(w,u)
                J.append((w,u))
                print("b", w, u, J)
            if (v,w) in DG_for_check.edges():
                DG_for_check.remove_edge(v,w)
                J.append((v,w))
                print("c", v, w, J)
            else:
                direction_wv = "->"
                DG_for_check.remove_edge(w,v)
                J.append((w,v))
                print("d", w, v, J)
            if verbose:
                print(f" {direction_uw} {w} {direction_wv} {v}", end="")
            J_str += f" {direction_uw} {w} {direction_wv} {v}"
            v_pre = v
        else:
            if v_pre == v: (u, v) = (v, u)
            direction_uv = "->"
            if (u,v) in DG_for_check.edges():
                DG_for_check.remove_edge(u,v)
                J.append((u, v))
            else:
                direction_uv = "<-"
                DG_for_check.remove_edge(v,u)
                J.append((v, u))
            if verbose:
                print(f" {direction_uv} {v}", end="")
            J_str += f" {direction_uv} {v}"
            v_pre = v
    if verbose:
        print()

    GG_colored = nx.MultiDiGraph()
    GG_colored.add_nodes_from(GG.nodes)
    n = len(J)
    rainbow = rainbow_colors(n)
    for i, (u, v) in enumerate(J):
        GG_colored.add_edge(u, v, color=rainbow[i])

    n_v = len([v for v in GG_colored.nodes() if GG_colored.out_degree[v] == 2])
    rainbow = rainbow_colors(n_v)
    i = 0

    return GG_colored, J

# Step 9: Lift edges to get Hamiltonian cycle in G^2
def lift_edges(J, x, GG_colored, verbose):
    H_pre = nx.Graph()
    H = nx.DiGraph()

    H_list = []
    i = 0
    colored_dict = {}
    while i < len(J) - 1:
        a, b = J[i]
        c, d = J[i + 1]
        if a == c:  # (i, j), (i, k)
            H_pre.add_edge(b, d)
            i += 2
            colored_dict[b,d] = [(a,b), (c,d)]
        else:
            H_pre.add_edge(a, b)
            i += 1

    if i == len(J) - 1:
        H_pre.add_edge(*J[-1])
    H_list = list(nx.eulerian_circuit(H_pre, source=x))

    n_colored = len(colored_dict)
    rainbow = rainbow_colors(n_colored)
    count = {(u,v):0 for key in colored_dict for u,v in colored_dict[key]}
    i = 0
    for (u,v) in H_list:
        color = "black"
        if (u,v) in colored_dict or (v,u) in colored_dict:
            key = (u,v) if (u,v) in colored_dict else (v,u) 
            for e in colored_dict[key]:
                k = count[e]
                color = rainbow[i]
                GG_colored[e[0]][e[1]][k]["color_2"] = color
                count[e] += 1
            i += 1
        H.add_edge(u,v, color=color)

    return H, H_list


# Step final: Visualization
def plot_graphs_side_by_side_colored_and_multiedges(Gs, titles, pos, x=None, ys=None, chains=None, color_from_step=2):
    fig, axes = plt.subplots(1, len(Gs), figsize=(4 * len(Gs), 4))
    colors = plt.cm.get_cmap('jet', len(chains) if chains else 1)

    for i, (G, title) in enumerate(zip(Gs, titles)):
        ax = axes[i]
        node_colors = []
        for node in G.nodes():
            if i >= color_from_step:
                if node == x:
                    node_colors.append('red')
                elif ys and node in ys:
                    node_colors.append('cyan')
                else:
                    node_colors.append('lightgray')
            else:
                node_colors.append('lightgray')

        nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, edgecolors='k', node_size=500)
        nx.draw_networkx_labels(G, pos, ax=ax)

        if title in {"Step 1: Original Graph", "Step 2: Minimal 2-Connected"}:
            nx.draw_networkx_edges(G, pos, edge_color='gray', width=1.5, ax=ax)

        elif title == "Step 3–4: Ear Decomposition with $x, y^i$" and chains:
            for j, chain in enumerate(chains):
                edge_list = list(nx.utils.pairwise([u for u, _ in chain] + [chain[-1][1]]))
                nx.draw_networkx_edges(G, pos, edgelist=edge_list, edge_color=[colors(j)], width=2.5, ax=ax)

        elif title in {"Step 5: Eulerian Graph $G_G$", "Step 7: Contracted Oriented Graph $GG_c$", "Step 8: Eulerian Cycle on $G_G$"}:
            offset = 0.3
            drawn = set()

            for u, v, k in G.edges(keys=True):
                if (u, v, k) not in drawn:
                    edge_color = G.edges[u, v, k].get('color', (0,0,0,1))
                    rad = offset * (k - 0.5)
                    nx.draw_networkx_edges(
                        G, pos, edgelist=[(u, v)],
                        connectionstyle=f'arc3,rad={rad}',
                        edge_color=edge_color, ax=ax, width=2
                    )
                    drawn.add((u, v, k))
        elif title in {"Step 7.5: Eulerian Cycle on $GG_c$", "Step 8.5"}: 
            offset = 0.3
            drawn = set()
            for u, v, k in G.edges(keys=True):
                if (u, v, k) not in drawn:
                    edge_color = G.edges[u, v, k].get('color_2', (0,0,0,1))
                    rad = offset * (k - 0.5)
                    nx.draw_networkx_edges(
                        G, pos, edgelist=[(u, v)],
                        connectionstyle=f'arc3,rad={rad}',
                        edge_color=edge_color, ax=ax, width=2
                    )
                    drawn.add((u, v, k))
        elif title in {"Step 6: Oriented Graph $\\vec{G}_G$", "Step 9: Hamiltonian Cycle on $G^2$"}:
            drawn_edges = set()
            for u, v in G.edges():
                edge_color = G.edges[u, v].get('color', (0,0,0,1))
                if (v, u) in G.edges() and (v, u) not in drawn_edges:
                    nx.draw_networkx_edges(
                        G, pos, edgelist=[(u, v)],
                        connectionstyle="arc3,rad=0.2",
                        edge_color=edge_color, arrows=True,
                        arrowstyle='-|>', width=2, ax=ax
                    )
                    edge_color = G.edges[v, u].get('color', (0,0,0,1))
                    nx.draw_networkx_edges(
                        G, pos, edgelist=[(v, u)],
                        connectionstyle="arc3,rad=0.2",
                        edge_color=edge_color, arrows=True,
                        arrowstyle='-|>', width=2, ax=ax
                    )
                    drawn_edges.add((u, v))
                    drawn_edges.add((v, u))
                elif (u, v) not in drawn_edges:
                    edge_color = G.edges[u, v].get('color', (0,0,0,1))
                    nx.draw_networkx_edges(
                        G, pos, edgelist=[(u, v)],
                        connectionstyle="arc3,rad=0",
                        edge_color=edge_color, arrows=True,
                        arrowstyle='-|>', width=2, ax=ax
                    )
                    drawn_edges.add((u, v))
        else:
            pass

        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout()
    plt.savefig("procedure.pdf", dpi=400)

# ---------- Main function ----------
def find_hamiltonian_cycle(G=None, pos=None, vis=False, verbose=False):
    if G == None:
        G = create_random_biconnected_graph(11)
    else:
        assert nx.node_connectivity(G) >= 2
    if verbose:
        print("G.edges() =", G.edges())
    if pos == None:
        pos = nx.spring_layout(G, seed=42)
    for i,j in G.edges():
        if "length" in G.adj[i][j]: break
        G.adj[i][j]["length"] = np.linalg.norm(pos[i] - pos[j])

    # Step 2
    G_minimal = construct_minimal_biconnected_subgraph(G, verbose)
    if verbose:
        print("G_minimal.edges() =", G_minimal.edges())

    # Step 3
    chains, chains_nodes = ear_decomposition(G_minimal, verbose)
    if verbose:
        print("chains =", "\n".join([str(chain) for chain in chains]))
    print(f"G has {len(chains)} chains")
    if len(chains) <= 2: exit()
    if max([len(chain) for chain in chains[1:]]) <= 3: exit()
    G_minimal_colored = G_minimal.copy()
    colors = plt.cm.get_cmap('tab10', len(chains) if chains else 1)
    for i, chain in enumerate(chains):
        for u, v in chain:
            G_minimal_colored.edges[u, v]['color'] = colors(i)

    # Step 4: x, y^i
    x, ys_full = identify_x_yi(G_minimal, chains_nodes, verbose)
    if verbose:
        print(f"x={x}, ys_full={ys_full}")

    # Step 5: Construct Eulerian graph GG
    GG, removed_edges = construct_eulerian_graph(G_minimal_colored, chains_nodes, ys_full, verbose)
    has_parallel = any(len(GG[u][v]) > 1 for u, v in GG.edges())
    if not has_parallel: 
        exit()
    else:
        print("comp")
    if verbose:
        print(f"GG.edges() = {GG.edges()}")
    # Step 6: Orient edges by ears
    DG = orient_edges_by_ears(GG, chains_nodes, x, ys_full, removed_edges, verbose)
    if verbose:
        print(f"DG.edges() = {DG.edges()}")

    # # Step 7: Contract digraph DG
    GG_c, contraction_map = contract_digraph(DG, x, verbose)
    if verbose:
        print(f"GG_c.edges() = {GG_c.edges()}")

    # Step 8: Find Eulerian cycle on contracted graph
    GG_colored, J = find_eulerian_cycle_on_contracted_graph(GG, GG_c, DG, contraction_map, x, verbose)
    if verbose:
        print(f"Eulerian cycle on GG: {GG_colored.edges()}")

    # Step 9: Lift edges to get Hamiltonian cycle in G^2
    H, H_list = lift_edges(J, x, GG_colored, verbose)

    if not vis: return H_list
    print("tour:", H_list)
    plot_graphs_side_by_side_colored_and_multiedges(
        [
            G, 
            G_minimal, 
            G_minimal, 
            GG, 
            DG, 
            GG_c, 
            GG_c, 
            GG_colored,
            GG_colored,
            H
        ],
        [
            "Step 1: Original Graph",
            "Step 2: Minimal 2-Connected",
            "Step 3–4: Ear Decomposition with $x, y^i$",
            "Step 5: Eulerian Graph $G_G$",
            "Step 6: Oriented Graph $\\vec{G}_G$",
            "Step 7: Contracted Oriented Graph $GG_c$",
            "Step 7.5: Eulerian Cycle on $GG_c$",
            "Step 8: Eulerian Cycle on $G_G$",
            "Step 8.5",
            "Step 9: Hamiltonian Cycle on $G^2$"
        ],
        pos, x=x, ys=ys_full, chains=chains, color_from_step=2
    )

if __name__ == "__main__":
    random.seed(int(argv[1]))
    find_hamiltonian_cycle(vis=True, verbose=False)