import numpy as np
import math
import argparse
from pyscipopt import Model, quicksum, SCIP_PARAMSETTING
from BTSP_2approx import bottleneck_tsp_2approx
import networkx as nx

def compute_D(filename, compute_shortest_path=True):
    G = nx.read_gml(filename)
    nodes = sorted(G.nodes())
    n = len(nodes)

    node_index = {node: i for i, node in enumerate(nodes)}

    d_matrix = np.full((n, n), np.inf)

    for i in range(n):
        d_matrix[i, i] = 0.0
    for u, v, data in G.edges(data=True):
        i = node_index[u]
        j = node_index[v]
        l_ij = G.adj[u][v]["distance"]
        d_value = (l_ij * 0.0085) + 4
        G.adj[u][v]["delay"] = d_value
        d_matrix[i, j] = d_value
        if d_matrix[j, i] == np.inf:
            d_matrix[j, i] = d_value
        else:
            if d_matrix[j, i] != d_value:
                raise ValueError(f"Edge ({u}, {v}) has inconsistent distances: {d_matrix[j, i]} vs {d_value}")
    adj = [(i,j) for i in range(d_matrix.shape[0]) for j in range(d_matrix.shape[0]) if d_matrix[i,j] < np.inf]

    if not compute_shortest_path:
        return d_matrix

    d_matrix_sp = np.zeros_like(d_matrix)
    for i,j in np.ndindex(n, n):
        d_matrix_sp[i, j] = nx.shortest_path_length(G, source=nodes[i], target=nodes[j], weight='delay', method='dijkstra')
    d_matrix = np.minimum(d_matrix, d_matrix_sp)
    return d_matrix

def solve_BCD_with_scip(D: np.ndarray, S_list: list):
    n = D.shape[0]
    tau = len(S_list)

    model = Model("BCD_MIP")
    model.setParam("display/verblevel", 1)

    model.setPresolve(SCIP_PARAMSETTING.OFF)
    # model.setHeuristics(SCIP_PARAMSETTING.OFF)

    physical_nodes = list(range(n))
    virtual_nodes = list(range(n))

    pi_initial = np.random.permutation(n).tolist()

    # x[i,u] = 1 if virtual node i is assigned to physical node u
    x = {(i,u):model.addVar(vtype="B", name=f"x_{i}_{u}") for i in virtual_nodes for u in physical_nodes}

    # M_l = maximum distance for step width l
    M = {}
    for l in range(tau):
        M[l] = model.addVar(vtype="C", lb=0.0, name=f"M_{l}")

    # y[i,s,u,v] = x[i,u] * x[i+s,v]
    s_list = [s for S_l in S_list for s in S_l]
    y = {(i,s):{(u,v):model.addVar(vtype="B", name=f"y_{i}_{s}_{u}_{v}") \
         for u in physical_nodes for v in physical_nodes if u != v}
         for i in virtual_nodes for s in s_list}

    for i in range(n):
        model.addCons( quicksum(x[i,u] for u in physical_nodes) == 1, f"assign_{i}")

    for u in range(n):
        model.addCons( quicksum(x[i,u] for i in virtual_nodes) == 1, f"position_{u}")

    for s in s_list:
        for i in virtual_nodes:
            for u in physical_nodes:
                for v in physical_nodes:
                    if u == v:
                        continue
                    j = (i + s) % n
                    model.addCons( y[i,s][u,v] <= x[i,u], f"y_leq_x_{i}_{s}_{u}_{v}")
                    model.addCons( y[i,s][u,v] <= x[j,v], f"y_leq_x_{j}_{s}_{u}_{v}")
                    model.addCons( y[i,s][u,v] >= x[i,u] + x[j,v] - 1, f"y_geq_x_{i}_{s}_{u}_{v}")
                    model.addCons( M[l] >= D[u,v] * y[i,s][u,v], f"M_leq_D_{l}_{i}_{s}_{u}_{v}")

    model.setObjective( quicksum(M[l] for l in range(tau)), "minimize")

    model.setParam("limits/time", 1800)
    model.optimize()

    pi_opt = [-1]*n
    try:
        for i in virtual_nodes:
            for u in physical_nodes:
                if model.getVal(x[i,u]) > 0.5:
                    pi_opt[u] = i
    except:
        return None, None
    objval = model.getObjVal()
    return pi_opt, objval
 
def compute_BCD_cost(pi: list, D: np.ndarray, S_list: list) -> float:
    n = len(pi)
    total_cost = 0.0
    for S in S_list:
        max_dist_l = 0.0
        S_p = [min(d, n - d) for d in S]
        for i in range(n):
            u = pi[i]
            for d_p in S_p:
                j = (i + d_p) % n
                v = pi[j]
                dist_uv = D[u, v]
                if dist_uv > max_dist_l:
                    max_dist_l = dist_uv
        total_cost += max_dist_l
    return total_cost

def solve_BCD_with_greedy(D: np.ndarray, S_list:list) -> list:
    n = D.shape[0]
    remaining = set(range(n))
    
    pi = [np.random.choice(list(remaining))]
    remaining.remove(pi[0])

    while len(pi) < n:
        best_candidate = None
        best_cost = float('inf')

        for v in remaining:
            pi_candidate = pi + [v]
            cost = compute_partial_BCD_cost_mod(pi_candidate, D, S_list)
            if cost < best_cost:
                best_cost = cost
                best_candidate = v

        pi.append(best_candidate)
        remaining.remove(best_candidate)

    return pi

