# -*- coding: utf-8 -*-

#@title Planted Trip Planning

import numpy as np
import random
import matplotlib.pyplot as plt
from scipy.optimize import linprog
from collections import defaultdict
import time

# -----------------------------
# Hypergraph construction
# -----------------------------

def build_types(R, m_per_type, alpha, rng):
    """
    Build R types, each with m_per_type activities.
    For type r, choose |A_r^*| = alpha * m_per_type core activities.
    Returns:
      - type_indices: list of lists of vertex indices for each type
      - core_sets: list of sets A_r^* for each type
    Vertices are labeled 0..(R*m_per_type-1).
    """
    type_indices = []
    core_sets = []
    offset = 0
    for r in range(R):
        verts = list(range(offset, offset + m_per_type))
        type_indices.append(verts)
        k = max(1, int(round(alpha * m_per_type)))  # at least 1 core
        core = set(rng.sample(verts, k))
        core_sets.append(core)
        offset += m_per_type
    return type_indices, core_sets

def sample_hyperedges(type_indices, core_sets, p, N, rng):
    """
    Sample N hyperedges i.i.d. from the generative model:
    for each type r, with probability p choose a uniform activity from A_r^*,
    otherwise choose uniformly from A_r minus A_r^*.
    Returns:
      - edges: list of N sampled hyperedges, each a tuple of vertex indices (one per type)
    """
    R = len(type_indices)
    edges = []
    for _ in range(N):
        e = []
        for r in range(R):
            A_r = type_indices[r]
            core_r = core_sets[r]
            if rng.random() < p and len(core_r) > 0:
                v = rng.choice(list(core_r))
            else:
                complement = [v for v in A_r if v not in core_r]
                # if complement is empty (e.g. alpha ~ 1), fall back to core
                if complement:
                    v = rng.choice(complement)
                else:
                    v = rng.choice(list(core_r))
            e.append(v)
        edges.append(tuple(e))
    return edges

# -----------------------------
# LP construction and solving
# -----------------------------

def build_lp(n_vertices, edges, tau):
    """
    Build LP:
      minimize sum_v x_v
      s.t. y_e <= x_v  for all e, v in e
           (1/N) sum_e y_e >= tau  <=> sum_e y_e >= tau * N
           0 <= x_v, y_e <= 1
    Vars: x_0..x_{n_vertices-1}, y_0..y_{m-1}
    Returns (c, A_ub, b_ub, bounds)
    """
    N = len(edges)
    n = n_vertices
    m = N

    # Objective: minimize sum x_v
    c = np.concatenate([np.ones(n), np.zeros(m)])

    rows = []
    rhs = []

    # Constraints: -x_v + y_e <= 0 for each edge e and each vertex v in e
    for ei, e in enumerate(edges):
        for v in e:
            row = np.zeros(n + m)
            row[v] = -1.0        # -x_v
            row[n + ei] = 1.0    # +y_e
            rows.append(row)
            rhs.append(0.0)

    # Coverage constraint: sum_e y_e >= tau * N  <=>  -sum_e y_e <= -tau * N
    row = np.zeros(n + m)
    for ei in range(m):
        row[n + ei] = -1.0
    rows.append(row)
    rhs.append(-tau * N)

    A_ub = np.vstack(rows)
    b_ub = np.array(rhs)

    bounds = [(0.0, 1.0)] * (n + m)
    return c, A_ub, b_ub, bounds

def solve_lp_get_y(n_vertices, edges, tau):
    """
    Solve the LP and return the hyperedge variables y_e.
    """
    c, A_ub, b_ub, bounds = build_lp(n_vertices, edges, tau)
    res = linprog(c, A_ub=A_ub, b_ub=b_ub, bounds=bounds,
                  method='highs')
    if not res.success:
        return None, res
    z = res.x
    y = z[n_vertices:]  # hyperedge variables
    return y, res

# -----------------------------
# Rounding scheme for varying tau'
# -----------------------------

def rounding_for_tau_prime(edges, y, tau_prime, N):
    """
    Given LP hyperedge scores y and a threshold tau_prime in [0,1],
    greedily select the top K hyperedges by y, where K = ceil(tau_prime * N).
    Returns:
      - selected_edges: list of selected hyperedges (tuples)
      - union_size: number of distinct vertices in the selected hyperedges
    """
    K = int(np.ceil(tau_prime * N))
    if K <= 0:
        return [], 0

    # sort edges by y in descending order
    order = np.argsort(-y)
    sel_idx = order[:K]
    selected_edges = [edges[i] for i in sel_idx]

    # compute union size
    union_vertices = set()
    for e in selected_edges:
        union_vertices.update(e)
    return selected_edges, len(union_vertices)

# -----------------------------
# Main experiment
# -----------------------------

