"""Lightning module for training the DIFUSCO TSP model."""

import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from lightning.pytorch.utilities import rank_zero_info

from difusco.co_datasets.tsp_graph_dataset import TSPGraphDataset
from difusco.pl_meta_model import COMetaModel
from difusco.utils.diffusion_schedulers import InferenceSchedule
from difusco.utils.tsp_utils import (
        TSPEvaluator,
        batched_two_opt_torch,
        merge_tours,
        make_tour_to_graph,
)

from difusco.guides import ValueGuide
import time


class TSPGuidedModel(COMetaModel):
        def __init__(self, param_args=None, classifier=None):
                super(TSPGuidedModel, self).__init__(
                        param_args=param_args, node_feature_only=False
                )

                self.train_dataset = TSPGraphDataset(
                        data_file=os.path.join(self.args.storage_path, self.args.training_split),
                        sparse_factor=self.args.sparse_factor,
                )

                self.test_dataset = TSPGraphDataset(
                        data_file=os.path.join(self.args.storage_path, self.args.test_split),
                        sparse_factor=self.args.sparse_factor,
                )

                self.validation_dataset = TSPGraphDataset(
                        data_file=os.path.join(self.args.storage_path, self.args.validation_split),
                        sparse_factor=self.args.sparse_factor,
                )
                # if classifier is not None:
                #         self.classifier = classifier

        def forward(self, x, adj, t, edge_index):
                return self.model(x, t, adj, edge_index)

        # classifier training step
        def continuous_classifier_training_step(self, classifier_model, batch, batch_idx):
                if self.sparse:
                # TODO: Implement Gaussian diffusion with sparse graphs
                        raise ValueError("DIFUSCO with sparse graphs are not supported for Gaussian diffusion")
                _, points, adj_matrix, _, _ = batch
                np_points = points.cpu().numpy()
                batch_size = points.shape[0]
                np_edge_index = None
                adj_mat = adj_matrix.clone()


                #Calculate original X0 reward (True cost)
                tours, merge_iterations = merge_tours(
                        adj_mat.cpu().detach().numpy(),
                        np_points,
                        np_edge_index,
                        sparse_graph=self.sparse,
                        batch_size=batch_size,
                        guided=True)
                solved_tours = torch.tensor(tours)
                if self.sparse:
                        splitted_points = np.split(np_points, batch_size, axis=0)
                        tsp_solver = TSPEvaluator(splitted_points, batch=True)
                else:
                        tsp_solver = TSPEvaluator(np_points, batch=True)
                true_cost = tsp_solver.evaluate(solved_tours)


                adj_matrix = adj_matrix * 2 - 1
                adj_matrix = adj_matrix * (1.0 + 0.05 * torch.rand_like(adj_matrix))
                # Sample from diffusion
                t = np.random.randint(1, self.diffusion.T + 1, adj_matrix.shape[0]).astype(int)
                xt, epsilon = self.diffusion.sample(adj_matrix, t)


                # print(true_cost)
                # print(true_cost.shape)

                if self.args.diffusion_type == "classifier_v1":
                        t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
                        cost_pred = self.forward(
                        points.float().to(adj_matrix.device),
                        xt.float().to(adj_matrix.device),
                        t.float().to(adj_matrix.device),
                        None)


                        
                elif self.args.diffusion_type == "classifier_v2":
                        # predict by x_0 hat
                        atbar = self.diffusion.alphabar[t]
                        atbar = torch.tensor(atbar).float().to(adj_matrix.device).reshape(-1, 1, 1)
                        
                        # use x_0
                        t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
                        epsilon = classifier_model.forward(
                                points.float().to(adj_matrix.device),
                                xt.float().to(adj_matrix.device),
                                t.float().to(adj_matrix.device),
                                None,
                        )
                        epsilon = epsilon.squeeze(1)
                        x0_hat = (xt - torch.sqrt(1.0 - atbar) * epsilon) / torch.sqrt(atbar)
                        
                        cost_pred = self.forward(
                        points.float().to(adj_matrix.device),
                        x0_hat.float().to(adj_matrix.device),
                        t.float().to(adj_matrix.device),
                        None)


                # adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
                # tours, merge_iterations = merge_tours(
                #         adj_mat,
                #         np_points,
                #         np_edge_index,
                #         sparse_graph=self.sparse,
                #         batch_size=batch_size,
                #         guided=True,
                # # )


                # use x_0

                # epsilon = classifier_model.forward(
                #         points.float().to(adj_matrix.device),
                #         xt.float().to(adj_matrix.device),
                #         t.float().to(adj_matrix.device),
                #         None,
                # )
                # epsilon = epsilon.squeeze(1)

                # x0_hat = (xt - torch.sqrt(1.0 - atbar) * epsilon) / torch.sqrt(atbar)


                # solved_tours = torch.tensor(tours)
                # if self.sparse:
                #         splitted_points = np.split(np_points, batch_size, axis=0)
                #         tsp_solver = TSPEvaluator(splitted_points, batch=True)
                # else:
                #         tsp_solver = TSPEvaluator(np_points, batch=True)
                # # print("here3")
                # true_cost = tsp_solver.evaluate(solved_tours)
                # # Denoise

                # cost_pred = cost_pred.squeeze(1)

                # Compute loss
                loss = F.mse_loss(cost_pred, true_cost.reshape(-1, 1).to(adj_matrix.device))

                # loss = F.mse_loss(cost_pred, true_cost.float().reshape(-1, 1).to(adj_matrix.device))

                self.log("train/loss", loss)
                return loss

        def discrete_classifier_training_step(self, classifier_model, batch, batch_idx):
                edge_index = None
                np_edge_index = None
                stacked_tours = []

                if not self.sparse:
                        _, points, adj_matrix, _, cost = batch
                        t = np.random.randint(1, self.diffusion.T + 1, 1).repeat(points.shape[0]).astype(int)
                        # t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
                        np_points = points.cpu().numpy()
                        batch_size = points.shape[0]
                else:
                        _, graph_data, point_indicator, edge_indicator, _, _ = batch
                        t = np.random.randint(
                                1, self.diffusion.T + 1, point_indicator.shape[0]
                        ).astype(int)
                        route_edge_flags = graph_data.edge_attr
                        points = graph_data.x
                        edge_index = graph_data.edge_index
                        np_points = points.cpu().numpy()
                        num_edges = edge_index.shape[1]
                        batch_size = point_indicator.shape[0]
                        adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))

                adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()

                if self.sparse:
                        adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

                # Sample from diffusion
                xt = self.diffusion.sample(adj_matrix_onehot, t)
                xt = xt * 2 - 1
                xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

                adj_matrix_onehot = adj_matrix_onehot.reshape(1, points.shape[0], -1, 2)
                graph_size = points.shape[0] // batch_size


                if self.sparse:
                        xt = xt.reshape(-1)
                        t = torch.from_numpy(t).float()
                        t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
                        xt = xt.reshape(-1)
                        adj_matrix = adj_matrix.reshape(-1)
                        points = points.reshape(-1, 2)
                        edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
                        edge_index = edge_index.reshape(2, batch_size, -1)
                        revise_edge_idx = (
                                torch.arange(0, batch_size)[None, :, None]
                                .repeat(edge_index.shape[0], 1, edge_index.shape[2])
                                .to(edge_index.device)
                        )
                        np_edge_index = (edge_index - graph_size * revise_edge_idx).cpu().numpy()
                        np_edge_index = np_edge_index.reshape(2, -1).astype("int64")
                else:
                        t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
                # Denoise
                with torch.no_grad():
                        x0_logit = classifier_model.forward(points.float().to(adj_matrix.device),
                                xt.float().to(adj_matrix.device),
                                t.float().to(adj_matrix.device),
                                edge_index)

                        if not self.sparse:
                                        x0_pred_prob = (
                                                x0_logit.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
                                        )
                        else:
                                        x0_pred_prob = x0_logit.reshape((1, points.shape[0], -1, 2)).softmax(
                                                dim=-1
                                        )

                        t1 = torch.from_numpy(np.array([t.reshape(-1)[0]]).astype(int)).view(1)
                        t2 = np.array([t.reshape(-1)[0] - 1]).astype(int)
                        xt = (xt > 0).long()
                        parm_p, xt = self.categorical_posterior(
                                        t2, t1, x0_pred_prob, xt, classifier=None,classifier_training=True
                                )
                # parm_p = (parm_p + parm_p.transpose(1,2))/2
                # parm_p = parm_p + parm_p.
                value_pred, _ = self.forward(
                        points.float().to(adj_matrix.device),
                        x0_logit.float().to(adj_matrix.device),
                        t.float().to(adj_matrix.device),
                        edge_index,
                )
                if self.diffusion_space == "discrete":
                        xt = (xt > 0).long()
                        adj_mat = xt.float().cpu().detach().numpy() + 1e-6
                else:
                        adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5


                tours, merge_iterations = merge_tours(
                        adj_mat,
                        np_points,
                        np_edge_index,
                        sparse_graph=self.sparse,
                        batch_size=batch_size,
                        guided=True,
                )

                # tours, merge_iterations = merge_tours(
                #         adj_matrix.cpu().detach().numpy(),
                #         np_points,
                #         np_edge_index,
                #         sparse_graph=self.sparse,
                #         batch_size=batch_size,
                #         guided=True,
                # )

                solved_tours = torch.tensor(tours)
                if self.sparse:
                        splitted_points = np.split(np_points, batch_size, axis=0)
                        tsp_solver = TSPEvaluator(splitted_points, batch=True)
                else:
                        tsp_solver = TSPEvaluator(np_points, batch=True)
                # print("here3")
                true_cost = tsp_solver.evaluate(solved_tours)
                loss_func = nn.MSELoss()
                loss = loss_func(value_pred, true_cost.reshape(-1, 1).float().to(adj_matrix.device))

                self.log("train/loss", loss)
                return loss

        def training_step(self, batch, batch_idx):
                if self.args.diffusion_space == "discrete":
                        return self.discrete_classifier_training_step(self.classifier_model, batch, batch_idx)
                else:
                        return self.continuous_classifier_training_step(self.classifier_model, batch, batch_idx)

        def categorical_denoise_step(
                self, points, xt, t, device, edge_index=None, target_t=None, classifier=None
        ):
                # classifier=None
                with torch.no_grad():
                        t = torch.from_numpy(t).view(1)
                        x0_pred = self.forward(
                                points.float().to(device),
                                xt.float().to(device),
                                t.float().to(device),
                                edge_index.long().to(device) if edge_index is not None else None,
                        )

                if not self.sparse:
                        x0_pred_prob = (
                                x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
                        )
                else:
                        x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(
                                dim=-1
                        )
                # classifier = None
                y, grad = classifier.gradients(
                                        points.to(device),
                                        t.float().to(device),
                                        xt.float().to(device),
                                        edge_index.long().to(device)
                                        if edge_index is not None
                                        else None,
                                )

                xt = self.categorical_posterior(
                        target_t, t, x0_pred_prob, xt, classifier=grad
                )

                return xt

        def gaussian_denoise_step(
                self, points, xt, t, device, edge_index=None, target_t=None, classifier=None, prev_y=None, pre_grad=None
        ):

                with torch.no_grad():
                        t = torch.from_numpy(t).view(1)
                        pred = self.forward(
                                points.float().to(device),
                                xt.float().to(device),
                                t.float().to(device),
                                edge_index.long().to(device) if edge_index is not None else None,
                        )

                y, grad = classifier.gradients(
                                points.to(device),
                                t.float().to(device),
                                xt.float().to(device),
                                pred,
                                self.diffusion.alphabar[t],
                                edge_index.long().to(device)
                                if edge_index is not None
                                else None,
                                diffusion_type=self.diffusion_type)
                        # xt = xt.detach() + 0.1 * grad
                # print(y)
                pred = pred.squeeze(1)
                xt = self.gaussian_posterior(target_t, t, pred, xt, guidance=grad)
                return xt, y, grad

        def test_step(self, batch, batch_idx, split="test"):
                torch.set_grad_enabled(True)
                return self.classifier_test_step(self.classifier_model, batch, batch_idx, split)

        def classifier_test_step(self, classifier_model, batch, batch_idx, split="test"):
                # debugging mode
                # value_guided_model = TSPModel.load_from_checkpoint("project/diffusion/result/tsp_result/train_classifier/my_checkpoints2/last.ckpt", param_args=self.args)

                edge_index = None
                np_edge_index = None
                device = batch[-1].device
                stacked_tours = []

                value_guided_model = classifier_model
                model = value_guided_model.model
                classifier = ValueGuide(model)

                if not self.sparse:
                        real_batch_idx, points, adj_matrix, gt_tour, _ = batch
                        t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
                        np_points = points.cpu().numpy()
                        batch_size = points.shape[0]
                        if batch_size == 1:
                                np_gt_tour = gt_tour.cpu().numpy()[0]
                        else:
                                np_gt_tour = gt_tour

                else:
                        real_batch_idx, graph_data, point_indicator, edge_indicator, _, _ = batch
                        t = np.random.randint(
                                1, self.diffusion.T + 1, point_indicator.shape[0]
                        ).astype(int)
                        route_edge_flags = graph_data.edge_attr
                        points = graph_data.x
                        edge_index = graph_data.edge_index
                        np_points = points.cpu().numpy()
                        num_edges = edge_index.shape[1]
                        batch_size = point_indicator.shape[0]
                        adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))

                stacked_tours = []
                ns, merge_iterations = 0, 0

                for _ in range(self.args.sequential_sampling):
                        xt = torch.randn_like(adj_matrix.float())
                        if self.args.parallel_sampling > 1:
                                if not self.sparse:
                                        xt = xt.repeat(self.args.parallel_sampling, 1, 1)
                                else:
                                        xt = xt.repeat(self.args.parallel_sampling, 1)
                                xt = torch.randn_like(xt)
                                # print(xt.shape)
                        if self.diffusion_type == "gaussian":
                                xt.requires_grad = True
                        else:
                                xt = (xt > 0).long()

                        if self.sparse:
                                xt = xt.reshape(-1)

                        steps = self.args.inference_diffusion_steps
                        time_schedule = InferenceSchedule(
                                inference_schedule=self.args.inference_schedule,
                                T=self.diffusion.T,
                                inference_T=steps,
                        )
                        # Diffusion iterations
                        y = None
                        grad = None
                        tsp_solver = TSPEvaluator(np_points, batch=True)
                        dis = tsp_solver.dist_mat
                        for i in range(steps):
                                t1, t2 = time_schedule(i)
                                t1 = np.array([t1]).astype(int)
                                t2 = np.array([t2]).astype(int)

                                if self.diffusion_space == "continuous":
                                        
                                        # Calculate reward test for each diffusion step
                                        # xt_1 = xt.clone()
                                        # adj_mat_1 = ((xt_1 * 0.5 + 0.5)).cpu().detach().numpy()
                                        # tsp_solver = TSPEvaluator(np_points, batch=True)
                                        # tours, merge_iterations = merge_tours(
                                        # adj_mat_1,
                                        # np_points,
                                        # np_edge_index,
                                        # sparse_graph=self.sparse,
                                        # batch_size=batch_size,
                                        # guided=True,
                                        # )
                                        # tours = torch.tensor(tours)
                                        # # print("without_2opt : ", wo_2opt_costs)
                                        # wo_2opt_cost = tsp_solver.evaluate(tours)
                                        # wo_2opt_costs = wo_2opt_cost.mean()
                                        # # print("{}-step reward : {}".format(i, y))
                                        # print("{}-step real_reward : {}".format(i, wo_2opt_cost))
                                        # print("{}-step real_reward : {}".format(i, wo_2opt_costs))
                                        xt, y, grad = self.gaussian_denoise_step(
                                                points, xt, t1, device, edge_index, target_t=t2, classifier=classifier, prev_y=y, pre_grad=grad
                                        )
                                else:
                                        xt = self.categorical_denoise_step(
                                                points, xt, t1, device, edge_index, target_t=t2, classifier=classifier
                                        )

                        if self.diffusion_space == "continuous":
                                # print("pred : ", y)
                                # tsp_solver = TSPEvaluator(np_points, batch=True)
                                # dis = tsp_solver.dist_mat
                                adj_mat = ((xt * 0.5) + 0.5)
                                adj_mat = adj_mat.cpu().detach().numpy()
                                # adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
                                # adj_mat = adj_mat * dis
                        else:
                                adj_mat = xt.float().cpu().detach().numpy() + 1e-6

                        if self.args.save_numpy_heatmap:
                                self.run_save_numpy_heatmap(adj_mat, np_points, real_batch_idx, split)

                        tours, merge_iterations = merge_tours(
                                adj_mat,
                                np_points,
                                np_edge_index,
                                sparse_graph=self.sparse,
                                batch_size=batch_size,
                                guided=True,
                        )
                        ## without 2-opt
                        if batch_size == 1:
                                tsp_solver = TSPEvaluator(np_points[0])
                                wo_2opt_costs = tsp_solver.evaluate(tours[0])
                                # print("without_2opt : ", wo_2opt_costs)
                                # tsp_solver = TSPEvaluator(np_points, batch=True)
                                # wo_2opt_costs = tsp_solver.evaluate(tours)

                                # print(wo_2opt_costs)
                                # Refine using 2-opt
                                
                                solved_tours, ns = batched_two_opt_torch(
                                        np_points[0].astype("float64"),
                                        np.array(tours).astype("int64"),
                                        max_iterations=self.args.two_opt_iterations,
                                        device=device,
                                )

                                stacked_tours.append(solved_tours)
                                solved_tours = np.concatenate(stacked_tours, axis=0)

                                tsp_solver = TSPEvaluator(np_points[0])
                                gt_cost = tsp_solver.evaluate(np_gt_tour)

                                total_sampling = (
                                        self.args.parallel_sampling * self.args.sequential_sampling
                                )
                                all_solved_costs = [
                                        tsp_solver.evaluate(solved_tours[i]) for i in range(total_sampling)
                                ]
                                best_solved_cost = np.min(all_solved_costs)
                                # print(best_solved_cost)
                        else:
                                # calculate before 2-opt cost

                                tsp_solver = TSPEvaluator(np_points, batch=True)
                                tours = torch.tensor(tours)
                                # print("without_2opt : ", wo_2opt_costs)
                                wo_2opt_cost = tsp_solver.evaluate(tours)
                                wo_2opt_costs = wo_2opt_cost.mean()
                                print("without_2opt : ", wo_2opt_cost)

                                # Refine using 2-opt
                                solved_tours, ns = batched_two_opt_torch(
                                        np_points.astype("float64"),
                                        np.array(tours).astype("int64"),
                                        max_iterations=self.args.two_opt_iterations,
                                        device=device,
                                        batch=True,
                                )
                                stacked_tours.append(solved_tours)
                                solved_tours = np.concatenate(stacked_tours, axis=0)
                                solved_tours = torch.tensor(solved_tours)

                                gt_cost = tsp_solver.evaluate(gt_tour).mean()
                                best_solved_cost = tsp_solver.evaluate(solved_tours).mean()

                metrics = {
                        f"{split}/wo_2opt_cost": wo_2opt_costs,
                        f"{split}/gt_cost": gt_cost,
                }
                for k, v in metrics.items():
                        self.log(k, v, on_epoch=True, sync_dist=True)
                self.log(
                        f"{split}/solved_cost",
                        best_solved_cost,
                        prog_bar=True,
                        on_epoch=True,
                        sync_dist=True,
                )
                return metrics

        def value_function_vaild_step(self, classifier_model, batch, batch_idx, split="val"):

                if self.sparse:
                # TODO: Implement Gaussian diffusion with sparse graphs
                        raise ValueError("DIFUSCO with sparse graphs are not supported for Gaussian diffusion")

                _, points, adj_matrix, _, _ = batch
                np_points = points.cpu().numpy()
                batch_size = points.shape[0]
                np_edge_index = None
                step = np.random.randint(1, 10, 1).astype(int)[0]
                if self.args.do_valid_only:
                        cost = cost[step:step+1].repeat(batch_size, 1)
                        adj_matrix = adj_matrix[step:step+1].repeat(batch_size, 1, 1)
                        points = points[step:step+1].reshape(batch_size, 1, 1)
                        t = np.arange(1, self.diffusion.T + 1, step=100).reshape(adj_matrix.shape[0]).astype(int)
                else:
                        t = np.random.randint(1, self.diffusion.T + 1, adj_matrix.shape[0]).astype(int)
             
                adj_mat = adj_matrix.clone()             
                tours, merge_iterations = merge_tours(
                        adj_mat.cpu().detach().numpy(),
                        np_points,
                        np_edge_index,
                        sparse_graph=self.sparse,
                        batch_size=batch_size,
                        guided=True)
                solved_tours = torch.tensor(tours)
                if self.sparse:
                        splitted_points = np.split(np_points, batch_size, axis=0)
                        tsp_solver = TSPEvaluator(splitted_points, batch=True)
                else:
                        tsp_solver = TSPEvaluator(np_points, batch=True)
                true_cost = tsp_solver.evaluate(solved_tours)


                adj_matrix = adj_matrix * 2 - 1
                adj_matrix = adj_matrix * (1.0 + 0.05 * torch.rand_like(adj_matrix))
                # Sample from diffusion

                xt, epsilon = self.diffusion.sample(adj_matrix, t)

                if self.diffusion_type == "classifier_v1":
                        t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
                        cost_pred = self.forward(
                        points.float().to(adj_matrix.device),
                        xt.float().to(adj_matrix.device),
                        t.float().to(adj_matrix.device),
                        None)
                elif self.diffusion_type == "classifier_v2":
                        
                        atbar = self.diffusion.alphabar[t]
                        atbar = torch.tensor(atbar).float().to(adj_matrix.device).reshape(-1, 1, 1)
                        
                        t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
                        # use x_0
                        epsilon = classifier_model.forward(
                                points.float().to(adj_matrix.device),
                                xt.float().to(adj_matrix.device),
                                t.float().to(adj_matrix.device),
                                None,
                        )
                        epsilon = epsilon.squeeze(1)
                        x0_hat = (xt - torch.sqrt(1.0 - atbar) * epsilon) / torch.sqrt(atbar)
                
                        cost_pred = self.forward(
                        points.float().to(adj_matrix.device),
                        x0_hat.float().to(adj_matrix.device),
                        t.float().to(adj_matrix.device),
                        None)
                # print("time : ",t.reshape(-1))
                # print("true cost : ",true_cost.reshape(-1))
                # print("pred cost : ",cost_pred.reshape(-1))

                # # Compute loss
                loss = F.mse_loss(cost_pred, true_cost.float().reshape(-1, 1).to(adj_matrix.device))

                self.log(
                        f"{split}/solved_cost", loss, prog_bar=True, on_epoch=True, sync_dist=True
                )
                # return cost_pred
                # if self.args.diffusion_space == "discrete":
                #         edge_index = None
                #         np_edge_index = None
                #         stacked_tours = []
                # if not self.sparse:
                #         _, points, adj_matrix, gt_tour, _ = batch
                #         # t = np.random.randint(1, self.diffusion.T + 1, points.shape[0]).astype(int)
                #         t = np.random.randint(1, self.diffusion.T + 1, 1).repeat(points.shape[0]).astype(int)
                #         np_points = points.cpu().numpy()
                #         np_gt_tour = gt_tour.cpu().numpy()[0]
                #         batch_size = points.shape[0]
                # else:
                #         _, graph_data, point_indicator, edge_indicator, _, _ = batch
                #         t = np.random.randint(
                #                 1, self.diffusion.T + 1, point_indicator.shape[0]
                #         ).astype(int)
                #         # t = np.random.randint(1, self.diffusion.T + 1, 1).repeat(point_indicator.shape[0]).astype(int)
                #         route_edge_flags = graph_data.edge_attr
                #         points = graph_data.x
                #         edge_index = graph_data.edge_index
                #         np_points = points.cpu().numpy()
                #         num_edges = edge_index.shape[1]
                #         batch_size = point_indicator.shape[0]
                #         adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))

                # # Sample from diffusion
                # adj_matrix_onehot = F.one_hot(adj_matrix.long(), num_classes=2).float()
                # if self.sparse:
                #         adj_matrix_onehot = adj_matrix_onehot.unsqueeze(1)

                # xt = self.diffusion.sample(adj_matrix_onehot, t)
                # xt = xt * 2 - 1
                # xt = xt * (1.0 + 0.05 * torch.rand_like(xt))

                # adj_matrix_onehot = adj_matrix_onehot.reshape(1, points.shape[0], -1, 2)
                # graph_size = points.shape[0] // batch_size

                # if self.sparse:
                #         xt = xt.reshape(-1)

                # if self.sparse:
                #         t = torch.from_numpy(t).float()
                #         t = t.reshape(-1, 1).repeat(1, adj_matrix.shape[1]).reshape(-1)
                #         xt = xt.reshape(-1)
                #         adj_matrix = adj_matrix.reshape(-1)
                #         points = points.reshape(-1, 2)
                #         edge_index = edge_index.float().to(adj_matrix.device).reshape(2, -1)
                #         edge_index = edge_index.reshape(2, batch_size, -1)
                #         revise_edge_idx = (
                #                 torch.arange(0, batch_size)[None, :, None]
                #                 .repeat(edge_index.shape[0], 1, edge_index.shape[2])
                #                 .to(edge_index.device)
                #         )
                #         np_edge_index = (edge_index - graph_size * revise_edge_idx).cpu().numpy()
                #         np_edge_index = np_edge_index.reshape(2, -1).astype("int64")
                # else:
                #         t = torch.from_numpy(t).float().view(adj_matrix.shape[0])
                # with torch.no_grad():
                #         x0_logit = classifier_model.forward(points.float().to(adj_matrix.device),
                #                 xt.float().to(adj_matrix.device),
                #                 t.float().to(adj_matrix.device),
                #                 edge_index)

                #         if not self.sparse:
                #                         x0_pred_prob = (
                #                                 x0_logit.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
                #                         )
                #         else:
                #                         x0_pred_prob = x0_logit.reshape((1, points.shape[0], -1, 2)).softmax(
                #                                 dim=-1
                #                         )
                #         t_hat = t.clone()
                #         t1 = torch.from_numpy(np.array([t_hat.reshape(-1)[0]]).astype(int)).view(1)

                #         t2 = np.array([t_hat.reshape(-1)[0] - 1]).astype(int)
                #         xt = (xt > 0).long()
                #         parm_p, xt = self.categorical_posterior(
                #                         t2, t1, x0_pred_prob, xt, classifier=None,classifier_training=True
                #                 )
                #         # parm_p = (parm_p + parm_p.transpose(1,2))/2
                #         # print(parm_p)
                # # print(parm_p.shape)
                # # print(xt.shape)
                # # print(t.shape)
                # value_pred, _ = self.forward(
                #         points.float().to(adj_matrix.device),
                #         x0_logit.float().to(adj_matrix.device),
                #         t.float().to(adj_matrix.device),
                #         edge_index,
                # )

                # # # Denoise
                # # value_pred, _ = self.forward(
                # #         points.float().to(adj_matrix.device),
                # #         xt.float().to(adj_matrix.device),
                # #         t.float().to(adj_matrix.device),
                # #         edge_index,
                # # )

                # if self.diffusion_space == "discrete":
                #         xt = (xt > 0).long()
                #         adj_mat = xt.float().cpu().detach().numpy() + 1e-6
                # else:
                #         adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5


                # tours, merge_iterations = merge_tours(
                #         adj_mat,
                #         np_points,
                #         np_edge_index,
                #         sparse_graph=self.sparse,
                #         batch_size=batch_size,
                #         guided=True,
                # )

                # # tours, merge_iterations = merge_tours(
                # #         adj_matrix.cpu().detach().numpy(),
                # #         np_points,
                # #         np_edge_index,
                # #         sparse_graph=self.sparse,
                # #         batch_size=batch_size,
                # #         guided=True,
                # # )

                # solved_tours = torch.tensor(tours)
                # if self.sparse:
                #         splitted_points = np.split(np_points, batch_size, axis=0)
                #         tsp_solver = TSPEvaluator(splitted_points, batch=True)
                # else:
                #         tsp_solver = TSPEvaluator(np_points, batch=True)
                # # print("here3")
                # true_cost = tsp_solver.evaluate(solved_tours)

                # print("true_cost :", true_cost)
                # print("value_pred :", value_pred.squeeze(1))

                # # compute loss
                # loss_func = nn.MSELoss()
                # loss = loss_func(
                #         value_pred, true_cost.reshape(-1, 1).float().to(adj_matrix.device)
                # )

                # self.log(
                #         f"{split}/solved_cost", loss, prog_bar=True, on_epoch=True, sync_dist=True
                # )
                return loss

        def run_save_numpy_heatmap(self, adj_mat, np_points, real_batch_idx, split):
                if self.args.parallel_sampling > 1 or self.args.sequential_sampling > 1:
                        raise NotImplementedError("Save numpy heatmap only support single sampling")
                exp_save_dir = os.path.join(
                        self.logger.save_dir, self.logger.name, self.logger.version
                )
                heatmap_path = os.path.join(exp_save_dir, "numpy_heatmap")
                rank_zero_info(f"Saving heatmap to {heatmap_path}")
                os.makedirs(heatmap_path, exist_ok=True)
                real_batch_idx = real_batch_idx.cpu().numpy().reshape(-1)[0]
                np.save(
                        os.path.join(heatmap_path, f"{split}-heatmap-{real_batch_idx}.npy"), adj_mat
                )
                np.save(
                        os.path.join(heatmap_path, f"{split}-points-{real_batch_idx}.npy"),
                        np_points,
                )

        def validation_step(self, batch, batch_idx):
                if self.args.do_train:
                        return self.value_function_vaild_step(self.classifier_model, batch, batch_idx, split="val")
                # elif self.args.do_train and self.args.diffusion_space == "continuous":
                #         return self.value_function_vaild_step(self.classifier_model, batch, batch_idx, split="val")
                elif self.args.do_valid_only:
                        return self.value_function_vaild_step(self.classifier_model, batch, batch_idx, split="val")
                elif self.classifier is not None:
                        torch.set_grad_enabled(True)
                        return self.classifier_test_step(
                                self.classifier_model, batch, batch_idx, split="val"
                        )
                elif self.diffusion_type == "classifier":
                        return self.value_function_vaild_step(batch, batch_idx, split="val")
