import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# #### Visualization Description ####
# Nodes: Size = Weighted Degree; Red = Critical Vitals, Blue = Others.
# Edges: Width = Interaction Strength.
#        - Weak edges (low alpha): Background noise.
#        - Strong edges (high alpha): Main structural backbone.
# Layout: Fixed spring layout for consistent comparison across subplots.

# Set fonts (support Chinese/Unicode)
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']
plt.rcParams['axes.unicode_minus'] = False

# Variable names mapping
var_names = [
    'HR',            # 0
    'O2Sat',         # 1
    'Temp',          # 2
    'SBP',           # 3
    'MAP',           # 4
    'DBP',           # 5
    'Resp',          # 6
    'EtCO2',         # 7
    'BaseExcess',    # 8
    'HCO3',          # 9
    'FiO2',          # 10
    'pH',            # 11
    'PaCO2',         # 12
    'SaO2',          # 13
    'AST',           # 14
    'BUN',           # 15
    'Alkalinephos',  # 16
    'Calcium',       # 17
    'Chloride',      # 18
    'Creatinine',    # 19
    'Bilirubin_direct', # 20
    'Glucose',       # 21
    'Lactate',       # 22
    'Magnesium',     # 23
    'Phosphate',     # 24
    'Potassium',     # 25
    'Bilirubin_total', # 26
    'TroponinI',     # 27
    'Hct',           # 28
    'Hgb',           # 29
    'PTT',           # 30
    'WBC',           # 31
    'Fibrinogen',    # 32
    'Platelets',     # 33
]

# Define critical variables (Red nodes)
red_vars = {'HR', 'O2Sat', 'Temp', 'SBP', 'MAP', 'DBP', 'Resp', 'EtCO2'}

# Color scheme
colors = {
    'red': '#FF6B6B',  # Soft Red
    'blue': '#4D96FF'  # Soft Blue
}

# Generate node color list based on variable type
node_colors = []
for name in var_names:
    if name in red_vars:
        node_colors.append(colors['red'])
    else:
        node_colors.append(colors['blue'])