def run_experiment(R=5, m_per_type=10, tau=0.8,
                   alphas=(0.2, 0.4, 0.6),
                   N=1000, trials=20):
    """
    For each alpha, build cores, sample N hyperedges, solve LP with coverage tau,
    and for tau' in [5*tau-4, tau] compute the ratio:
      (# vertices in union of selected hyperedges) / (alpha * m * R)
    where selected hyperedges are the top ceil(tau' * N) by y_e.
    Returns a dict alpha -> { 'tau_primes': [...], 'means': [...], 'stds': [...] }.
    """
    # compute p so that p^R = tau
    p = tau ** (1.0 / R)
    print(f"R={R}, m={m_per_type}, tau={tau}, p={p:.4f}, N={N}")

    # tau' range
    tau_prime_min = max(0.0, 2*tau - 1.0)  # as specified
    tau_prime_max = 0.99
    tau_primes = np.linspace(tau_prime_min, tau_prime_max, 20)
    print(f"tau' range: [{tau_prime_min:.3f}, {tau_prime_max:.3f}]")

    results = {}

    for alpha in alphas:
        print(f"Running alpha={alpha}")
        ratios_per_tauprime = {tp: [] for tp in tau_primes}

        for t in range(trials):
            rng = random.Random(1000 * t + int(10000 * alpha))

            # build types and core subsets
            type_indices, core_sets = build_types(R, m_per_type, alpha, rng)
            n_vertices = R * m_per_type

            # sample N hyperedges i.i.d.
            edges = sample_hyperedges(type_indices, core_sets, p, N, rng)

            # solve LP with fixed tau
            y, res = solve_lp_get_y(n_vertices, edges, tau)
            if y is None:
                # if LP fails, record NaNs for this trial
                print(f"LP failed for alpha={alpha}, trial={t}, status={res.status}")
                for tp in tau_primes:
                    ratios_per_tauprime[tp].append(np.nan)
                continue

            # rounding for each tau'
            core_union_size = alpha * m_per_type * R
            for tp in tau_primes:
                _, union_size = rounding_for_tau_prime(edges, y, tp, N)
                ratio = union_size / core_union_size if core_union_size > 0 else np.nan
                ratios_per_tauprime[tp].append(ratio)

        # aggregate statistics over trials
        means = []
        stds = []
        for tp in tau_primes:
            arr = np.array([v for v in ratios_per_tauprime[tp] if not np.isnan(v)])
            if arr.size == 0:
                means.append(np.nan)
                stds.append(np.nan)
            else:
                means.append(arr.mean())
                stds.append(arr.std(ddof=0))
        results[alpha] = {
            'tau_primes': tau_primes,
            'means': means,
            'stds': stds,
        }

    return results

def plot_results(results, tau, filename='bicriteria_plot.png'):
    plt.figure(figsize=(8,5))
    for alpha, data in results.items():
        tp = data['tau_primes']
        means = data['means']
        stds = data['stds']
        plt.errorbar(tp, means, yerr=stds, marker='o', capsize=4,
                     label=f'alpha={alpha}')
    plt.xlabel(r"$\tau'$ (fraction of sampled hyperedges selected)")
    plt.ylabel(r"Selected vertices / $(\alpha m R)$")
    plt.title(rf"LP rounding bicriteria tradeoff (fixed $\tau={tau}$)")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(filename, dpi=200)
    plt.show()
    print(f"Saved plot to {filename}")

if __name__ == "__main__":
    R = 5
    m_per_type = 10
    tau = 0.8
    alphas = (0.2, 0.4, 0.6, 0.8)
    N = 1000
    trials = 20

    start = time.time()
    results = run_experiment(R=R, m_per_type=m_per_type, tau=tau,
                             alphas=alphas, N=N, trials=trials)
    print(f"Experiment took {time.time() - start:.2f} seconds.")
    plot_results(results, tau=tau, filename='bicriteria_lp_rounding.png')

#@title Planted Trip Planning (Split Conformal)

import numpy as np
import random
import matplotlib.pyplot as plt
from scipy.optimize import linprog
import time

# -----------------------------
# Hypergraph construction
# -----------------------------

def build_types(R, m_per_type, alpha, rng):
    """
    Build R types, each with m_per_type activities.
    For type r, choose |A_r^*| = alpha * m_per_type core activities.
    """
    type_indices = []
    core_sets = []
    offset = 0
    for r in range(R):
        verts = list(range(offset, offset + m_per_type))
        type_indices.append(verts)
        k = max(1, int(round(alpha * m_per_type)))
        core = set(rng.sample(verts, k))
        core_sets.append(core)
        offset += m_per_type
    return type_indices, core_sets

def sample_hyperedges(type_indices, core_sets, p, N, rng):
    """
    Sample N hyperedges i.i.d. from the generative model.
    """
    R = len(type_indices)
    edges = []
    for _ in range(N):
        e = []
        for r in range(R):
            A_r = type_indices[r]
            core_r = core_sets[r]
            if rng.random() < p and len(core_r) > 0:
                v = rng.choice(list(core_r))
            else:
                complement = [v for v in A_r if v not in core_r]
                if complement:
                    v = rng.choice(complement)
                else:
                    v = rng.choice(list(core_r))
            e.append(v)
        edges.append(tuple(e))
    return edges

# -----------------------------
# LP Logic
# -----------------------------

