import numpy as np
import gurobipy as gp
from gurobipy import GRB


# Retrieve variable from MIP model
def get_var(m, name, size):
    x = {}
    for i in range(size[0]):
        for j in range(size[1]):
            if len(size) == 3:
                for k in range(size[2]):
                    x[i, j, k] = m.getVarByName(f"{name}[{i},{j},{k}]")
            else:
                x[i, j] = m.getVarByName(f"{name}[{i},{j}]")
    return x


# Retrieve variable value from MIP model
def get_var_value(m, name, size):
    val = []
    for i in range(size[0]):
        val.append([])
        for j in range(size[1]):
            if len(size) == 3:
                val[i].append([])
                for k in range(size[2]):
                    val[i][j].append(
                        int(np.rint(m.getVarByName(f"{name}[{i},{j},{k}]").Xn))
                    )
            else:
                val[i].append(int(np.rint(m.getVarByName(f"{name}[{i},{j}]").Xn)))
    return val


# formulate the graph space with node number ranging from N0 to N
def MIP_graph(N, N0=None):

    model = gp.Model()

    # A[u, v]: if edge u->v exists
    A = model.addVars(N, N, vtype=GRB.BINARY, name="A")
    # dis[u, v]: the shortest distance from u to v
    dis = model.addVars(N, N, vtype=GRB.INTEGER, lb=0, ub=N, name="dis")
    # z[u, v, w]: if w appears in a shortest path from u to v
    z = model.addVars(N, N, N, vtype=GRB.BINARY, name="z")
    # r[u, v]: if u could reach v
    r = model.addVars(N, N, vtype=GRB.BINARY, name="r")

    if N0 is None:
        N0 = N

    # at least N0 nodes
    expr = 0
    for v in range(N):
        expr += A[v, v]
    model.addConstr(expr >= N0)

    # force nodes with smaller indexes exist
    for v in range(N - 1):
        model.addConstr(A[v, v] >= A[v + 1, v + 1])

    # path from u to v exists only when both node u and v exist
    for u in range(N):
        for v in range(N):
            if u != v:
                model.addConstr(2 * A[u, v] <= A[u, u] + A[v, v])
                model.addConstr(dis[u, v] >= N * (1 - A[u, u]))
                model.addConstr(dis[u, v] >= N * (1 - A[v, v]))
                model.addConstr(2 * r[u, v] <= A[u, u] + A[v, v])

    # initialize the shorstest distance & path and reachability from any node to itself
    for v in range(N):
        dis[v, v].ub = 0
        r[v, v].lb = 1
        z[v, v, v].lb = 1
        for w in range(N):
            if w != v:
                z[v, v, w].ub = 0

    # initialize the shorstest distance & path and reachability from u to v when edge u->v exist
    for u in range(N):
        for v in range(N):
            if u != v:
                model.addConstr(dis[u, v] >= 2 - A[u, v])
                model.addConstr(dis[u, v] <= 1 + (N - 1) * (1 - A[u, v]))
                model.addConstr(r[u, v] >= A[u, v])
                z[u, v, u].lb = 1
                z[u, v, v].lb = 1

    # u can reach v if and only if the shortest distance from u to v is less than N, i.e., infinity
    for u in range(N):
        for v in range(N):
            if u != v:
                # dis[u, v] = N => r[u, v] = 0
                # r[u, v] = 1 => dis[u, v] < N
                model.addConstr(dis[u, v] + r[u, v] <= N)
                # dis[u, v] < N => r[u, v] = 1
                # r[u, v] = 0 => dis[u, v] = N
                model.addConstr(dis[u, v] + N * r[u, v] >= N)

    # update the shorstest distance & path and reachability from u to v via w
    for u in range(N):
        for v in range(N):
            if u != v:
                model.addConstr(z.sum(u, v, "*") >= 2 + r[u, v] - A[u, v])
                model.addConstr(z.sum(u, v, "*") <= 2 + (N - 2) * (r[u, v] - A[u, v]))
                for w in range(N):
                    if w != u and w != v:
                        # u can reach v if u can reach w and w can reach v
                        model.addConstr(r[u, v] >= r[u, w] + r[w, v] - 1)
                        # w appears in the shortest path from u to v if and only if u can reach w and w can reach v
                        model.addConstr(r[u, w] + r[w, v] >= 2 * z[u, v, w])
                        # dis[u, v] < dis[u, w] + dis[w, v] if and only if
                        # (i) u reach w
                        # (ii) w reach v
                        # (iii) w does not appear in any shortest path from u to v
                        model.addConstr(
                            dis[u, v]
                            <= dis[u, w]
                            + dis[w, v]
                            - (1 - z[u, v, w])
                            + (N + 1) * (2 - r[u, w] - r[w, v])
                        )
                        # dis[u, v] = dis[u, w] + dis[w, v] if and only if w appears in one shortest path from u to v
                        model.addConstr(
                            dis[u, v]
                            >= dis[u, w] + dis[w, v] - 2 * N * (1 - z[u, v, w])
                        )

    model.update()
    return model