# Edge data for different settings (Source, Target, Weight)
csv_edges = {
    'wo ITG': [
        (32, 22, 0.038898),
        (4, 31, 0.038826),
        (31, 0, 0.038218),
        (17, 0, 0.038214),
        (22, 22, 0.037749),
        (31, 31, 0.037430),
        (2, 0, 0.037208),
        (19, 22, 0.037160),
        (24, 3, 0.037125),
        (18, 0, 0.036962),
        (15, 4, 0.036892),
        (3, 3, 0.036877),
        (17, 31, 0.036849),
        (10, 0, 0.036654),
        (20, 4, 0.036394),
        (3, 0, 0.036191),
        (0, 0, 0.036162),
        (0, 22, 0.036094),
        (28, 4, 0.036066),
        (3, 4, 0.035941),
        (4, 22, 0.035855),
        (12, 31, 0.035835),
        (32, 0, 0.035825),
        (18, 31, 0.035786),
        (28, 0, 0.035706),
        (10, 4, 0.035630),
        (11, 31, 0.035605),
        (5, 4, 0.035492),
        (4, 4, 0.035340),
        (31, 3, 0.035272),
        (3, 31, 0.035200),
        (13, 13, 0.035149),
        (19, 0, 0.035129),
        (21, 31, 0.035120)
    ],
    'tuned': [
        (0, 0, 0.041481),
        (14, 0, 0.039884),
        (27, 4, 0.039362),
        (17, 31, 0.038907),
        (3, 3, 0.038391),
        (5, 31, 0.038227),
        (22, 0, 0.037891),
        (29, 0, 0.037617),
        (5, 22, 0.037461),
        (25, 0, 0.037301),
        (14, 4, 0.037086),
        (10, 31, 0.036872),
        (30, 3, 0.036848),
        (33, 4, 0.036803),
        (9, 3, 0.036722),
        (2, 3, 0.036648),
        (3, 31, 0.036627),
        (13, 3, 0.036558),
        (4, 4, 0.036546),
        (5, 0, 0.036385),
        (13, 22, 0.036355),
        (19, 31, 0.036270),
        (25, 3, 0.036231),
        (4, 0, 0.036219),
        (11, 22, 0.036217),
        (20, 22, 0.036104),
        (12, 3, 0.036071),
        (30, 22, 0.035984),
        (33, 0, 0.035929),
        (12, 31, 0.035849),
        (16, 3, 0.035843),
        (9, 0, 0.035824),
        (4, 3, 0.035752),
        (25, 4, 0.035638),
        (12, 0, 0.035638),
        (13, 4, 0.035629),
        (33, 31, 0.035592),
        (32, 3, 0.035504),
        (5, 5, 0.035461),
        (3, 22, 0.035458),
        (23, 4, 0.035429),
        (1, 31, 0.035400),
        (3, 0, 0.035399),
        (15, 31, 0.035397),
        (21, 4, 0.035374),
        (26, 4, 0.035349),
        (14, 31, 0.035344),
        (22, 31, 0.035329),
        (5, 4, 0.035282),
        (5, 3, 0.035242),
        (0, 31, 0.035206),
        (18, 22, 0.035084),
        (0, 22, 0.035081),
        (20, 4, 0.035020)
    ],
    'LB': [
        (3, 3, 0.040041),
        (15, 4, 0.039909),
        (4, 4, 0.039480),
        (28, 22, 0.038725),
        (0, 31, 0.038231),
        (32, 4, 0.038030),
        (25, 0, 0.037958),
        (16, 4, 0.037693),
        (21, 0, 0.037690),
        (22, 22, 0.037655),
        (21, 4, 0.037484),
        (10, 31, 0.037475),
        (15, 0, 0.037417),
        (7, 4, 0.037303),
        (10, 22, 0.037271),
        (1, 22, 0.037195),
        (0, 0, 0.037188),
        (26, 31, 0.037179),
        (5, 3, 0.037082),
        (2, 22, 0.037024),
        (31, 4, 0.036988),
        (31, 31, 0.036975),
        (16, 22, 0.036857),
        (23, 3, 0.036783),
        (18, 31, 0.036720),
        (1, 3, 0.036696),
        (12, 3, 0.036666),
        (9, 22, 0.036570),
        (7, 31, 0.036503),
        (7, 3, 0.036446),
        (29, 0, 0.036438),
        (28, 4, 0.036391),
        (12, 0, 0.036322),
        (22, 3, 0.036317),
        (19, 0, 0.036295),
        (26, 4, 0.036250),
        (17, 3, 0.036230),
        (19, 31, 0.036218),
        (23, 22, 0.036186),
        (11, 3, 0.036074),
        (6, 4, 0.036061),
        (31, 22, 0.036057),
        (22, 31, 0.036048),
        (3, 22, 0.035981),
        (13, 4, 0.035960),
        (25, 22, 0.035951),
        (3, 4, 0.035893),
        (33, 31, 0.035882),
        (11, 31, 0.035851),
        (9, 31, 0.035825),
        (10, 3, 0.035781),
        (16, 3, 0.035754),
        (6, 3, 0.035749),
        (30, 3, 0.035725),
        (20, 3, 0.035722),
        (5, 31, 0.035684),
        (25, 3, 0.035649),
        (16, 31, 0.035635),
        (14, 14, 0.035593),
        (24, 31, 0.035525),
        (8, 22, 0.035506),
        (29, 22, 0.035500),
        (21, 31, 0.035495),
        (16, 0, 0.035477),
        (14, 22, 0.035438),
        (22, 4, 0.035377),
        (33, 0, 0.035340),
        (9, 0, 0.035318),
        (13, 22, 0.035304),
        (8, 0, 0.035280),
        (22, 0, 0.035256),
        (6, 0, 0.035254),
        (2, 0, 0.035239),
        (33, 3, 0.035228),
        (2, 31, 0.035165),
        (32, 22, 0.035162),
        (11, 22, 0.035101),
        (20, 0, 0.035028)
    ]
}

# =========================
# Hyperparameters
# =========================
TOP_EDGE_RATIO = 0.20          # Ratio for strong edges (default: 0.20)
LABEL_TOPK = 10                # Label top K nodes only
REMOVE_SELF_LOOPS = True       # Remove self-loops
WEAK_ALPHA = 0.06              # Transparency for weak edges
STRONG_ALPHA = 0.28            # Transparency for strong edges
WEAK_COLOR = "#000000"         # Color for weak edges
STRONG_COLOR = "#000000"       # Color for strong edges
ARROWS = False                 # Toggle arrows
SEED = 0

# =========================
# 1) Layout: Compute once using full_G
# =========================
full_G = nx.DiGraph()
for name in var_names:
    full_G.add_node(name)

pos = nx.spring_layout(full_G, seed=SEED, k=0.55)  # k controls spacing

# =========================
# 2) Plotting
# =========================
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(30, 10), constrained_layout=True)

def normalize(arr, lo, hi):
    arr = np.array(arr, dtype=float)
    if arr.size == 0:
        return arr
    mn, mx = arr.min(), arr.max()
    if abs(mx - mn) < 1e-12:
        return np.full_like(arr, (lo + hi) / 2.0)
    return lo + (arr - mn) / (mx - mn) * (hi - lo)