def build_lp_matrices(n_vertices, edges):
    """
    Standard LP construction:
    min sum x_v
    s.t. -x_v + y_e <= 0
         -sum y_e <= -tau * N
    """
    N = len(edges)
    n = n_vertices
    m = N

    c = np.concatenate([np.ones(n), np.zeros(m)])

    rows = []
    # Structural constraints
    for ei, e in enumerate(edges):
        for v in e:
            row = np.zeros(n + m)
            row[v] = -1.0
            row[n + ei] = 1.0
            rows.append(row)

    # Coverage constraint (template)
    cov_row = np.zeros(n + m)
    for ei in range(m):
        cov_row[n + ei] = -1.0
    rows.append(cov_row)

    A_ub = np.vstack(rows)
    bounds = [(0.0, 1.0)] * (n + m)

    return c, A_ub, bounds

def solve_lp_for_tau(c, A_ub, bounds, n_edges, target_tau):
    # Update RHS for coverage constraint
    b_ub = np.zeros(A_ub.shape[0])
    b_ub[-1] = -target_tau * n_edges

    res = linprog(c, A_ub=A_ub, b_ub=b_ub, bounds=bounds, method='highs')

    if not res.success:
        return None

    # Extract hyperedge variables y
    y = res.x[len(res.x) - n_edges:]

    # Select edges where y > 0
    selected_indices = [i for i, val in enumerate(y) if val > 1e-5]
    return selected_indices

def calc_coverage(vertex_set, edges):
    if not edges: return 0.0
    count = 0
    for e in edges:
        if set(e).issubset(vertex_set):
            count += 1
    return count / len(edges)

# -----------------------------
# Main experiment
# -----------------------------

def run_experiment(R=5, m_per_type=10, tau=0.8,
                   alphas=(0.2, 0.4, 0.6, 0.8),
                   trials=20):

    # Total samples to generate (50 Train + 50 Test)
    N_total = 200
    n_train = 100
    # n_test is implicit (remaining)

    # Prob for core selection
    p = tau ** (1.0 / R)
    print(f"R={R}, m={m_per_type}, planted tau={tau}, p={p:.4f}, Total N={N_total}")

    # 1. Training Sweep Grid (tau_prime)
    # We solve LP for these values on Train to get candidate sets
    train_tau_grid = np.linspace(0.1, 1.0, 20)

    # 2. Evaluation Grid (phi)
    # Target coverage on Test
    eval_phis = np.linspace(0.6, 1.0, 20)

    results = {}

    for alpha in alphas:
        print(f"Running alpha={alpha}...")

        # Storage for this alpha: phi -> list of ratios
        ratios_per_phi = {phi: [] for phi in eval_phis}
        core_size = int(round(alpha * m_per_type)) * R

        for t in range(trials):
            rng = random.Random(1000 * t + int(10000 * alpha))

            # Build World
            type_indices, core_sets = build_types(R, m_per_type, alpha, rng)
            n_vertices = R * m_per_type

            # Sample Data (Train + Test)
            all_edges = sample_hyperedges(type_indices, core_sets, p, N_total, rng)
            edges_train = all_edges[:n_train]
            edges_test = all_edges[n_train:]

            # --- PHASE 1: Generate Candidates on Train ---
            # Pre-build LP matrices
            c, A_ub, bounds = build_lp_matrices(n_vertices, edges_train)

            candidates = [] # List of (vertex_set, size)

            for tp in train_tau_grid:
                sel_idx = solve_lp_for_tau(c, A_ub, bounds, n_train, tp)

                if sel_idx is not None:
                    # Construct subgraph (union of vertices)
                    v_set = set()
                    for idx in sel_idx:
                        v_set.update(edges_train[idx])

                    candidates.append(v_set)

            # Add full set as fallback
            all_v = set(range(n_vertices))
            candidates.append(all_v)

            # --- PHASE 2: Evaluate on Test ---
            for phi in eval_phis:
                best_size = float('inf')
                found = False

                for cand_set in candidates:
                    # Check validity on TEST
                    cov = calc_coverage(cand_set, edges_test)

                    if cov >= phi:
                        s = len(cand_set)
                        if s < best_size:
                            best_size = s
                            found = True

                # Record
                if found and core_size > 0:
                    ratios_per_phi[phi].append(best_size / core_size)
                else:
                    # Fallback if no candidate covers (shouldn't happen with all_v fallback)
                    # or if core_size is 0
                    ratios_per_phi[phi].append(np.nan)

        # Aggregate stats
        means = []
        stds = []
        for phi in eval_phis:
            arr = np.array([v for v in ratios_per_phi[phi] if not np.isnan(v)])
            if arr.size == 0:
                means.append(np.nan); stds.append(np.nan)
            else:
                means.append(arr.mean())
                stds.append(arr.std(ddof=0))

        results[alpha] = {
            'phis': eval_phis,
            'means': means,
            'stds': stds
        }

    return results