# formulate NAS cell with/without node or/and edge labels
def MIP_Cell(m, N, Ln=None, Le=None, inputs_idx=None, outputs_idx=None):

    # Default one input node and one output node if not specified
    if inputs_idx is None:
        inputs_idx = [0]
    if outputs_idx is None:
        outputs_idx = [N-1]

    # retrieve variables in graph encoding
    A = get_var(m, "A", (N, N))
    dis = get_var(m, "dis", (N, N))
    z = get_var(m, "z", (N, N, N))
    r = get_var(m, "r", (N, N))

    # no edge from u to v if u > v
    for u in range(N):
        for v in range(u):
            A[u, v].ub = 0
            r[u, v].ub = 0
            dis[u, v].lb = N
            for w in range(N):
                if w != u and w != v:
                    z[u, v, w].ub = 0

    # any node v can be reached by at least one input
    for v in list(set(range(N)) - set(inputs_idx)):
        m.addConstr(gp.quicksum(r[input, v] for input in inputs_idx) >= 1)

    # any node u can reach to at least one output node
    for u in list(set(range(N)) - set(outputs_idx)):
        m.addConstr(gp.quicksum(r[u, output] for output in outputs_idx) >= 1)

    # no edge between input nodes
    if len(inputs_idx) > 1:
        ordered_inputs = sorted(inputs_idx)
        for idx, input1 in enumerate(ordered_inputs[:-1]):
            for input2 in ordered_inputs[idx+1:]:
                A[input1, input2].ub = 0
                r[input1, input2].ub = 0
                dis[input1, input2].lb = N
                for v in range(N):
                    if v != input1 and v != input2:
                        z[input1, input2, v].ub = 0

    # no edge between output nodes
    if len(outputs_idx) > 1:
        ordered_outputs = sorted(outputs_idx)
        for idx, output1 in enumerate(ordered_outputs[:-1]):
            for output2 in ordered_outputs[idx+1:]:
                A[output1, output2].ub = 0
                r[output1, output2].ub = 0
                dis[output1, output2].lb = N
                for v in range(N):
                    if v != output1 and v != output2:
                        z[output1, output2, v].ub = 0

    # encoding node labels
    if Ln is not None:
        # label matrix for nodes
        Fn = m.addVars(N, Ln, vtype=GRB.BINARY, name="Fn")

        # each node only has one label
        for v in range(N):
            m.addConstr(Fn.sum(v, "*") == 1)

        # label of input node must be ``IN''
        for input in inputs_idx:
            Fn[input, 0].lb = 1
        for v in list(set(range(N)) - set(inputs_idx)):
            Fn[v, 0].ub = 0

        # label of output node must be ``OUT''
        for output in outputs_idx:
            Fn[output, Ln - 1].lb = 1
        for v in list(set(range(N)) - set(outputs_idx)):
            Fn[v, Ln - 1].ub = 0

    # encoding edge labels
    if Le is not None:
        # label matrix for edges
        Fe = m.addVars(N, N, Le, vtype=GRB.BINARY, name="Fe")

        # an edge has label if and only if it exists
        for u in range(N):
            for v in range(N):
                if u != v:
                    m.addConstr(Fe.sum(u, v, "*") == A[u, v])
                else:
                    for l in range(Le):
                        Fe[u, v, l].ub = 0

    m.update()