def compute_partial_BCD_cost_mod(pi_partial: list, D: np.ndarray, S_list: list) -> float:
    k = len(pi_partial)
    n = D.shape[0]
    total_partial_cost = 0.0

    for S in S_list:
        max_dist_l = 0.0
        S_p = [min(d, n - d) for d in S]
        for i in range(k):
            u = pi_partial[i]
            for d_p in S_p:
                j = (i + d_p) % n
                if j < k:
                    v = pi_partial[j]
                    dist_uv = D[u, v]
                    if dist_uv > max_dist_l:
                        max_dist_l = dist_uv
        total_partial_cost += max_dist_l

    return total_partial_cost

def compute_bottleneck_value(pi: list, D: np.ndarray) -> float:
    n = len(pi)
    max_len = 0.0
    for idx in range(n):
        u = pi[idx]
        v = pi[(idx + 1) % n]
        if D[u, v] > max_len:
            max_len = D[u, v]
    return max_len

def is_coprime(a: int, b: int) -> bool:
    return math.gcd(a, b) == 1

def maximum_skip_sum_reduction(n: int, S_list: list, pi: list, D: np.ndarray) -> tuple:
    # 2nd stage in BTSP-MSR
    best_p = None
    best_cost = float('inf')

    for p in range(1, n):
        if not is_coprime(p, n):
            continue

        S_list_p = [[(d * p) % n for d in S] for S in S_list]
        cost_p = compute_BCD_cost(pi, D, S_list_p)

        if cost_p < best_cost:
            best_cost = cost_p
            best_p = p

    pi_star = [pi[(best_p * i) % n] for i in range(n)]
    return best_p, pi_star

def generate_S_list(n: int, nw_topology: str) -> list:
    S_list = []
    if nw_topology == "ring":
        S_list = [[1]]
    elif nw_topology == "exponential":
        k_max = int(math.floor(math.log(n - 1, 2)))
        static_S = [2**k for k in range(k_max + 1) if 2**k < n]
        S_list = [static_S]
    elif nw_topology == "one_peer_exp":
        k_max = int(math.floor(math.log(n - 1, 2)))
        static_elems = [2**k for k in range(k_max + 1) if 2**k < n]
        S_list = [[d] for d in static_elems]
    elif nw_topology == "sparse_exp":
        k_max = int(math.floor(math.log(n - 1, 2)))
        static_S = [2**k for k in range(k_max + 1) if 2**k < n]
        S_list = [[1, static_S[-2]]]
    else:
        raise ValueError(f"Unknown nw_topology: {nw_topology}")
    return S_list

def main(nw_topology="ring", filename=None, D=None, verbose=True, seed=42):
    parser = argparse.ArgumentParser(description="minimizing BCD")
    parser.add_argument("--nw_topology", choices=["ring", "exponential", "one_peer_exp", "sparse_exp"],
                        default=nw_topology, help="NW topology")
    parser.add_argument("--seed", type=int, default=seed,
                        help="random seed")
    parser.add_argument("--greedy_opt", action="store_true",
                        help="set True to use greedy optimization")
    parser.add_argument("--no_opt", action="store_true",
                        help="set True to random optimization")
    parser.add_argument("--physical_nw_file", type=str, default=filename,
                        help="distance matrix file path")
    parser.add_argument("--milp_opt", action="store_true",
                        help="set True to use SCIP")

    args = parser.parse_args()
    best_p = 0
    np.random.seed(args.seed)

    if args.physical_nw_file.endswith('.gml'):
        D = compute_D(args.physical_nw_file)
    else:
        D = np.load(args.physical_nw_file)
    n = D.shape[0]

    S_list = generate_S_list(n, args.nw_topology)

    pi_initial = list(range(n))
    np.random.shuffle(pi_initial)

    if args.greedy_opt:
        pi = solve_BCD_with_greedy(D, S_list)
    elif args.no_opt:
        pi = list(range(n))
    elif args.milp_opt:
        pi, objval = solve_BCD_with_scip(D, S_list)
        if pi is None:
            pi = pi_initial
    else: # BTSP-MSR
        pi, _, _ = bottleneck_tsp_2approx(D)
        pi = pi[:-1]

    if not args.milp_opt and not args.greedy_opt and not args.no_opt:
        best_p, pi_star = maximum_skip_sum_reduction(n, S_list, pi, D)
    else:
        best_p = 1
        pi_star = pi.copy()

    if verbose:
        print(f"n                = {n}")
        print(f"physical_nw_file      = {args.physical_nw_file}")
        print(f"nw_topology             = {args.nw_topology}")
        print(f"S_list           = {S_list}")
        print(f"pi_1 = {pi}")
        print(f"optimal p        = {best_p}")
        print(f"pi_star         = {pi_star}")
    print(f"nw_topology: {nw_topology}, p={best_p}")
    
    cost_pi = compute_BCD_cost(pi, D, S_list) / len(S_list)
    cost_final = compute_BCD_cost(pi_star, D, S_list) / len(S_list)
    cost_initial = []
    for _ in range(5): # compute average BCD cost with random permutations
        pi_random = pi_initial.copy()
        np.random.shuffle(pi_random)
        cost_random = compute_BCD_cost(pi_random, D, S_list) / len(S_list)
        cost_initial.append(cost_random)
    cost_initial = np.mean(cost_initial)
    if verbose:
        print(f"BCD cost (pi_initial avg) = {cost_initial:.4f}")
        print(f"BCD cost (pi)         = {cost_pi:.4f}")
        print(f"BCD cost (pi_star)   = {cost_final:.4f} (p = {best_p})")

    return cost_initial, cost_pi, cost_final

if __name__ == "__main__":
    cost_initial, cost_pi, cost_final = main(verbose=True)