def plot_results(results, tau, filename='trip_planning_conformal.png'):
    plt.figure(figsize=(8,5))
    for alpha, data in results.items():
        phis = data['phis']
        means = data['means']
        stds = data['stds']
        plt.errorbar(phis, means, yerr=stds, marker='o', capsize=4,
                     label=f'alpha={alpha}')

    plt.axvline(x=tau, color='k', linestyle='--', alpha=0.5, label=f'Planted $\\tau={tau}$')

    plt.xlabel(r"Target Test Coverage $\phi$")
    plt.ylabel(r"Size Ratio ($|\hat{K}| / |K_{core}|$)")
    plt.title(rf"Conformal Trip Planning")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(filename, dpi=200)
    plt.show()
    print(f"Saved plot to {filename}")

if __name__ == "__main__":
    R = 5
    m_per_type = 10
    tau = 0.8 # Planted core probability
    alphas = (0.2, 0.4, 0.6, 0.8)
    trials = 20

    start = time.time()
    results = run_experiment(R=R, m_per_type=m_per_type, tau=tau,
                             alphas=alphas, trials=trials)
    print(f"Experiment took {time.time() - start:.2f} seconds.")
    plot_results(results, tau=tau)

#@title  **** Path Routing ****

#@title Routing with a ByPass -- Conformal

import networkx as nx
import numpy as np
import random
import matplotlib.pyplot as plt
from scipy.optimize import linprog
from scipy.sparse import coo_matrix, vstack
from collections import Counter

# ==========================================
# 1. Graph Generation
# ==========================================
def create_city_bypass(N=6, bypass_len=20):
    G = nx.DiGraph()
    pos = {}

    S, T = "Source", "Target"
    pos[S] = (-1, N)
    pos[T] = (N, -1)

    # Grid
    for r in range(N):
        for c in range(N):
            node_id = f"{r}_{c}"
            G.add_node(node_id)
            pos[node_id] = (c, N-1-r)
            if r < N-1: G.add_edge(node_id, f"{r+1}_{c}", type='city', base_weight=1.0)
            if c < N-1: G.add_edge(node_id, f"{r}_{c+1}", type='city', base_weight=1.0)

    # Connect S/T to Grid
    G.add_edge(S, "0_0", type='city', base_weight=1.0)
    G.add_edge(f"{N-1}_{N-1}", T, type='city', base_weight=1.0)

    # Bypass
    prev = S
    for i in range(bypass_len):
        curr = f"h_{i}"
        G.add_node(curr)
        # Visual arc
        alpha = (i + 1) / (bypass_len + 1)
        x = -1 + alpha * (N + 1) + 2 * np.sin(alpha * np.pi)
        y = N - alpha * (N + 1) + 2 * np.sin(alpha * np.pi)
        pos[curr] = (x, y)

        G.add_edge(prev, curr, type='highway', base_weight=1.0)
        prev = curr
    G.add_edge(prev, T, type='highway', base_weight=1.0)

    return G, pos, S, T

# ==========================================
# 2. Independent Path Sampling
# ==========================================
def generate_batch(G, S, T, n_paths=50, split=0.15, seed=None):
    if seed is not None:
        random.seed(seed); np.random.seed(seed)

    paths = []

    # 1. Highway (Concentrated)
    # 15% of traffic
    n_hwy = int(n_paths * split)

    # Pre-calculate the highway path once (it's static)
    h_edges = [(u,v) for u,v in G.edges() if G[u][v]['type'] == 'highway']
    try:
        H = G.edge_subgraph(h_edges)
        p_nodes = nx.shortest_path(H, S, T)
        hwy_path_edges = [tuple(sorted((p_nodes[i], p_nodes[i+1]))) for i in range(len(p_nodes)-1)]
        for _ in range(n_hwy):
            paths.append(hwy_path_edges)
    except: pass

    # 2. City (Diffuse)
    # 85% of traffic. Every path gets FRESH random weights.
    n_city = n_paths - n_hwy

    for _ in range(n_city):
        # Assign fresh random weights to city edges
        for u, v in G.edges():
            if G[u][v]['type'] == 'city':
                G[u][v]['weight'] = np.random.uniform(0.1, 2.0)
            else:
                G[u][v]['weight'] = 1000.0 # Block highway

        try:
            p_nodes = nx.shortest_path(G, S, T, weight='weight')
            p_edges = [tuple(sorted((p_nodes[i], p_nodes[i+1]))) for i in range(len(p_nodes)-1)]
            paths.append(p_edges)
        except: continue

    # Shuffle to mix types
    random.shuffle(paths)
    return paths

# ==========================================
# 3. Solver
# ==========================================
def solve_lp_exact(paths, all_edges, edge_map, tau):
    n_e = len(all_edges)
    n_p = len(paths)
    c = np.concatenate([np.ones(n_e), np.zeros(n_p)])

    # Coverage: sum(z_p) >= tau * N_p
    target_count = tau * n_p
    row_cov = np.zeros(len(c)); row_cov[n_e:] = -1.0

    # Connectivity
    path_rows, path_cols, path_data = [], [], []
    curr_row = 0
    for pid, p in enumerate(paths):
        for e in p:
            if e in edge_map:
                path_rows.extend([curr_row, curr_row])
                path_cols.extend([n_e + pid, edge_map[e]])
                path_data.extend([1.0, -1.0])
                curr_row += 1

    A_conn = coo_matrix((path_data, (path_rows, path_cols)), shape=(curr_row, n_e+n_p))
    A_ub = vstack([coo_matrix(row_cov), A_conn])
    b_ub = np.concatenate([[-target_count], np.zeros(curr_row)])

    res = linprog(c, A_ub=A_ub, b_ub=b_ub, bounds=(0,1), method='highs')
    if not res.success: return np.zeros(n_e)
    return res.x[:n_e]