def plot_graph(edge_list, ax, title=None):
    G = nx.DiGraph()

    # ---- Add active nodes ----
    nodes_in_edges = set()
    for (i, j, w) in edge_list:
        nodes_in_edges.add(i)
        nodes_in_edges.add(j)

    for idx in nodes_in_edges:
        G.add_node(var_names[idx])

    # ---- Add edges ----
    for (i, j, w) in edge_list:
        G.add_edge(var_names[i], var_names[j], weight=float(w))

    # ---- Remove self-loops ----
    if REMOVE_SELF_LOOPS:
        G.remove_edges_from(nx.selfloop_edges(G))

    # ---- Assign node colors ----
    current_node_colors = []
    for node in G.nodes():
        idx = var_names.index(node)
        current_node_colors.append(node_colors[idx])

    # =========================
    # Node size: Weighted Degree (In+Out)
    # =========================
    wdeg = {}
    for n in G.nodes():
        s = 0.0
        for _, _, d in G.in_edges(n, data=True):
            s += abs(d.get("weight", 0.0))
        for _, _, d in G.out_edges(n, data=True):
            s += abs(d.get("weight", 0.0))
        wdeg[n] = s

    node_size_vals = np.array([wdeg[n] for n in G.nodes()], dtype=float)
    node_sizes = normalize(node_size_vals, lo=450, hi=1400)

    nx.draw_networkx_nodes(
        G, pos, ax=ax,
        node_color=current_node_colors,
        node_size=node_sizes,
        alpha=0.88,
        linewidths=1.6,
        edgecolors="#222222"
    )

    # =========================
    # Edge Layers: Weak (Background) vs Strong (Backbone)
    # =========================
    edges_data = list(G.edges(data=True))
    if edges_data:
        # Sort by weight magnitude
        edges_sorted = sorted(edges_data, key=lambda x: abs(x[2].get("weight", 0.0)), reverse=True)
        topk = max(1, int(len(edges_sorted) * TOP_EDGE_RATIO))
        strong = edges_sorted[:topk]
        weak = edges_sorted[topk:]

        # Normalize widths
        strong_w = np.array([abs(d.get("weight", 0.0)) for _, _, d in strong])
        weak_w   = np.array([abs(d.get("weight", 0.0)) for _, _, d in weak])

        strong_widths = normalize(strong_w, lo=1.4, hi=3.2)
        weak_widths   = normalize(weak_w,   lo=0.4, hi=1.0)

        # Draw Weak Edges
        if weak:
            weak_edges = [(u, v) for u, v, d in weak]
            nx.draw_networkx_edges(
                G, pos, ax=ax,
                edgelist=weak_edges,
                width=weak_widths,
                edge_color=WEAK_COLOR,
                alpha=WEAK_ALPHA,
                arrows=ARROWS,
                arrowstyle='->',
                arrowsize=10,
                connectionstyle="arc3,rad=0.08"
            )

        # Draw Strong Edges
        if strong:
            strong_edges = [(u, v) for u, v, d in strong]
            nx.draw_networkx_edges(
                G, pos, ax=ax,
                edgelist=strong_edges,
                width=strong_widths,
                edge_color=STRONG_COLOR,
                alpha=STRONG_ALPHA,
                arrows=ARROWS,
                arrowstyle='->',
                arrowsize=12,
                connectionstyle="arc3,rad=0.10"
            )

    # =========================
    # Labels: Top-K only
    # =========================
    top_nodes = sorted(G.nodes(), key=lambda n: wdeg[n], reverse=True)[:LABEL_TOPK]
    labels = {n: n for n in top_nodes}

    nx.draw_networkx_labels(
        G, pos, labels=labels,
        font_size=11,
        font_weight='bold',
        font_color="#1f1f1f",
        ax=ax
    )

    # Title
    if title:
        ax.set_title(title, fontsize=22, pad=18, fontweight='bold', color='#333333')
    ax.set_axis_off()

# Plot combined figure
for ax, (csv_name, edge_list) in zip(axes, csv_edges.items()):
    plot_graph(edge_list, ax, title=csv_name)

# Save combined
plt.savefig('attention_graphs_comparison_icml_clean.png', dpi=300, bbox_inches='tight')
plt.savefig('attention_graphs_comparison_icml_clean.pdf', bbox_inches='tight')
plt.show()

# Generate individual subgraphs
for csv_name, edge_list in csv_edges.items():
    fig, ax = plt.subplots(figsize=(10, 10), constrained_layout=True)
    plot_graph(edge_list, ax, title=None)  # No caption
    plt.savefig(f'attention_graph_{csv_name}_icml_clean.png', dpi=300, bbox_inches='tight')
    plt.savefig(f'attention_graph_{csv_name}_icml_clean.pdf', bbox_inches='tight')
    plt.close(fig)