# formulate graph kernels
# m: Gurobi model
# G: a list of graphs as prior samples
# N: size of graph
# beta_t: LCB coefficient
# GraphGP: graph GP
def MIP_graph_kernel(m, G, N, beta_t, GraphGP):
    # retrieve variables in graph encoding
    A = get_var(m, "A", (N, N))
    dis = get_var(m, "dis", (N, N))
    r = get_var(m, "r", (N, N))

    # number of samples
    T = len(G)
    # get number of node/edge labels
    Ln = G[0].Ln if G[0].node_attr is not None else 0
    Le = G[0].Le if G[0].edge_attr is not None else 0

    # retrieve kernel information from graph GP
    kernel_type = GraphGP.kernel.kernel_type
    exp_option = GraphGP.kernel.exp_option
    alpha = GraphGP.kernel.alpha.numpy()
    beta = GraphGP.kernel.beta.numpy() if Ln else GraphGP.kernel.beta
    gamma = GraphGP.kernel.gamma.numpy() if Le else GraphGP.kernel.gamma
    variance = (
        GraphGP.kernel.variance.numpy() if exp_option else GraphGP.kernel.variance
    )
    if variance == 0:
        ValueError("Variance is zero!")

    # retrieve variables for node labels if node label exists
    if Ln:
        Fn = get_var(m, "Fn", (N, Ln))
    # retrieve variables for edge labels if edge label exists
    if Le:
        Fe = get_var(m, "Fe", (N, N, Le))

    # Cholesky decomposition of KXX
    # KXX = L @ L^T
    # LinvY = L^{-1} @ Y
    L, LinvY, shift = GraphGP.get_K_constants()

    # K: the covariance between g and G[i]
    K = m.addMVar(T, vtype=GRB.CONTINUOUS, name="K")

    # d[u,v,s]: indicate if dis[u, v] = s
    d = m.addVars(N, N, N + 1, vtype=GRB.BINARY, name="d")
    for u in range(N):
        for v in range(N):
            # sum to one constraints
            m.addConstr(d.sum(u, v, "*") == 1)
            # link d[u, v, s] with dis[u, v]
            expr = 0
            for s in range(1, N + 1):
                expr += s * d[u, v, s]
            m.addConstr(dis[u, v] == expr)
            # dis[u, v] = N if and only if u cannot reach v
            m.addConstr(1 - d[u, v, N] == r[u, v])

    # shortest-path kernel
    if kernel_type == "SSP":
        # D[s]: number of shortest paths with length s
        D = m.addVars(N, vtype=GRB.INTEGER, lb=0, ub=N * N, name="D")
        for s in range(N):
            # link D[s] with d[u, v, s]
            expr = 0
            for u in range(N):
                for v in range(N):
                    expr += d[u, v, s]
            m.addConstr(D[s] == expr)
        # ID[s, c]: indicate if D[s] = c
        ID = m.addVars(N, N * N + 1, vtype=GRB.BINARY, name="ID")
        for s in range(N):
            # sum to one constraint
            m.addConstr(ID.sum(s, "*") == 1)
            # link ID[s, c] with D[s]
            expr = 0
            for c in range(N * N + 1):
                expr += c * ID[s, c]
            m.addConstr(D[s] == expr)
    elif kernel_type == "SP":
        if not Ln:
            ValueError("No node labels!")
        # p[u, v, s, l1, l2]: indicate if u has label l1, v has label l2, and dis[u, v] = s
        p = m.addVars(N, N, N, Ln, Ln, vtype=GRB.BINARY, name="p")
        for u in range(N):
            for v in range(N):
                expr = 0
                for s in range(N):
                    for l1 in range(Ln):
                        for l2 in range(Ln):
                            # link p[u, v, s, l1, l2] with
                            m.addConstr(
                                3 * p[u, v, s, l1, l2]
                                <= Fn[u, l1] + Fn[v, l2] + d[u, v, s]
                            )
                            m.addConstr(
                                p[u, v, s, l1, l2]
                                >= Fn[u, l1] + Fn[v, l2] + d[u, v, s] - 2
                            )
                            expr += p[u, v, s, l1, l2]
                # the shortest path from u to v exists if and only if r[u, v] = 1
                m.addConstr(expr == r[u, v])

        # P[s, l1, l2]: number of shortest paths with length s and label l1, l2
        P = m.addVars(N, Ln, Ln, vtype=GRB.INTEGER, lb=0, ub=N * N, name="P")
        for s in range(N):
            for l1 in range(Ln):
                for l2 in range(Ln):
                    # link P[s, l1, l2] with p[u, v, s, l1, l2]
                    expr = 0
                    for u in range(N):
                        for v in range(N):
                            expr += p[u, v, s, l1, l2]
                    m.addConstr(P[s, l1, l2] == expr)

        # IP[s, l1, l2, c]: indicate if P[s, l1, l2] = c
        IP = m.addVars(N, Ln, Ln, N * N + 1, vtype=GRB.BINARY, name="IP")
        for s in range(N):
            for l1 in range(Ln):
                for l2 in range(Ln):
                    # sum to one constraint
                    m.addConstr(IP.sum(s, l1, l2, "*") == 1)
                    # link IP[s, l1, l2, c] with P[s, l1, l2]
                    expr = 0
                    for c in range(N * N + 1):
                        expr += c * IP[s, l1, l2, c]
                    m.addConstr(P[s, l1, l2] == expr)

    # introduce auxiliary variables for using exponential constraints later
    if exp_option:
        # linear kernels
        R = m.addVars(T, vtype=GRB.CONTINUOUS, lb=-GRB.INFINITY, name="R")
        # exponential kernels
        ER = m.addVars(T, vtype=GRB.CONTINUOUS, name="ER")

    if Ln:
        # S[l]: number of nodes with label l
        S = m.addVars(Ln, vtype=GRB.INTEGER, lb=0, ub=N, name="S")
        # IS[l, c]: indicate if S[l] = c
        IS = m.addVars(Ln, N + 1, vtype=GRB.BINARY, name="IS")
        for l in range(Ln):
            # link S[l] with Fn[v, l]
            expr = 0
            for v in range(N):
                expr += Fn[v, l]
            m.addConstr(S[l] == expr)
            # sum to one constraints
            m.addConstr(IS.sum(l, "*") == 1)
            # link IS[l, c] with S[l]
            expr = 0
            for c in range(N + 1):
                expr += c * IS[l, c]
            m.addConstr(S[l] == expr)

    for i in range(T):
        Kp, Kn, Ke = 0, 0, 0
        expr = 0
        # covariance over shortest paths
        if kernel_type == "SSP":
            for s in range(min(N, G[i].N)):
                Kp += G[i].D[s] * D[s]
        else:
            for s in range(min(N, G[i].N)):
                for l1 in range(Ln):
                    for l2 in range(Ln):
                        Kp += G[i].P[(s, l1, l2)] * P[s, l1, l2]
        normalization_path = (N * (N - 1) / 2) * (G[i].N * (G[i].N - 1) / 2)
        expr += alpha * Kp / normalization_path

        # covariance over node features
        if Ln:
            for l in range(Ln):
                Kn += G[i].S[l] * S[l]
            normalization_node = N * G[i].N
            expr += beta * Kn / normalization_node

        # covariance over edge features
        if Le:
            for u in range(N):
                for v in range(N):
                    if G[i].edge_attr[u, v] is not None:
                        Ke += Fe[u, v, G[i].edge_attr[u, v]]
            normalization_edge = N * (N - 1) / 2
            expr += gamma * Ke / normalization_edge

        if exp_option:
            m.addConstr(R[i] == expr - shift)
            m.addGenConstrExp(R[i], ER[i], "FunNonlinear=1")
            m.addConstr(K[i] == variance * np.exp(shift) * ER[i])
        else:
            m.addConstr(K[i] == variance * expr)

    Kp, Kn, Ke = 0, 0, 0
    expr = 0
    # Kxx over shortest paths
    if kernel_type == "SSP":
        for s in range(N):
            for c in range(N * N + 1):
                Kp += c * c * ID[s, c]
    elif kernel_type == "SP":
        for s in range(N):
            for l1 in range(Ln):
                for l2 in range(Ln):
                    for c in range(N * N + 1):
                        Kp += c * c * IP[s, l1, l2, c]
    normalization_path = (N * (N - 1) / 2) * (N * (N - 1) / 2)
    expr += alpha * Kp / normalization_path
    # Kxx over node features
    if Ln:
        for l in range(Ln):
            for c in range(N + 1):
                Kn += c * c * IS[l, c]
        normalization_node = N * N
        expr += beta * Kn / normalization_node
    # Kxx over edge features
    if Le:
        for u in range(N):
            for v in range(N):
                if u != v:
                    Ke += A[u, v]
        normalization_edge = N * (N - 1) / 2
        expr += gamma * Ke / normalization_edge

    # Kxx: the covariance between g and g
    Kxx = m.addMVar(1, vtype=GRB.CONTINUOUS, name="Kxx")
    if exp_option:
        ERxx = m.addMVar(1, vtype=GRB.CONTINUOUS, name="ERxx")
        Rxx = m.addMVar(1, vtype=GRB.CONTINUOUS, lb=-GRB.INFINITY, name="Rxx")
        m.addConstr(Rxx == expr - shift)
        m.addGenConstrExp(Rxx, ERxx, "FunNonlinear=1")
        m.addConstr(Kxx == variance * np.exp(shift) * ERxx)
    else:
        m.addConstr(Kxx == variance * expr)

    # posterior mean
    mean = m.addMVar(1, vtype=GRB.CONTINUOUS, lb=-GRB.INFINITY, name="mean")
    # posterior variance
    std = m.addMVar(1, vtype=GRB.CONTINUOUS, name="std")
    # set objective function, i.e., LCB
    m.setObjective(mean - beta_t * std, GRB.MINIMIZE)
    # m.setObjective(mean, GRB.MINIMIZE)

    # LinvK = L^{-1} @ K => L @ LinvK = K
    LinvK = m.addMVar(T, lb=-GRB.INFINITY, vtype=GRB.CONTINUOUS, name="LinvK")
    m.addConstr(L @ LinvK == K)
    # mean = K^T @ KXX^{-1} @ y = (L^{-1} @ K)^T @ (L^{-1} @ y) = LinvK^T @ LinvY
    m.addConstr(LinvK.T @ LinvY == mean)
    # variance = Kxx - K^T @ (KXX)^{-1} @ K = Kxx - (LinvK)^T @ LinvK
    m.addConstr(Kxx - LinvK.T @ LinvK >= std * std)

    m.update()

    # do not consider sampled graphs
    for g in G:
        if g.N != N:
            continue
        expr = 0
        for u in range(N):
            for v in range(N):
                if g.A[u][v]:
                    expr += 1 - A[u, v]
                else:
                    expr += A[u, v]
        if Ln:
            for v in range(N):
                l = g.node_attr[v]
                expr += 1 - Fn[v, l]
        if Le:
            for u in range(N):
                for v in range(N):
                    l = g.edge_attr[u, v]
                    if l is not None:
                        expr += 1 - Fe[u, v, l]
        m.addConstr(expr >= 1)
    m.update()