def calc_coverage(edge_set, paths):
    if not paths: return 0.0
    return sum(1 for p in paths if set(p).issubset(edge_set)) / len(paths)

# ==========================================
# 4. Experiment Execution
# ==========================================
N = 6
G, pos, S, T = create_city_bypass(N=N, bypass_len=20)

# Generate 50 Train, 50 Test
print("Generating Train Set...")
train_paths = generate_batch(G, S, T, n_paths=50, split=0.15, seed=42)
print("Generating Test Set...")
test_paths = generate_batch(G, S, T, n_paths=50, split=0.15, seed=101)

# Universe of edges (from Train)
all_edges = list(set(e for p in train_paths for e in p))
edge_map = {e: i for i, e in enumerate(all_edges)}

# Greedy Ranking (Train Frequency)
train_counts = Counter(e for p in train_paths for e in p)
greedy_order = [e for e, _ in train_counts.most_common()]

taus = [0.6, 0.8, 0.99]
lp_covs, gr_covs = [], []
viz_data = {}

print(f"\n{'Tau':<5} | {'Budget':<6} | {'LP(Test)':<8} | {'Gr(Test)':<8}")
print("-" * 40)

for tau in taus:
    # 1. Solve LP on Train for this Tau
    x_vals = solve_lp_exact(train_paths, all_edges, edge_map, tau)

    # 2. Threshold x > 0.001 to get edges
    lp_edge_indices = [i for i, v in enumerate(x_vals) if v > 0.001]
    lp_edge_set = set(all_edges[i] for i in lp_edge_indices)
    budget = len(lp_edge_set)

    # 3. Greedy Edges (Top budget)
    gr_edge_set = set(greedy_order[:budget])

    # 4. Evaluate on Test
    l_cov = calc_coverage(lp_edge_set, test_paths)
    g_cov = calc_coverage(gr_edge_set, test_paths)

    lp_covs.append(l_cov)
    gr_covs.append(g_cov)

    print(f"{tau:<5.2f} | {budget:<6} | {l_cov:<8.2f} | {g_cov:<8.2f}")

    if tau == 0.8:
        viz_data = {'lp': lp_edge_set, 'gr': gr_edge_set, 'lp_cov': l_cov, 'gr_cov': g_cov}

# ==========================================
# 5. Visualization
# ==========================================
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# A. Efficiency Curve
axes[0].plot(taus, lp_covs, 'o-', label='LP (Conformal)', color='blue')
axes[0].plot(taus, gr_covs, 's-', label='Greedy', color='firebrick')
axes[0].plot([0,1], [0,1], '--', color='gray', alpha=0.5)
axes[0].set_xlabel(r'Target Coverage $\tau$')
axes[0].set_ylabel('Test Coverage')
axes[0].set_title('A. Test Set Efficiency')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

def draw_on_ax(ax, edges, title, color):
    # Background
    for u, v in G.edges():
        u_s, v_s = str(u), str(v)
        if u_s in pos and v_s in pos:
            ax.plot([pos[u_s][0], pos[v_s][0]], [pos[u_s][1], pos[v_s][1]], c='lightgray', lw=0.5, zorder=0)
    # Selected
    for e in edges:
        u, v = e
        if u in pos and v in pos:
            ax.plot([pos[u][0], pos[v][0]], [pos[u][1], pos[v][1]], c=color, lw=2.5, zorder=5)
    # Landmarks
    ax.scatter(*pos[S], c='k', marker='^', s=100, zorder=10)
    ax.scatter(*pos[T], c='k', marker='*', s=100, zorder=10)
    ax.set_title(title, fontsize=11)
    ax.axis('off')

# B. Greedy Map
draw_on_ax(axes[1], viz_data['gr'],
           f"B. Greedy (Test Cov={viz_data['gr_cov']:.2f})\nPrioritizes Highway (Trap)", 'firebrick')

# C. LP Map
draw_on_ax(axes[2], viz_data['lp'],
           f"C. LP (Test Cov={viz_data['lp_cov']:.2f})\nSelects Grid (Efficient)", 'blue')

plt.tight_layout()
plt.show()

#@title Heatmap

# ==========================================
# 1. Heatmap Visualization
# ==========================================
# Count edges in test set
test_counts = Counter(e for p in test_paths for e in p)
edges, counts = zip(*test_counts.items())
max_count = max(counts)

plt.figure(figsize=(10, 8))
ax = plt.gca()

# A. Draw Background Graph (Faint Gray) to show structure
nx.draw_networkx_edges(G, pos, ax=ax, edge_color='#e0e0e0', width=1.0, arrows=False)

