from typing import List

from utils import *

try:
    from gurobi_onboarder import init_gurobi

    gurobi_venv, GUROBI_FOUND = init_gurobi.initialize_gurobi()
except:
    gurobi_venv = gp.Env(empty=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Forge(nn.Module):
    """
    Adapted from VQGraph code
    """

    def __init__(self,
                 input_dim=10,
                 hidden_dim=1024,
                 codebook_dim=1024,
                 dropout_ratio=0.4,
                 activation=F.relu,
                 norm_type="none",
                 codebook_size=5000,
                 lamb_edge=1,
                 lamb_node=1,
                 separate_codebooks=False,
                 orthogonal_reg_weight=0.0,
                 prob_head=False,
                 cut_head=False,
                 eval_only=False):
        # TODO add pydocs to explain these parameters, and how their value changes the architecture/algorithm
        #  TODO align/mention paper notation where applicable

        super().__init__()

        # TODO follow the order of the params as given in the init()
        # TODO Separate setting input params to self.xx fields VS. other things calculated/set based on input
        self.norm_type = norm_type
        self.dropout = nn.Dropout(dropout_ratio)
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.codebook_dim = codebook_dim
        self.orthogonal_reg_weight = orthogonal_reg_weight
        self.eval_only = eval_only
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.codebook_size = codebook_size
        # TODO separate the block of assingments coming from the input VS. fields declared inside, like this one
        # TODO also, this is boolean, call it is_trained
        self.trained = False

        # Used for the BCE loss to predict variable membership in solution
        # TODO this is boolean, call it is_prob_head
        self.prob_head = prob_head

        # Used to predict cut value
        # TODO this is boolean, call it is_cut_head
        self.cut_head = cut_head

        if input_dim < hidden_dim:
            self.updated_input_dim = hidden_dim
        else:
            self.updated_input_dim = input_dim

        self.graph_layer_1 = SAGEConv(input_dim, self.updated_input_dim, activation=activation, aggregator_type='mean')
        self.graph_layer_2 = SAGEConv(self.updated_input_dim, self.updated_input_dim, activation=activation,
                                      aggregator_type='mean')
        self.linear = nn.Linear(self.updated_input_dim, self.updated_input_dim)

        if self.prob_head:
            self.prob_layer = nn.Linear(self.updated_input_dim, 1)

        if self.cut_head:
            self.cut_layer = nn.Linear(self.updated_input_dim, 1)

        self.bn1 = nn.BatchNorm1d(self.updated_input_dim)
        self.bn2 = nn.BatchNorm1d(self.updated_input_dim)
        self.bn3 = nn.BatchNorm1d(self.updated_input_dim)

        # Node Decoder
        self.decoder_node = nn.Linear(self.updated_input_dim, input_dim)

        # Edge Decoders
        # Edges decoded as product of two matrices 
        self.decoder_edge_1 = nn.Linear(self.updated_input_dim, 32)
        self.decoder_edge_2 = nn.Linear(self.updated_input_dim, 32)

        # Seperate codebooks for node feature reconstruction and edge reconstruction?
        # TODO this is boolean, call it is_separate_codebooks
        self.separate_codebooks = separate_codebooks

        if self.separate_codebooks:
            self.vq_node = VectorQuantize(dim=self.updated_input_dim,
                                          codebook_size=codebook_size,
                                          decay=0.8,
                                          commitment_weight=0.25,
                                          use_cosine_sim=True,
                                          orthogonal_reg_weight=self.orthogonal_reg_weight,
                                          codebook_dim=self.codebook_dim)

            self.vq_edge = VectorQuantize(dim=self.updated_input_dim,
                                          codebook_size=codebook_size,
                                          decay=0.8,
                                          commitment_weight=0.25,
                                          use_cosine_sim=True,
                                          orthogonal_reg_weight=self.orthogonal_reg_weight,
                                          codebook_dim=self.codebook_dim)
        else:
            self.vq = VectorQuantize(dim=self.updated_input_dim,
                                     codebook_size=codebook_size,
                                     decay=0.8,
                                     commitment_weight=0.25,
                                     use_cosine_sim=True,
                                     orthogonal_reg_weight=self.orthogonal_reg_weight,
                                     codebook_dim=self.codebook_dim)

        self.lamb_edge = lamb_edge
        self.lamb_node = lamb_node

    def forward(self, g, feats, num_cons, num_vars):
        # TODO add pydocs, explain inputs, and especially the output

        # Input
        h = feats

        # TODO cannot you move this down to the other if no self.eval block below?
        if not self.eval_only:
            adj = g.adjacency_matrix().to_dense().to(feats.device)

        # List to hold intermediate layers
        h_list = []

        # Graph SAGE Layer 1
        h = self.graph_layer_1(g, h, edge_weight=g.edata['weight'])
        h = self.bn1(h)
        if self.norm_type != "none":
            h = self.norms[0](h)
        h = self.dropout(h)

        # GraphSAGE Layer 2 
        h = self.graph_layer_2(g, h, edge_weight=g.edata['weight'])
        h = self.bn2(h)
        h = self.dropout(h)

        # Linear Layer
        h = self.linear(h)
        h = F.relu(h)
        h = self.bn3(h)
        h = self.dropout(h)

        # Save output at this stage
        # This is going to be our "embedding" of the input graph
        h_list.append(h)

        # The "embedding" is passed into the prob head and the cut head below
        # if self.prob_head:
        #     prob = F.sigmoid(self.prob_layer(h))

        # if self.cut_head:
        #     cut = F.sigmoid(self.cut_layer(h))

        # The same "embedding" is then passed into the vector quantizer below
        if self.separate_codebooks:
            quantized_edge, _, commit_loss_edge, dist, codebook_edge = self.vq_edge(h)
            quantized_node, _, commit_loss_node, dist, codebook_node = self.vq_node(h)
            quantized_edge = self.decoder_edge(quantized_edge)
            quantized_node = self.decoder_node(quantized_node)
        else:
            quantized, _, commit_loss, dist, codebook = self.vq(h)
            quantized_node = self.decoder_node(quantized)
            quantized_edge_1 = self.decoder_edge_1(quantized)
            quantized_edge_2 = self.decoder_edge_2(quantized)

        # The "embedding" is passed into the prob head and the cut head below
        if self.prob_head:
            prob = F.sigmoid(self.prob_layer(quantized))

        if self.cut_head:
            cut = F.sigmoid(self.cut_layer(quantized))

        if not self.eval_only:
            # Reconstruction Loss (other losses are calculated in training code)
            feature_rec_loss = self.lamb_node * F.mse_loss(feats, quantized_node)

            adj_quantized_1 = torch.matmul(quantized_edge_1, quantized_edge_1.t())
            adj_quantized_2 = torch.matmul(quantized_edge_2, quantized_edge_2.t())

            adj_quantized = torch.matmul(adj_quantized_1, adj_quantized_2.T)

            # Min Max Rescaling of Adjacency Matrix
            adj_quantized = (adj_quantized - adj_quantized.min()) / (adj_quantized.max() - adj_quantized.min())

            # Look Only at The Bipartite Part of the Graph
            adj = adj[num_cons:, :num_cons]
            adj_quantized = adj_quantized[num_cons:, :num_cons]

            # Higher Penalty for Not Recreating Positive Edges
            edge_scale = adj * 1
            diff = torch.square(adj - adj_quantized)
            diff *= edge_scale

            pos_edge_rec_loss = self.lamb_edge * torch.mean(diff)
            edge_rec_loss = self.lamb_edge * torch.sqrt(F.mse_loss(adj, adj_quantized))
            edge_rec_loss += pos_edge_rec_loss

        # Distance Matrix - Distance From Each Node's Embedding to Each Code in the Codebook
        dist = torch.squeeze(dist)
        h_list.append(quantized)
        h_list.append(quantized_node)
        h_list.append(quantized_edge_1)
        h_list.append(quantized_edge_2)

        if self.prob_head:
            h_list.append(prob)
        if self.cut_head:
            h_list.append(cut)

        if self.separate_codebooks:
            if not self.eval_only:
                loss = feature_rec_loss + edge_rec_loss + commit_loss_edge + commit_loss_node
            else:
                loss = -1
            return h_list, h, loss, dist, (codebook_node, codebook_edge)
        else:
            if not self.eval_only:
                loss = feature_rec_loss + edge_rec_loss + commit_loss
            else:
                loss = -1
            return h_list, h, loss, dist, codebook

    def train_unsupervised(self, model_save_path, train_list, epochs=10, steps_per_instance=10, lr=1e-4, log_path=None):

        self.train()
        self = self.to(device)
        main_loss_list = []
        skip_list = set()
        optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=1e-4)

        t = ""

        # Loop through data set
        for main_epoch in range(epochs):

            # Alternate between prioritizing node feature reconstruction and edge reconstruction
            if main_epoch % 2 == 0:
                self.lamb_node = 10
                self.lamb_edge = 1
            else:
                self.lamb_node = 1
                self.lamb_edge = 10

            loss_list = []
            epoch_start = time.time()

            # MIPLIB instances in dataset
            for idx in range(10, len(train_list)):

                g, features, num_cons, num_vars = train_list[idx]

                # Some MIP instances are too large to fit in GPU memory 
                if g.num_nodes() > 21000:
                    skip_list.add(idx)
                    continue

                for epoch in range(steps_per_instance):
                    # Compute loss and prediction
                    h_list, logits, loss, distances, codebook_ = self.forward(g.to(device), features.to(device),
                                                                              num_cons, num_vars)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                loss_list.append(loss.item())
                print("\rEpoch: ", main_epoch, "| Loss on Instance", idx, ": ", np.round(loss.item(), 3),
                      "| Mean Loss :", np.round(np.mean(loss_list), 3), end='')
                torch.cuda.empty_cache()

            print("\r")
            print()
            print("------")

            p_string = "Epoch: " + str(main_epoch) + "| Mean Loss: " + str(
                np.round(np.mean(loss_list), 3)) + "+/-" + str(
                np.round(np.std(loss_list), 3)) + " | Time For Epoch : " + str(
                np.round(time.time() - epoch_start, 3)) + "s"
            t += p_string + "\n"
            print(p_string, end='\n')
            print("------")
            print()

            torch.save(self.state_dict(), model_save_path)
            main_loss_list.append(np.round(np.mean(loss_list), 3))
            if log_path is not None:
                with open(log_path, 'a') as file:
                    file.write(t)

            self.trained = True

    def train_triplets(self, pretrained_path, model_save_path, mips_to_triplet=None, mips_to_triplet_path=None,
                       epochs=10, steps_per_instance=10, lr=1e-5, batch_size=1024):

        # Initialize with pretrained model
        if self.prob_head is False:
            self.prob_head = True
        pre_trained = Forge()
        pre_trained.load_model(pretrained_path)
        copy_params(old_model=pre_trained, new_model=self)
        del pre_trained
        torch.cuda.empty_cache()

        self.train()
        triplet_loss = nn.TripletMarginLoss(margin=2, p=2, eps=1e-7, reduction='mean')
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4)

        if mips_to_triplet is None and mips_to_triplet_path is None:
            raise ValueError(
                'Pass in either the mips_to_triplet dictionary or the path to the dictionary stored as a pickle file')

        elif mips_to_triplet is None:
            with open(mips_to_triplet_path, 'rb') as file:
                mips_to_triplet = pkl.load(file)
        else:
            None

        keys = list(mips_to_triplet.keys())
        keys = [x for x in mips_to_triplet.keys()]

        for epoch in range(epochs):

            epoch_loss = []
            b_epoch_loss = []
            t_epoch_loss = []
            r_epoch_loss = []
            i_num = 0
            for i in keys:

                g, features, num_cons, num_vars = generate_mip_graph(i, graph_features=False)

                # Some MIP instances are too large to fit in GPU
                if g.num_nodes() < 30000:

                    g, features = g.to(device), features.to(device)
                    for _ in range(steps_per_instance):

                        # Forward Pass Through VQ Graph
                        h_list, logits, loss, distances, codebook_ = self(g, features, num_cons, num_vars)

                        logits = logits[num_cons:, :]
                        prob = h_list[-1][num_cons:, :]

                        triplets = mips_to_triplet[i]['triplets']

                        # Initially, positive variables are those that appear in the same number of solutions
                        # After a certain point, we shuffle the positives so that a variable appearing in 
                        # ANY solution is considered as a positive triplet 
                        if epoch > epochs // 2:
                            perm_1 = np.random.permutation(triplets.shape[0])
                            perm_2 = np.random.permutation(triplets.shape[0])
                            triplets[:, 0] = triplets[perm_1, 0]
                            triplets[:, 1] = triplets[perm_2, 1]

                        # Number of triplets can be quite large - need to sample to fit in GPU
                        if len(triplets) > 3000000:
                            random_indices = np.random.choice(range(len(triplets)), size=2000000, replace=False)
                            triplets = triplets[random_indices]

                        y_true = mips_to_triplet[i]['y_true'].to(device)

                        if len(triplets) < batch_size:
                            t_loss = triplet_loss(logits[triplets[:, 0]], logits[triplets[:, 1]],
                                                  logits[triplets[:, 2]])

                        else:
                            t_loss = 0
                            for batch in np.array_split(triplets, np.ceil(len(triplets) / batch_size)):
                                t_loss += triplet_loss(logits[batch[:, 0]], logits[batch[:, 1]], logits[batch[:, 2]])

                        bce_loss = F.binary_cross_entropy(prob, y_true)
                        final_loss = (10 * t_loss) + (0.05 * bce_loss) + (0.01 * loss)

                        optimizer.zero_grad()
                        final_loss.backward()
                        optimizer.step()

                    print("\r", "(", i_num, "/", len(mips_to_triplet), ") Instance : ", i, "| Triplet Loss : ",
                          np.round(t_loss.item(), 3), "| Recon Loss : ", np.round(loss.item(), 3), "| BCE Loss : ",
                          np.round(bce_loss.item(), 3), end='')

                    epoch_loss.append(final_loss.item())
                    b_epoch_loss.append(bce_loss.item())
                    t_epoch_loss.append(t_loss.item())
                    r_epoch_loss.append(loss.item())
                    i_num += 1

            print("\nEpoch ", epoch + 1, "| Means | Loss : ", np.mean(epoch_loss), "| Triplet Loss : ",
                  np.mean(t_epoch_loss), "| BCE Loss : ", np.mean(b_epoch_loss), "| Recon Loss : ",
                  np.mean(r_epoch_loss))
            print()

            torch.save(self.state_dict(), model_save_path)
            np.random.shuffle(keys)
            self.trained = True

    def train_lp_gaps(self, pretrained_path, model_save_path, mips_to_gaps=None, mips_to_gaps_path=None, epochs=10,
                      steps_per_instance=10, lr=1e-4):

        # Initialize with pretrained model
        if self.prob_head is False:
            self.prob_head = True
        if self.cut_head is False:
            self.cut_head = True

        pre_trained = Forge(prob_head=True)
        pre_trained.load_model(pretrained_path, model_type='warm_start')
        copy_params(old_model=pre_trained, new_model=self)
        del pre_trained
        torch.cuda.empty_cache()

        self.train()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4)

        if mips_to_gaps_path is None and mips_to_gaps is None:
            raise ValueError(
                'Pass in either the mips_to_gaps dictionary or the path to the dictionary stored as a pickle file')

        elif mips_to_gaps is None:
            with open(mips_to_gaps_path, 'rb') as file:
                mips_to_gaps = pkl.load(file)
        else:
            None

        keys = list(mips_to_gaps.keys())

        for epoch in range(epochs):

            epoch_loss = []
            b_epoch_loss = []
            c_epoch_loss = []

            for i_num, i in enumerate(keys):

                g, features, num_cons, num_vars = generate_mip_graph(i, graph_features=False)

                # Some instances are too big to fit in the GPU
                if g.num_nodes() <= 30000:
                    g, features = g.to(device), features.to(device)

                    for _ in range(steps_per_instance):
                        optimizer.zero_grad()

                        # Forward pass through the GNN
                        h_list, logits, loss, distances, codebook_ = self(g, features, num_cons, num_vars)

                        cut_pred = torch.mean(h_list[-1][num_cons:, :])
                        cut_true = mips_to_gaps[i]['ratio']
                        if cut_true > 1:
                            cut_true = 1 / cut_true

                        prob_pred = h_list[-2][num_cons:, :]
                        prob_true = torch.Tensor(mips_to_gaps[i]['mip_sol']).to(device)

                        try:
                            bce_loss = F.binary_cross_entropy(prob_pred.flatten(), prob_true.flatten())
                            cut_loss = torch.abs(cut_pred - cut_true)
                        except:
                            continue

                        loss = cut_loss + (0.01 * bce_loss)
                        loss.backward()
                        optimizer.step()

                        print('', '(', i_num, '/', len(keys), ') |', i, ' | Cut Loss :', cut_loss.item(), 'BCE Loss :',
                              bce_loss.item(), end='\r')

                        epoch_loss.append(loss.item())
                        b_epoch_loss.append(bce_loss.item())
                        c_epoch_loss.append(cut_loss.item())

            print("\nEpoch ", epoch + 1, "| Means | Loss : ", np.mean(epoch_loss), "| Cut Loss : ",
                  np.mean(c_epoch_loss), "| BCE Loss : ", np.mean(b_epoch_loss))
            print()

            torch.save(self.state_dict(), model_save_path)
            np.random.shuffle(keys)

    def load_model(self, gnn_model_path, model_type='none'):

        # TODO style comment, you can do opposite, so that you don't need indentation nesting inside
        # # if self.is_trained
        #     continue/pass
        # device = xx

        if not self.trained:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self = self.to(device)

            if model_type == 'unsupervised':
                self.prob_head = False
                self.cut_head = False
            elif model_type == 'warm_start':
                self.prob_head = True
                self.cut_head = False
            elif model_type == 'lp_gap':
                self.prob_head = True
                self.cut_head = True
            else:
                None

            self.load_state_dict(torch.load(gnn_model_path, map_location=device))
            self.trained = True

    def mip_to_vector(self, mip_instance, gnn_model_path=None) -> [np.ndarray, np.ndarray, np.ndarray]:
        """
        mip_instance : this can be a path to a LP or MPS file or a gurobi object
        gnn_model_path : the trained model path to use for forward propagation
        # TODO explain output
        """

        if gnn_model_path is None and self.trained is False:
            print("Warning: Using an untrained randomly initialized model. "
                  "Pass in a trained model through gnn_model_path.")

        if gnn_model_path is not None:
            self.load_model(gnn_model_path)

        # TODO pycharm suggests to use is/is not OR isintance()
        if type(mip_instance) == type(""):
            gurobi_object = False
        else:
            gurobi_object = True

        if not self.trained:
            self.load_model(gnn_model_path)

        self.eval()
        current_val = self.eval_only
        self.eval_only = True

        # Process LP or MPS files or a Gurobi object into DGL format
        g, features, num_cons, num_vars = generate_mip_graph(mip_instance,
                                                             graph_features=False,
                                                             gurobi_object=gurobi_object)

        # Forward Pass Through GNN
        h_list, logits, loss, distances, codebook_ = self.forward(g.to(device), features.to(device), num_cons, num_vars)

        # Compute a Vector for Each MIP Instance
        # This Vector is a Distribution of the Codes that Constraints and Variables
        # in the MIP Instance Belong to.
        assigned_codes = torch.argmin(distances, axis=1).detach().cpu().numpy()
        mip_vec = np.zeros(self.codebook_size, )
        for c in assigned_codes:
            mip_vec[c] += 1

        logits = logits.detach().cpu().numpy()
        self.eval_only = current_val
        return mip_vec, h_list[1][:num_cons], h_list[1][num_cons:]

    def mip_to_lp_cut(self, mip_instance_path, gnn_model_path=None, prob_type='SC', return_metadata=False, threads=1):

        if gnn_model_path is None and self.trained is False:
            print("Warning: Using an untrained randomly initialized model. "
                  "Pass in a trained model through gnn_model_path.")

        if gnn_model_path is not None:
            self.load_model(gnn_model_path, model_type='lp_gap')

        current_val = self.eval_only
        self.eval_only = True

        # Generate DGL graph
        g, features, num_cons, num_vars = generate_mip_graph(mip_instance_path, graph_features=False)

        # Compute forward pass through the GNN
        h_list, logits, loss, distances, codebook_ = self(g.to(device), features.to(device), num_cons, num_vars)

        prob = h_list[-2][num_cons:]
        cut = h_list[-1][num_cons:]
        cut_ratio = torch.mean(cut).item()

        # Read and solve the LP relaxation to generate initial objective value 
        gurobi_venv.setParam("Threads", threads)
        lp = gp.read(mip_instance_path, env=gurobi_venv)
        lp = lp.relax()
        lp.optimize()

        if prob_type in ['SC']:
            cut_ratio += (0.02 * cut_ratio)  # Add a buffer to the ratio to make sure we are not infeasible
            cut_val = lp.ObjVal + (lp.ObjVal * (1 - cut_ratio))
        elif prob_type in ['MVC']:
            cut_val = lp.ObjVal + (lp.ObjVal * (1 - cut_ratio))
        elif prob_type in ['GISP']:
            cut_ratio += (0.2 * cut_ratio)
            cut_val = lp.ObjVal * cut_ratio
        elif prob_type in ['CA']:
            cut_ratio += (0.05 * cut_ratio)  # Add a buffer to the ratio to make sure we are not infeasible
            cut_val = lp.ObjVal * cut_ratio
        self.eval_only = current_val
        if return_metadata:
            return cut_val, cut_ratio, lp.ObjVal

        return cut_val

    def mip_to_hint(self, mip_instance_path, gnn_model_path=None, debug=0, prob_type='SC') -> List[np.ndarray]:

        if gnn_model_path is None and self.trained is False:
            print(
                "Warning: Using an untrained randomly initialized model. Pass in a trained model through gnn_model_path.")

        if gnn_model_path is not None:
            self.load_model(gnn_model_path, model_type='warm_start')

        current_val = self.eval_only
        self.eval_only = True

        # Generate DGL graph
        g, features, num_cons, num_vars = generate_mip_graph(mip_instance_path, graph_features=False)

        # Compute forward pass through the GNN
        h_list, logits, loss, distances, codebook_ = self(g.to(device), features.to(device), num_cons, num_vars)

        # Get probability on each variable
        output = h_list[-1].detach().cpu().numpy()[num_cons:, :]

        if prob_type in ['GISP']:
            upper_perc = 98
            lower_perc = 5
        elif prob_type in ['CA']:
            upper_perc = 98
            lower_perc = 10
        else:
            upper_perc = 95
            lower_perc = 10

        bce_ones = np.where(output >= np.percentile(output, upper_perc))[0]
        bce_zeros = np.where(output <= np.percentile(output, lower_perc))[0]

        # Ensure we only look at variables and not constraints
        logits = logits[num_cons:, :]
        logits = logits.detach().cpu().numpy()

        gurobi_venv.setParam("OutputFlag", debug)
        gurobi_venv.setParam("Threads", 1)
        gurobi_venv.setParam("LPWarmStart", 2)
        gurobi_venv.setParam("TimeLimit", 1)
        gurobi_venv.setParam("MIPFocus", 0)
        seed = gp.read(mip_instance_path, env=gurobi_venv)

        s_ = time.time()
        if debug:
            print("Optimizing Seed MIP\n")
        seed.optimize()
        seed_xn = np.array([x.x for x in seed.getVars()])

        # Get seed variables
        seed_sols = np.where(seed_xn == 1)[0]
        seed_zeros = np.where(seed_xn == 0)[0]

        if debug:
            print("Finding Nearest Neighbors\n")

        # Compute nearest neighbors
        num_neigh = 50
        nbrs = NearestNeighbors(n_neighbors=num_neigh, algorithm='kd_tree', p=2).fit(logits)
        distances, indices = nbrs.kneighbors(logits[seed_sols])

        # Normalize distances to between 0 and 1 
        distances = (distances - np.min(distances)) / np.ptp(distances)
        max_distance = np.max(distances)

        # To compute nearest negative neighbors, 
        # look at varaibles at a distance beyond the maximum distance of the predicted 
        # positive nodes
        pred_zeros = set()
        for i in seed_zeros:
            for j in seed_sols:
                if prob_type in ['GISP', 'MVC']:
                    delta = 0.3
                else:
                    delta = 0.2
                if np.linalg.norm(logits[i] - logits[j], ord=2) > (max_distance + (delta * max_distance)):
                    pred_zeros.add(i)

        # Add every node within a radius of 0.2 from the seed node as a positive node 
        neighbors = set()
        for i in range(len(distances)):
            for j in range(len(distances[0])):

                if distances[i][j] <= 0.3:
                    neighbors.add(indices[i][j])

        pred_hints = list(np.unique(np.concatenate([seed_sols, list(neighbors)])))

        # Ensure there are no overlapping nodes between positive and negative predictions
        intersection = list(set(pred_hints).intersection(set(pred_zeros)))
        pred_hints = [x for x in pred_hints if x not in intersection]
        pred_zeros = [x for x in pred_zeros if x not in intersection]

        m = gp.read(mip_instance_path, env=gurobi_venv)
        variables = m.getVars()

        # Final hints are produced by voting from 3 sources 
        # Source 1. Initial MIP solve to generate seeds 
        # Source 2. Output of prob head 
        # Source 3. Distance based predictions 

        final_ones = np.zeros(len(variables))
        final_ones[pred_hints] += 1
        final_ones[seed_sols] += 1
        final_ones[bce_ones] += 1

        final_zeros = np.zeros(len(variables))
        final_zeros[pred_zeros] += 1
        final_zeros[seed_zeros] += 1
        final_zeros[bce_zeros] += 1

        # Compute ranks for each of the hints
        # These ranks are computed based on distances
        # For positive hints, ranks are based on distance to nearest seed variable
        dist = []
        for i in pred_hints:
            min_dist = np.inf
            for j in seed_sols:
                d = np.linalg.norm(logits[i] - logits[j])
                if d < min_dist:
                    min_dist = d
            dist.append(min_dist)

        pos_ranks = np.array([100 - (np.sum(dist <= x) / len(dist) * 100) for x in dist])
        pos_ranks = (pos_ranks - np.min(pos_ranks)) / (np.ptp(pos_ranks) + 1e-6)
        pos_ranks = (pos_ranks * 100).astype(int)

        # For negative hints, ranks are based on which node is farthest from a 
        # positive hint node 
        dist = []
        for i in pred_zeros:
            max_dist = 0
            for j in pred_hints:
                d = np.linalg.norm(logits[i] - logits[j])
                if d > max_dist:
                    max_dist = d
            dist.append(max_dist)

        neg_ranks = np.array([np.sum(dist <= x) / len(dist) * 100 for x in dist])
        neg_ranks = (neg_ranks - np.min(neg_ranks)) / (np.ptp(neg_ranks) + 1e-6)
        neg_ranks = (neg_ranks * 100).astype(int)

        # For hints from the probablity head, ranks are simply based on probability of each node
        b_o = output[bce_ones]
        b_pos_rank = [np.sum(b_o <= x) / len(b_o) * 100 for x in b_o]
        b_pos_rank = (b_pos_rank - np.min(b_pos_rank)) / (np.ptp(b_pos_rank) + 1e-6)
        b_pos_rank = (b_pos_rank * 100).astype(int)

        b_z = output[bce_zeros]
        b_neg_rank = [100 - (np.sum(b_z <= x) / len(b_z) * 100) for x in b_z]
        b_neg_rank = (b_neg_rank - np.min(b_neg_rank)) / (np.ptp(b_neg_rank) + 1e-6)
        b_neg_rank = (b_neg_rank * 100).astype(int)

        # Some book keeping to make sure the ranks are properly assigned 
        global_pos_rank = np.zeros((num_vars, 1)) + 1
        for idx, i in enumerate(pred_hints):
            global_pos_rank[i] = pos_ranks[idx]

        for idx, i in enumerate(bce_ones):
            global_pos_rank[i] += b_pos_rank[idx]

        global_neg_rank = np.zeros((num_vars, 1)) + 1
        for idx, i in enumerate(pred_zeros):
            global_neg_rank[i] = neg_ranks[idx]

        for idx, i in enumerate(bce_zeros):
            global_neg_rank[i] += b_neg_rank[idx]

        del output, logits, h_list, loss, distances, codebook_

        # Update model with hints and priorities
        hint_ones = np.where(final_ones == 3)[0]
        hint_zeros = np.where(final_zeros == 3)[0]
        hint_pri_ones = [int(x) for x in global_pos_rank[hint_ones]]
        hint_pri_zeros = [int(x) for x in global_neg_rank[hint_zeros]]
        self.eval_only = current_val
        return [hint_ones, hint_zeros, hint_pri_ones, hint_pri_zeros]