def MIP_NASBench101(G, beta_t, GraphGP, N=7, random=False):
    # NAS_Bench_101
    # 2: 1, 3: 6, 4: 84, 5: 2441, 6: 62010, 7: 359082

    m = MIP_graph(N)

    # node-labeled DAG
    Ln = 5
    MIP_Cell(m, N, Ln=Ln, Le=None)

    # retrieve variables in graph encoding
    A = get_var(m, "A", (N, N))
    r = get_var(m, "r", (N, N))
    Fn = get_var(m, "Fn", (N, N))

    # maximal number of edges
    E = 9
    expr = 0
    for u in range(N):
        for v in range(u + 1, N):
            expr += A[u, v]
    m.addConstr(expr <= E)

    # symmetry breaking
    coef = [2 ** i for i in range(N - 1, -1, -1)]
    expr_f = []
    for u in range(N):
        expr = 0
        for v in range(u + 1, N):
            expr += coef[v] * r[u, v]
        expr_f.append(expr)
    for v in range(N - 1):
        m.addConstr(expr_f[v] >= expr_f[v + 1])

    expr_g = []
    for u in range(N):
        expr = 0
        for v in range(u):
            expr += coef[v] * r[v, u]
        expr_g.append(expr)

    expr_l = []
    for v in range(N):
        expr = 0
        for l in range(Ln):
            expr += l * Fn[v, l]
        expr_l.append(expr)

    for u in range(N):
        for v in range(u + 1, N):
            m.addConstr(expr_g[v] <= expr_g[u] + (1 << N) * (expr_f[u] - expr_f[v]))
            m.addConstr(
                expr_l[u]
                <= expr_l[v]
                + Ln * ((1 << N) * (expr_f[u] - expr_f[v]) + expr_g[u] - expr_g[v])
            )
    m.update()
    if not random:
        MIP_graph_kernel(m=m, G=G, N=N, beta_t=beta_t, GraphGP=GraphGP)

    return m