# B. Draw Active Edges Colored by Count
# Width scales with count for better visibility
widths = [2.0 + 3.0 * (c / max_count) for c in counts]

edges_drawn = nx.draw_networkx_edges(
    G, pos,
    edgelist=edges,
    edge_color=counts,
    edge_cmap=plt.cm.Reds,
    edge_vmin=0,
    edge_vmax=max_count,
    width=widths,
    arrows=False,
    ax=ax
)

# C. Draw Landmarks
ax.scatter(*pos[S], c='black', marker='^', s=150, zorder=10, label='Source')
ax.scatter(*pos[T], c='black', marker='*', s=150, zorder=10, label='Target')
ax.text(pos[S][0], pos[S][1]+0.5, "Source", ha='center', fontsize=12, fontweight='bold')
ax.text(pos[T][0], pos[T][1]-0.5, "Target", ha='center', fontsize=12, fontweight='bold')

# D. Colorbar
sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=max_count))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Number of Trajectories', rotation=270, labelpad=15, fontsize=12)

# Styling
ax.set_title(f"Test Traffic Intensity (N={len(test_paths)} paths)", fontsize=14)
ax.axis('off')
# Zoom out slightly to fit labels
ax.set_xlim(-2, N+1)
ax.set_ylim(-2, N+1)

plt.tight_layout()
plt.show()

#@title Reverse Greedy (Synthetic Experiment)

# ==========================================
# 2. Experiment Execution (Monte Carlo)
# ==========================================

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import random
from collections import Counter
from scipy.optimize import linprog

# --------------------------
# Helper Functions (Redefined for Safety)
# --------------------------
def calc_coverage(selected_edges_set, paths):
    """Returns fraction of paths fully contained in selected_edges_set."""
    if not paths: return 0.0
    covered_count = 0
    for path in paths:
        if all(e in selected_edges_set for e in path):
            covered_count += 1
    return covered_count / len(paths)

def get_edges_for_target_coverage(sorted_edges, paths_to_cover, target_tau):
    selected = set()
    if target_tau <= 0: return 0
    for i, edge in enumerate(sorted_edges):
        selected.add(edge)
        current_cov = calc_coverage(selected, paths_to_cover)
        if current_cov >= target_tau:
            return i + 1
    return len(sorted_edges)

# --------------------------
# Setup
# --------------------------
N = 6
G, pos, S, T = create_city_bypass(N=N, bypass_len=20)
all_edges_univ = [tuple(sorted((u, v))) for u, v in G.edges()]
edge_map = {e: i for i, e in enumerate(all_edges_univ)}

# Experiment Parameters
num_runs = 10
taus = [0.2, 0.4, 0.6, 0.8, 0.9, 1.0]
tau_ps = np.arange(0.1, 1.01, 0.1)

# Storage: [Run, Tau_Index]
results_greedy = np.zeros((num_runs, len(taus)))
results_rev_greedy = np.zeros((num_runs, len(taus))) # New storage
results_lp = np.zeros((num_runs, len(taus)))

print(f"Starting {num_runs} independent runs...")

for run_idx in range(num_runs):
    # Generate new data for this run with distinct seeds
    train_seed = 42 + run_idx * 100
    test_seed = 101 + run_idx * 100

    train_paths = generate_batch(G, S, T, n_paths=50, split=0.15, seed=train_seed)
    test_paths = generate_batch(G, S, T, n_paths=50, split=0.15, seed=test_seed)

    # 1. Greedy Order (Static for this run)
    train_counts = Counter(e for p in train_paths for e in p)
    # Sort all universe edges
    greedy_scored = [(e, train_counts[e]) for e in all_edges_univ]
    greedy_scored.sort(key=lambda x: x[1], reverse=True)
    greedy_order = [e for e, c in greedy_scored]

    # 2. Pre-calculate LP Subgraphs for all tau_p
    lp_cache = []

    for tp in tau_ps:
        x_vals = solve_lp_exact(train_paths, all_edges_univ, edge_map, tp)

        # Identify non-zero edges
        subgraph_edges = []
        for i, val in enumerate(x_vals):
            if val > 1e-5:
                subgraph_edges.append((all_edges_univ[i], val))

        # Calculate max possible Test coverage for this subgraph
        full_set = {e for e, w in subgraph_edges}
        test_cov = calc_coverage(full_set, test_paths)

        lp_cache.append({
            'tau_p': tp,
            'edges_with_weights': subgraph_edges,
            'test_coverage': test_cov
        })

    # 3. Eval Loop over Taus
    for t_idx, tau in enumerate(taus):

        # --- A. Forward Greedy ---
        n_greedy = get_edges_for_target_coverage(greedy_order, test_paths, tau)
        results_greedy[run_idx, t_idx] = n_greedy

        # --- B. Reverse Greedy (Adaptive on TRAIN stats) ---
        # Strategy: Start with ALL edges. Iteratively remove the edge that is
        # least useful for the *currently covered* paths.

        active_rev = set(all_edges_univ)
        locked_edges = set()

        # Helper: Convert paths to sets for fast checking (Local to this loop)
        train_path_sets = [set(p) for p in train_paths]
        test_path_sets = [set(p) for p in test_paths]
        total_test_paths = len(test_path_sets)

        while True:
            # 1. Identify which TRAIN paths are currently covered
            train_covered_indices = [i for i, p_set in enumerate(train_path_sets) if p_set.issubset(active_rev)]

            # 2. Count edge frequencies ONLY in covered TRAIN paths
            counts = Counter()
            for i in train_covered_indices:
                for e in train_paths[i]: # Use original list for iteration
                    if e in active_rev:
                        counts[e] += 1

            # 3. Sort candidates by utility (count asc)
            candidates = [e for e in active_rev if e not in locked_edges]
            if not candidates:
                break

            candidates.sort(key=lambda e: counts[e])

            edge_removed_in_this_pass = False

            for edge_to_remove in candidates:
                # Try remove
                active_rev.remove(edge_to_remove)

                # Check Constraint on TEST data
                ncov = 0
                for p_set in test_path_sets:
                    if p_set.issubset(active_rev):
                        ncov += 1

                if (ncov / total_test_paths) >= tau:
                    edge_removed_in_this_pass = True

                    # Optimization: If we removed an edge that was supporting a TRAIN path,
                    # we must Recalculate to update the heuristic.
                    if counts[edge_to_remove] > 0:
                        break
                    # If count was 0, it wasn't supporting valid train paths anyway. Continue.
                else:
                    # Put back
                    active_rev.add(edge_to_remove)
                    locked_edges.add(edge_to_remove)

            if not edge_removed_in_this_pass:
                break

        results_rev_greedy[run_idx, t_idx] = len(active_rev)

        # --- C. LP Strategy ---
        # Find smallest tau_p whose subgraph covers at least tau of test paths
        selected_candidate = None
        for candidate in lp_cache:
            if candidate['test_coverage'] >= tau:
                selected_candidate = candidate
                break

        if selected_candidate is None:
            # Fallback: Start with max LP subgraph and fill gaps with Greedy order
            best_lp = lp_cache[-1]
            active_set = {e for e, w in best_lp['edges_with_weights']}
            base_count = len(active_set)

            # Fill
            needed = 0
            for e in greedy_order:
                if calc_coverage(active_set, test_paths) >= tau:
                    break
                if e not in active_set:
                    active_set.add(e)
                    needed += 1
            n_lp = base_count + needed
        else:
            # Pruning Strategy
            edges_to_process = list(selected_candidate['edges_with_weights'])
            # Sort by x values ASCENDING (remove small weights first)
            edges_to_process.sort(key=lambda x: x[1])

            active_set = {e for e, w in edges_to_process}

            for edge, weight in edges_to_process:
                # Temporarily remove
                active_set.remove(edge)
                # If coverage drops below target, put it back
                if calc_coverage(active_set, test_paths) < tau:
                    active_set.add(edge)

            n_lp = len(active_set)

        results_lp[run_idx, t_idx] = n_lp

    print(f"Run {run_idx+1} complete.")

# ==========================================
# 3. Visualization (Mean + Error Bars)
# ==========================================
mu_greedy = np.mean(results_greedy, axis=0)
std_greedy = np.std(results_greedy, axis=0)

mu_rev = np.mean(results_rev_greedy, axis=0)
std_rev = np.std(results_rev_greedy, axis=0)

mu_lp = np.mean(results_lp, axis=0)
std_lp = np.std(results_lp, axis=0)

fig, ax = plt.subplots(figsize=(10, 6))

# Plot LP
ax.errorbar(taus, mu_lp, yerr=std_lp, fmt='o-', linewidth=2.5, markersize=8,
            label='Nested LP', color='royalblue', capsize=5)

# Plot Forward Greedy
ax.errorbar(taus, mu_greedy, yerr=std_greedy, fmt='s--', linewidth=2.0, markersize=8,
            label='Forward Greedy', color='firebrick', capsize=5, alpha=0.8)

# Plot Reverse Greedy
ax.errorbar(taus, mu_rev, yerr=std_rev, fmt='^-', linewidth=2.0, markersize=8,
            label='Reverse Greedy', color='forestgreen', capsize=5, alpha=0.8)

# Formatting
ax.set_xlabel(r'Target Test Coverage ($\phi$)', fontsize=12)
ax.set_ylabel('Number of Edges Required', fontsize=12)
ax.set_title(f'Efficiency: Edges Required for Target Coverage ({num_runs} Runs)', fontsize=14)

ax.set_xticks(taus)
ax.set_xlim(min(taus) - 0.05, max(taus) + 0.05)

ax.grid(True, linestyle='--', alpha=0.6)
ax.legend(fontsize=11)

plt.tight_layout()
plt.show()

#@title Reverse Greedy Visualization

# ==========================================
# 4. Comparative Visualization (Fixed Budget)
# ==========================================
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
from scipy.optimize import linprog

# 1. Setup Data (Single Shot)
N = 6
G, pos, S, T = create_city_bypass(N=N, bypass_len=20)
all_edges_univ = [tuple(sorted((u, v))) for u, v in G.edges()]
edge_map = {e: i for i, e in enumerate(all_edges_univ)}