def MIP_NASBench201(G, beta_t, GraphGP, N=4, random=False):

    m = MIP_graph(N)

    # edge-labeled DAG
    MIP_Cell(m, N, Ln=None, Le=4)

    if not random:
        MIP_graph_kernel(m=m, G=G, N=N, beta_t=beta_t, GraphGP=GraphGP)

    return m


def MIP_NASBench301(G, beta_t, GraphGP, N=14, random=False):

    m = MIP_graph(N)

    # edge-labeled DAG
    Le = 7
    cell_size = int(N/2)
    inputs_idx = [0, 1, cell_size, cell_size+1]
    outputs_idx = [cell_size-1, N-1]
    MIP_Cell(m, N, Ln=None, Le=Le, inputs_idx=inputs_idx, outputs_idx=outputs_idx)

    # retrieve variables in graph encoding
    A = get_var(m, "A", (N, N))
    dis = get_var(m, "dis", (N, N))
    z = get_var(m, "z", (N, N, N))
    r = get_var(m, "r", (N, N))
    Fe = get_var(m, "Fe", (N, N, Le))

    # no edge from normal cell to reduce cell
    for u in range(cell_size):
        for v in range(cell_size, N):
            A[u, v].ub = 0
            r[u, v].ub = 0
            dis[u, v].lb = N
            for w in range(N):
                if w != u and w != v:
                    z[u, v, w].ub = 0

    # no edge from cell input nodes to cell output
    for shift in [0, cell_size]:
        for input in [0, 1]:
            input += shift
            output = cell_size-1 + shift
            A[input, output].ub = 0

    # all the intermediate nodes are connected to the cell output and have fixed labels
    for u in range(2, cell_size-1):
        A[u, outputs_idx[0]].lb = 1
        A[u + cell_size, outputs_idx[1]].lb = 1
        r[u, outputs_idx[0]].lb = 1
        r[u + cell_size, outputs_idx[1]].lb = 1
        dis[u, outputs_idx[0]].lb = 1
        dis[u, outputs_idx[0]].ub = 1
        dis[u + cell_size, outputs_idx[1]].lb = 1
        dis[u + cell_size, outputs_idx[1]].ub = 1
        for w in range(cell_size):
            if w != u and w != outputs_idx[0]:
                z[u, outputs_idx[0], w].ub = 0
                z[u + cell_size, outputs_idx[1], w + cell_size].ub = 0
        # use first label as arbitrary fixed label
        Fe[u, outputs_idx[0], 0].lb = 1
        Fe[u + cell_size, outputs_idx[1], 0].lb = 1

    # number of incoming edges for each intermediate nodes is 2
    for v in range(2, cell_size-1):
        m.addConstr(gp.quicksum(A[u, v] for u in range(v)) == 2)
        m.addConstr(gp.quicksum(A[u+cell_size, v+cell_size] for u in range(v)) == 2)

    m.update()
    if not random:
        MIP_graph_kernel(m=m, G=G, N=N, beta_t=beta_t, GraphGP=GraphGP)

    return m