# Generate fixed Train/Test
train_paths = generate_batch(G, S, T, n_paths=50, split=0.15, seed=42)
test_paths = generate_batch(G, S, T, n_paths=50, split=0.15, seed=101)

# Hyperedges
train_hyperedges = [set(p) for p in train_paths]
test_hyperedges = [set(p) for p in test_paths]

target_phi = 0.75

# ---------------------------------------------------------
# A. LP Solution Targeting Phi = 0.75
# ---------------------------------------------------------
print(f"Finding LP solution for coverage >= {target_phi}...")

# 1. Pre-calculate LP candidates
tau_ps = np.arange(0.1, 1.01, 0.1)
lp_candidates = []

for tp in tau_ps:
    x_vals = solve_lp_exact(train_paths, all_edges_univ, edge_map, tp)
    edges_w = []
    for i, val in enumerate(x_vals):
        if val > 1e-5: edges_w.append((all_edges_univ[i], val))

    # Max capability on Test (to pick the right tau_p)
    full_set = {e for e, w in edges_w}
    max_cov = calc_coverage(full_set, test_paths)
    lp_candidates.append({'tp': tp, 'edges': edges_w, 'max_cov': max_cov})

# 2. Select Candidate
selected = None
for cand in lp_candidates:
    if cand['max_cov'] >= target_phi:
        selected = cand
        break

# 3. Prune to exact coverage requirement
if selected:
    edges_to_process = list(selected['edges'])
    edges_to_process.sort(key=lambda x: x[1]) # Sort weight ASC

    lp_active = {e for e, w in edges_to_process}

    # Pruning loop
    for edge, w in edges_to_process:
        lp_active.remove(edge)
        if calc_coverage(lp_active, test_paths) < target_phi:
            lp_active.add(edge)
else:
    # Fallback (use all)
    lp_active = set(all_edges_univ)

B = len(lp_active)
lp_final_cov = calc_coverage(lp_active, test_paths)
print(f"LP Selected {B} edges. Test Coverage: {lp_final_cov:.2f}")

# ---------------------------------------------------------
# B. Reverse Greedy Targeting Budget B (Fixed Logic)
# ---------------------------------------------------------
print(f"Running Reverse Greedy to target budget B={B}...")

rev_active = set(all_edges_univ)
locked_edges = set() # Not strictly used here since we target budget, not coverage, but good practice

# Convert train paths to sets for fast checking
train_path_sets = [set(p) for p in train_paths]

while len(rev_active) > B+1:
    # 1. Identify currently covered TRAIN paths
    # (Fix: We must calculate utility based on Training data, not Test data)
    train_covered_indices = [i for i, p_set in enumerate(train_path_sets) if p_set.issubset(rev_active)]

    # 2. Count usage in covered TRAIN paths
    counts = Counter()
    for i in train_covered_indices:
        for e in train_paths[i]:
            if e in rev_active:
                counts[e] += 1

    # 3. Sort active edges by count (ASC)
    # Edges with count 0 (useless for current train coverage) come first
    candidates = list(rev_active)
    candidates.sort(key=lambda e: counts[e])

    # 4. Remove the worst edge
    # We remove the edge with lowest train utility to shrink size towards B
    edge_to_remove = candidates[0]
    rev_active.remove(edge_to_remove)

    # Recalculation happens implicitly in next iteration loop

rev_final_cov = calc_coverage(rev_active, test_paths)
print(f"Reverse Greedy reduced to {len(rev_active)} edges. Test Coverage: {rev_final_cov:.2f}")

# ---------------------------------------------------------
# C. Plotting
# ---------------------------------------------------------
def draw_on_ax(ax, edges, title, color):
    # Background
    for u, v in G.edges():
        u_s, v_s = str(u), str(v)
        p_u = pos[u] if u in pos else pos.get(str(u))
        p_v = pos[v] if v in pos else pos.get(str(v))

        if p_u is not None and p_v is not None:
            ax.plot([p_u[0], p_v[0]], [p_u[1], p_v[1]], c='lightgray', lw=0.5, zorder=0)
    # Selected
    for e in edges:
        u, v = e
        p_u = pos[u] if u in pos else pos.get(str(u))
        p_v = pos[v] if v in pos else pos.get(str(v))
        if p_u is not None and p_v is not None:
            ax.plot([p_u[0], p_v[0]], [p_u[1], p_v[1]], c=color, lw=2.5, zorder=5)
    # Landmarks
    ax.scatter(*pos[S], c='k', marker='^', s=100, zorder=10)
    ax.scatter(*pos[T], c='k', marker='*', s=100, zorder=10)
    ax.set_title(title, fontsize=11)
    ax.axis('off')

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Plot Reverse Greedy
draw_on_ax(axes[0], rev_active,
           f"Reverse Greedy (Budget={B})\nTest Cov={rev_final_cov:.2f}",
           'forestgreen')

# Plot LP
draw_on_ax(axes[1], lp_active,
           f"LP Solution (Budget={B})\nTest Cov={lp_final_cov:.2f}",
           'blue')

plt.tight_layout()
plt.show()