import torch
import numpy as np
from torch.distributions.bernoulli import Bernoulli
from difusco.utils.tsp_utils import TSPEvaluator, batched_two_opt_torch, merge_tours, make_tour_to_graph, merge_tours_parallel
from difusco.utils.mis_utils import mis_decode_np, mis_decode_degree
from difusco.utils.diffusion_schedulers import InferenceSchedule
import torch.nn.functional as F
import scipy.sparse

def get_reward_bonus(xt,xt_before,batch_size,model_args,np_points,np_edge_index,sparse):
    
    if sparse :
        bonus_reward = []
        for x in [xt,xt_before]:
            points = torch.from_numpy(np_points)
            
            if model_args.diffusion_type == 'gaussian':
                # adj_mat = prob.cpu().detach().numpy() * 0.5 + 0.5
                adj_mat = x.cpu().detach().numpy() * 0.5 + 0.5
                
            else:
                # adj_mat = prob.float().cpu().detach().numpy() + 1e-6
                adj_mat = x.float().cpu().detach().numpy() + 1e-6
            tours, merge_iterations, _ = merge_tours(adj_mat, np_points,np_edge_index,
                        sparse_graph=sparse, batch_size=batch_size, guided=True, tsp_decoder= model_args.tsp_decoder)
            
            tours = torch.tensor(tours)
            points = points.reshape([batch_size, -1, 2])
            dist_mat = torch.cdist(points,points)
            start_cities = tours[:, :-1]
            end_cities = tours[:, 1:]
            
            ins_size = dist_mat.size()[-1]
            m_indices = torch.arange(batch_size)
            m_indices = m_indices.reshape(-1, 1).repeat(1, ins_size)
            distances = dist_mat[m_indices, start_cities, end_cities]
            reward = -distances.sum(dim=-1)
            bonus_reward.append(reward)
        bonus_reward = bonus_reward[0] - bonus_reward[1]

    else :
 
        xt = torch.concat([xt,xt_before],dim=0)
        points = torch.concat([torch.from_numpy(np_points),torch.from_numpy(np_points)],dim=0)
        
        
        if model_args.diffusion_type == 'gaussian':
            # adj_mat = prob.cpu().detach().numpy() * 0.5 + 0.5
            adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
            
        else:
            # adj_mat = prob.float().cpu().detach().numpy() + 1e-6
            adj_mat = xt.float().cpu().detach().numpy() + 1e-6
        tours, merge_iterations, _ = merge_tours(adj_mat, np.concatenate([np_points,np_points],axis=0), np.concatenate([np_edge_index,np_edge_index],axis=0),
                    sparse_graph=sparse, batch_size=batch_size*2, guided=True,tsp_decoder= model_args.tsp_decoder)
        
        tours = torch.tensor(tours)
        dist_mat = torch.cdist(points,points)
        start_cities = tours[:, :-1]
        end_cities = tours[:, 1:]
        
        ins_size = dist_mat.size()[-1]
        m_indices = torch.arange(batch_size*2)
        m_indices = m_indices.reshape(-1, 1).repeat(1, ins_size)
        distances = dist_mat[m_indices, start_cities, end_cities]
        distances = distances.sum(dim=-1)
                
        bonus_reward = -distances[0:batch_size] + distances[batch_size:]
    return bonus_reward
    
    
def categorical_posterior(model, model_diffusion, model_args, target_t, t, x0_pred_prob, xt,next_xt,batch_t=False,inference=False,sparse=False, point_indicator=None):
    """Sample from the categorical posterior for a given time step.
       See https://arxiv.org/pdf/2107.03006.pdf for details.
    """
    diffusion = model_diffusion
    diffusion.Q_bar=diffusion.Q_bar.to(x0_pred_prob.device)
    
    # print('batch_t',batch_t)
    if not batch_t:
        t,target_t = t.item(),target_t.item()
    
    if target_t is None:
        target_t = t - 1
    # else:
    #     target_t = torch.from_numpy(target_t).view(1)
    
        #diffusion.Q_bar[t]=torch.from_numpy(diffusion.Q_bar[t])
    Q_t = (torch.linalg.inv(diffusion.Q_bar[target_t])@diffusion.Q_bar[t]).float() #sing : 2x2 , batch : bx2x2
        #Q_t=torch.from_numpy(Q_t).float().to(x0_pred_prob.device)
    if inference == True and target_t == 0 :
        Q_t = torch.eye(2).float().to(x0_pred_prob.device)
    
    #diffusion.Q_bar[t]=torch.from_numpy(diffusion.Q_bar[t]).float().to(x0_pred_prob.device)
    #diffusion.Q_bar[target_t]=torch.from_numpy(diffusion.Q_bar[target_t]).float().to(x0_pred_prob.device)
    Q_bar_t_source = diffusion.Q_bar[t] #sing : 2x2 , batch : bx2x2
    Q_bar_t_target = diffusion.Q_bar[target_t] #sing : 2x2 , batch : bx2x2

    xt = F.one_hot(xt.long(), num_classes=2).float() # [batch, Node, Node, 2]
    xt = xt.reshape(x0_pred_prob.shape)
    if batch_t: 
        origin_xt_shape = xt.shape
        x_t_target_prob_part_1 = torch.matmul(xt.view(xt.shape[0],-1,xt.shape[-1]), Q_t.permute((0,2,1)).contiguous())
        x_t_target_prob_part_1 = x_t_target_prob_part_1.view(origin_xt_shape)
        x_t_target_prob_part_2 = Q_bar_t_target[:,0][:,None,None,:]
        x_t_target_prob_part_3 = (Q_bar_t_source[:,0][:,None,None,:] * xt).sum(dim=-1, keepdim=True)
    else :
        x_t_target_prob_part_1 = torch.matmul(xt, Q_t.permute((1, 0)).contiguous())
        x_t_target_prob_part_2 = Q_bar_t_target[0]
        x_t_target_prob_part_3 = (Q_bar_t_source[0] * xt).sum(dim=-1, keepdim=True)
        
    x_t_target_prob = (x_t_target_prob_part_1 * x_t_target_prob_part_2) / x_t_target_prob_part_3
    # ([10, 50, 50, 2]) * [2]/([10, 50, 50, 1])
    sum_x_t_target_prob = x_t_target_prob[..., 1] * x0_pred_prob[..., 0]
    if batch_t: 
        x_t_target_prob_part_2_new = Q_bar_t_target[:,1][:,None,None,:]
        x_t_target_prob_part_3_new = (Q_bar_t_source[:,1][:,None,None,:] * xt).sum(dim=-1, keepdim=True)
    else :
        x_t_target_prob_part_2_new = Q_bar_t_target[1]
        x_t_target_prob_part_3_new = (Q_bar_t_source[1] * xt).sum(dim=-1, keepdim=True)
    x_t_source_prob_new = (x_t_target_prob_part_1 * x_t_target_prob_part_2_new) / x_t_target_prob_part_3_new

    sum_x_t_target_prob += x_t_source_prob_new[..., 1] * x0_pred_prob[..., 1]

    dist = Bernoulli(sum_x_t_target_prob.clamp(0, 1))

    prob = None
    if next_xt == None :  
        xt = dist.sample()
    else :
        xt = next_xt.reshape(sum_x_t_target_prob.shape)
    log_prob = dist.log_prob(xt)
    if model_args.task == 'tsp':
        if sparse:
            log_prob = log_prob.reshape(model_args.batch_size,-1).mean(dim=1)
        else:
            log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
    elif model_args.task == 'mis_sat' or model_args.task == 'mis_er':
        log_probs = []
        idx = 0
        for length in point_indicator:
            log_probs.append(log_prob[:,idx:idx+length, :].sum()/1300) ## hard coded for MIS SAT scaling
            idx+=length
        # log_prob = torch.tensor(log_probs).to(log_prob.device)
        log_prob = torch.stack(log_probs)

    # if not batch and target_t == 0 :
    if not batch_t and target_t == 0 :
        prob = sum_x_t_target_prob.clamp(min=0)  
    if sparse:
        xt = xt.reshape(-1)
    
    return xt, log_prob, prob



def test_eval_co(model, model_diffusion, model_args, batch, inference=False,use_env=False, sparse=False, reward_gap=False):
    
    device = model.device
    if not sparse:
        edge_index = None
        original_edge_index = None
        np_edge_index = None
        if model_args.task == 'tsplib' :
            real_batch_idx, points, adj_matrix, gt_tour, cost,unnorm_points, optimal_cost = batch 
            real_batch_idx, points, adj_matrix, gt_tour, cost = real_batch_idx.to(device), points.to(device), adj_matrix.to(device), gt_tour.to(device), cost.to(device)
            np_points = points.cpu().numpy()
            np_gt_tour = gt_tour.cpu().numpy()


        else:
            real_batch_idx, points, adj_matrix, gt_tour, cost = batch 
            real_batch_idx, points, adj_matrix, gt_tour, cost = real_batch_idx.to(device), points.to(device), adj_matrix.to(device), gt_tour.to(device), cost.to(device)
            np_points = points.cpu().numpy()
            np_gt_tour = gt_tour.cpu().numpy()
    
    else:
        if len(batch) == 5:
            real_batch_idx, graph_data, point_indicator, gt_tour,cost = batch
        elif len(batch) == 6:
            real_batch_idx, graph_data, point_indicator, edge_indicator, gt_tour, cost = batch
        route_edge_flags = graph_data.edge_attr
        points = graph_data.x
        edge_index = graph_data.edge_index
        num_edges = edge_index.shape[1]
        batch_size = point_indicator.shape[0]
        if use_env :
            adj_matrix = np.array([num_edges//batch_size])
        else :
            route_edge_flags = graph_data.edge_attr
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
        points = points.reshape((-1, 2))
        edge_index = edge_index.reshape((2, -1))
        np_edge_index = edge_index.cpu().numpy()
        original_edge_index = edge_index.clone()

        np_points = points.cpu().numpy()
        np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
        
    # import pdb
    # pdb.set_trace()

    batch_size = model_args.parallel_sampling

    unsolved_tours_list, solved_tours_list = [], []

    if model_args.parallel_sampling > 1:
        if not sparse:
            points = points.repeat(model_args.parallel_sampling, 1, 1)
            # np_points = np_points.repeat(model_args.parallel_sampling, 1, 1)
        else:
            points = points.repeat(model_args.parallel_sampling, 1)
            # np_points = np_points.repeat(model_args.parallel_sampling, 1)
            # edge_index = model.duplicate_edge_index(model_args.parallel_sampling, edge_index, np_points.shape[0], device)
            edge_index = model.module.duplicate_edge_index(model_args.parallel_sampling, edge_index, np_points.shape[0], device)

            np_edge_index = edge_index.cpu().numpy()

        np_points = points.cpu().numpy()
        if model_args.task ==  'tsplib':
            unnorm_points = np.repeat(unnorm_points,np_points.shape[0],axis=0)


    tsp_solvers = [] # generate tsp_solvers
    
    if model_args.task ==  'tsplib':
        for batch_idx in range(batch_size): 
            tsp_solvers.append(TSPEvaluator(unnorm_points.reshape([batch_size, -1, 2])[batch_idx], batch=True))
    else :
        
        for batch_idx in range(batch_size): 
            tsp_solvers.append(TSPEvaluator(np_points.reshape([batch_size, -1, 2])[batch_idx], batch=True))

    if not inference:
        if use_env :
            xt = torch.randn([batch_size]+adj_matrix.tolist()).to(model.device)
        else :
            xt = torch.randn_like(adj_matrix.float()).to(model.device)
            
    else:
        xt = torch.randn_like(adj_matrix.float()).to(model.device) # (sparse) xt = [batch, sparse*100]

    if model_args.parallel_sampling > 1:
        if not sparse:
            xt = xt.repeat(model_args.parallel_sampling, 1, 1)
        else:
            xt = xt.repeat(model_args.parallel_sampling, 1)  # [B, E]
        xt = torch.randn_like(xt)
    
    if model_args.diffusion_type == 'gaussian':
        xt.requires_grad = True
    else:
        xt = (xt > 0).long()

    if sparse:
        xt = xt.reshape(-1)

    steps = model_args.inference_diffusion_steps
    time_schedule = InferenceSchedule(inference_schedule=model_args.inference_schedule,
                                                                        T=model_args.diffusion_steps, inference_T=steps)
    all_latents = [xt.clone().cpu().detach()]
    all_log_probs = []
    all_time_steps = []
    all_rb = []
    for i in range(steps):
        t1, t2 = time_schedule(i)
        t1 = np.array([t1]).astype(int)
        #t1 = (np.ones(points.shape[0])*t1).astype(int)
        t2 = np.array([t2]).astype(int)

        if model_args.reward_shaping :
            xt_before = xt.clone()
        
        if model_args.diffusion_type == 'gaussian':
            xt = model.gaussian_denoise_step(
                    points, xt, t1, device, edge_index, target_t=t2)
        else:
            xt, log_prob, prob, _ = categorical_denoise_step(model,model_diffusion,model_args,
                    points, xt, t1, device, edge_index, target_t=t2,
                    inference=inference, sparse=sparse)
            # print('batchness', batchness, 'xt', xt, 'log_prob',log_prob, 'prob', prob)
        if model_args.reward_shaping :
            r_b = get_reward_bonus(xt,xt_before,batch_size,model_args,np_points,np_edge_index,sparse)
        else :
            r_b = torch.zeros(batch_size)
        # if not (log_prob == None):
        #     all_latents.append(xt.cpu().detach())
        #     all_log_probs.append(log_prob.cpu().detach())
        #     all_time_steps.append(torch.LongTensor([t1[0],t2[0]]))
        #     all_rb.append(r_b)
    # if 
    if model_args.diffusion_type == 'gaussian':
        # adj_mat = prob.cpu().detach().numpy() * 0.5 + 0.5
        adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
        
    else:
        adj_mat = xt.float().cpu().detach().numpy() + 1e-6
        #adj_mat = xt.float().cpu().detach().numpy() + 1e-6
    

    # try:
        # print("pass", adj_mat.shape, "np_points", np_points.shape, "np_edge_index", np_edge_index.shape)

    tours, merge_iterations, _ = merge_tours(adj_mat, np_points, np_edge_index,
        sparse_graph=sparse, batch_size=batch_size, guided=True,tsp_decoder=model_args.tsp_decoder) 
    # except:


    # (dense) adj_mat=[batch, 100, 100] np_points = [batch, 100, 2], [batch, 100 + 1] (double-list)
    # (sparse) adj_mat = [batch*sparse*parallel*100], np_points= [batch*100, 2], np_edge_index = [2, sparse*batch*100]

    unsolved_tours = torch.tensor(tours)
    unsolved_tours_list.append(unsolved_tours) # [1, 8, 100 + 1]

            # unsolved_tours = torch.tensor(tours)
            # unsolved_tours_list.append(unsolved_tours) # [1, 8, 100 + 1]
    # pdb.set_trace()
    np_points_reshape = np_points.reshape([batch_size, -1, 2])
    # tours_reshape = tours.reshape([batch_size, -1])

    if model_args.reward_2opt:
        for idx in range(batch_size):
            solved_tours, ns = batched_two_opt_torch(
                    np_points_reshape[idx].astype("float64"),
                    np.array([tours[idx]]).astype("int64"),
                    max_iterations=model_args.two_opt_iterations,
                    device=device,
                    # batch=True,
            )
            solved_tours_list.append(solved_tours) # [1, 8, 100 + 1]
        tour_list = solved_tours_list
        # print('solved_tours_list', solved_tours_list, 'unsolved_tours_list', unsolved_tours_list, solved_tours_list[0]-unsolved_tours_list[0])
    else:
        tour_list = unsolved_tours_list
    
    # print('solved_tours_list',solved_tours_list)
    tour_list = np.concatenate(tour_list, axis=0)
    tour_list = torch.tensor(tour_list)
    tour_list = tour_list.view(model_args.sequential_sampling,batch_size, -1)
    
    # solved_tours_list = np.concatenate(solved_tours_list, axis=0)
    # solved_tours_list = torch.tensor(solved_tours_list)
    # solved_tours_list = solved_tours_list.view(model_args.sequential_sampling, batch_size, -1)
    gt_costs, best_costs = [],  []

    gt_costs = []
    temp_solved_costs = []
    temp_unsolved_costs = []


    for batch_idx in range(batch_size):
        tsp_solver = tsp_solvers[batch_idx]
        # best_unsolved_costs.append(tsp_solver.evaluate(unsolved_tours_list[:,batch_idx]))
        # if not sparse:
        #     tour_each = tour_list[batch_idx]
        # else:
        #     tour_each = tour_list[batch_idx:]
        # print('tour_list',tour_list)
        # try:
        best_costs.append(tsp_solver.evaluate(tour_list[0,batch_idx].view([1, -1])))
        # except:
    if model_args.tsplib:
        for batch_idx in range(batch_size):
            tsp_solver = tsp_solvers[batch_idx]
            # best_unsolved_costs.append(tsp_solver.evaluate(unsolved_tours_list[:,batch_idx]))
            # if not sparse:
            #     tour_each = tour_list[batch_idx]
            # else:
            #     tour_each = tour_list[batch_idx:]
            # print('tour_list',tour_list)
            # try:
            gt_costs.append(tsp_solver.evaluate(gt_tour[0].view([1, -1]).cpu()))
        # import pdb
        # pdb.set_trace()


    best_cost = torch.tensor(best_costs).min().item()

    # best_solved_cost = np.array(best_costs).min()
    best_solved_id = torch.tensor(best_costs).argmin().item()
    g_best_tour = tour_list[0, best_solved_id].cpu().numpy()
    gt_cost = 0
    wo_2opt_costs = best_cost
    previous_best_cost = best_cost
    g_best_cost = best_cost
    if model_args.rewrite:
        g_best_cost =  best_cost

        for _ in range(model_args.rewrite_steps):
            g_stacked_tours = []
            # g_stacked_unsolved_tours = []
            # optimal adjacent matrix
            # import pdb
            # pdb.set_trace()
            g_x0 = tour2adj(g_best_tour, np_points_reshape[0], sparse, model_args.sparse_factor, original_edge_index)
            g_x0 = g_x0.unsqueeze(0).to(device)  # [1, N, N] or [1, N]
            if model_args.parallel_sampling > 1:
                if not sparse:
                    g_x0 = g_x0.repeat(model_args.parallel_sampling, 1, 1)  # [1, N ,N] -> [B, N, N]
                    # np_edge_index = [None for _ in range(model_args.parallel_sampling)]
                else:
                    g_x0 = g_x0.repeat(model_args.parallel_sampling, 1)

            if sparse:
                g_x0 = g_x0.reshape(-1)

            g_x0_onehot = F.one_hot(g_x0.long(), num_classes=2).float()  # [B, N, N, 2]
            # if sparse:
            #   g_x0_onehot = g_x0_onehot.unsqueeze(1)

            steps_T = int(model_args.diffusion_steps * model_args.rewrite_ratio)
            steps_inf = model_args.inference_steps
            time_schedule = InferenceSchedule(inference_schedule=model_args.inference_schedule,
                                            T=steps_T, inference_T=steps_inf)

            # g_xt = self.diffusion.sample(g_x0_onehot, steps_T)
            Q_bar = model_diffusion.Q_bar[steps_T].float().to(g_x0_onehot.device)
            g_xt_prob = torch.matmul(g_x0_onehot, Q_bar)  # [B, N, N, 2]

            # add noise for the steps_T samples, namely rewrite
            g_xt = torch.bernoulli(g_xt_prob[..., 1].clamp(0, 1))  # [B, N, N]
            g_xt = g_xt * 2 - 1  # project to [-1, 1]
            g_xt = g_xt * (1.0 + 0.05 * torch.rand_like(g_xt))  # add noise
            g_xt = (g_xt > 0).long()

            for i in range(steps_inf):
                t1, t2 = time_schedule(i)
                t1 = np.array([t1]).astype(int)
                t2 = np.array([t2]).astype(int)

                # pdb.set_trace()
                g_xt, _, _, _ = categorical_denoise_step(model,model_diffusion,model_args,
                    points, g_xt, t1, device, edge_index, target_t=t2,
                    inference=inference, sparse=sparse)




            g_adj_mat = g_xt.float().cpu().detach().numpy() + 1e-6
            # if model_args.save_numpy_heatmap:
            # self.run_save_numpy_heatmap(g_adj_mat, np_points, real_batch_idx, split)

            if not sparse:
                g_adj_mat_reshape = g_adj_mat.reshape([model_args.parallel_sampling, g_adj_mat.shape[-1], g_adj_mat.shape[-1]])
                np_points_reshape = np_points.reshape([model_args.parallel_sampling, -1, 2]) 

                g_tours, g_merge_iterations, _ = merge_tours(
                g_adj_mat_reshape, np_points_reshape, None,
                sparse_graph=sparse,
                batch_size = model_args.parallel_sampling,
                tsp_decoder=model_args.tsp_decoder
                )
            else:
                g_tours, g_merge_iterations, _ = merge_tours(g_adj_mat, np_points, np_edge_index,
                sparse_graph=sparse, batch_size=batch_size, guided=True,tsp_decoder=model_args.tsp_decoder) 


            # Refine using 2-opt
            for idx in range(batch_size):
                if model_args.reward_2opt:
                    g_tour, ns = batched_two_opt_torch(
                            np_points_reshape[idx].astype("float64"),
                            np.array([g_tours[idx]]).astype("int64"),
                            max_iterations=model_args.two_opt_iterations,
                            device=device,
                            # batch=True,
                    )
                    g_stacked_tours.append(g_tour)
                else:
                    g_stacked_tours.append(g_tours[idx])


            g_tours = np.concatenate(g_stacked_tours, axis=0).reshape([model_args.parallel_sampling, -1])


            np_points = np_points.reshape([model_args.parallel_sampling, -1, 2])
            tsp_solver = TSPEvaluator(np_points[0])  # np_points: [N, 2] ndarray

            g_total_sampling = model_args.parallel_sampling
            g_all_costs = [tsp_solver.evaluate(g_tours[i]) for i in range(g_total_sampling)]

            g_best_cost_tmp, g_best_id = np.min(g_all_costs), np.argmin(g_all_costs)
            # g_best_unsolved_cost_tmp = np.min(g_all_unsolved_costs)
            g_best_cost = min(g_best_cost, g_best_cost_tmp)
            # g_best_wo_2opt = min(g_best_unsolved_cost_tmp, g_best_wo_2opt)

            g_best_tour = g_tours[g_best_id]

    # if model_args.task == "tsplib" :
    #     g_best_cost = ((g_best_cost-optimal_cost.item())/optimal_cost.item())*100

    if model_args.tsplib:
        past_best = g_best_cost
        past_prev = previous_best_cost
        g_best_cost = (g_best_cost/gt_costs[0]-1)*100
        previous_best_cost = (previous_best_cost/gt_costs[0]-1)*100
        # print('past', past_best, past_prev, 'after', g_best_cost, previous_best_cost)
        print(real_batch_idx, g_best_cost)
    return [g_best_cost], [previous_best_cost]

# def categorical_denoise_step_mis(model,model_diffusion, xt, t, device, edge_index=None, target_t=None, next_xt=None, batch_t=False, inference=False, sparse=False):
#     with torch.no_grad():
#         t = torch.from_numpy(t).view(1)
#         x0_pred = model.forward(
#           xt.float().to(device),
#           t.float().to(device),
#           edge_index.long().to(device) if edge_index is not None else None,
#       )
#         x0_pred_prob = x0_pred.reshape((1, xt.shape[0], -1, 2)).softmax(dim=-1)
#         # xt = categorical_posterior(model, target_t, t, x0_pred_prob, xt)
#         xt = categorical_posterior(model, model_diffusion, target_t, t, x0_pred_prob, xt, next_xt, batch_t,inference,sparse)
#         return xt
def gaussian_posterior(model,model_diffusion, target_t, t, pred, xt,next_xt,batch_t=False,inference=False,sparse=False, point_indicator=None):
    """Sample (or deterministically denoise) from the Gaussian posterior for a given time step.
            See https://arxiv.org/pdf/2010.02502.pdf for details.
    """
    eta = 0.1 # hardcoding

    diffusion = model_diffusion
    if target_t is None:
        target_t = t - 1
    else:
        if isinstance(target_t, np.ndarray):
            target_t = torch.from_numpy(target_t).view(1)

    atbar = diffusion.alphabar[t]
    atbar_target = diffusion.alphabar[target_t]

    if model.args.inference_trick == 'ddim':
        if target_t == 0:
            at = diffusion.alpha[t]
            # beta_t = diffusion.beta[t]

            xt_target_mean = (1 / np.sqrt(at)).item() * (xt - ((1 - at) / np.sqrt(1 - atbar)).item() * pred)
            """
            std_dev_t = eta*np.sqrt(1-atbar).item()
            std_dev_t = torch.tensor(std_dev_t).to(xt.device)
            """
            # if next_xt == None :  
            #     xt_target = xt_target_mean + std_dev_t * z
            # else :
            #     xt_target = next_xt.reshape(xt_target_mean.shape)
    
        else:
            xt_target_mean = np.sqrt(atbar_target / atbar).item() * (xt - np.sqrt(1 - atbar).item() * pred)
            xt_target_mean = xt_target_mean + np.sqrt(1 - atbar_target).item() * pred

            # std_dev_t_claude = torch.tensor(np.sqrt((1 - atbar_target) / (1 - atbar)) * np.sqrt(1 - atbar / atbar_target)).to(xt.device)
            """
            std_dev_t = _get_variance(diffusion, t, target_t).to(xt.device) ** (0.5)
            std_dev_t *= eta
            """
            # print('std_dev_t',std_dev_t, 'std_dev_t_claude',std_dev_t_claude)

    else:
        raise ValueError('Unknown inference trick {}'.format(model.args.inference_trick))
    
    std_dev_t = eta * _get_variance(diffusion, t, target_t).to(xt.device) ** (0.5)

    if next_xt == None :  
        variance_noise = torch.randn_like(xt, dtype=xt.dtype).to(xt.device)
        xt_target = xt_target_mean + std_dev_t * variance_noise
    else :
        xt_target = next_xt.reshape(xt_target_mean.shape)

    log_prob = -((xt_target.detach() - xt_target_mean) ** 2) / (2 * (std_dev_t ** 2)) - torch.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi)))
    log_probs = []
    # print('xt',xt.size(), 'log_prob', log_prob.size())
    idx = 0
    for length in point_indicator:
        log_probs.append(log_prob[idx:idx+length].sum()/800) ## hard coded for MIS SAT scaling
        idx+=length
    # log_prob = torch.tensor(log_probs).to(log_prob.device)

    log_prob = torch.stack(log_probs)
    # print('log_prob',log_prob, 't', t, 'target_t', target_t, 'beta', diffusion.beta[t - 1] , 'compute alpha', 1-diffusion.alphabar[t]/diffusion.alphabar[t-1], 'std_dev_t',std_dev_t)
    prob = None
    # print('t', t, "std_dev_t",std_dev_t,'log_prob', log_prob)
    return xt_target, log_prob, prob

def categorical_denoise_step(model, model_diffusion, model_args, points, xt, t, device,  edge_index=None, target_t=None, next_xt=None, batch_t=False, inference=False, sparse=False, aux=False, return_x0_pred=False):
    # if batch_t == False :
    #     t = torch.from_numpy(t).view(1)
    if isinstance(t, np.ndarray):
        t = torch.from_numpy(t).view(1)

    x0_pred, aux_pred = model.forward(
        points.float().to(device),
        (xt).float().to(device),
        t.float().to(device),
        edge_index.long().to(device) if edge_index is not None else None
    )
    #     x0_pred, aux_pred = x0_pred
        # print('x0_pred', x0_pred.size(), 'aux_pred', aux_pred.size())
    
    # try:
    # except:
    #     import pdb
    #     pdb.set_trace()
        
    #sparse : [batch*sparse*tsp_size,2]
    if not sparse:
        x0_pred_prob = x0_pred.permute((0, 2, 3, 1)).contiguous().softmax(dim=-1)
    else:
        x0_pred_prob = x0_pred.reshape((1, points.shape[0], -1, 2)).softmax(dim=-1)

    xt,log_prob,prob = categorical_posterior(model, model_diffusion, model_args, target_t, t, x0_pred_prob, xt, next_xt, batch_t,inference,sparse)



    if aux:
        return xt, log_prob, prob, x0_pred, aux_pred
    else:
        return xt, log_prob, prob, x0_pred


def gaussian_denoise_step_mis(model, model_diffusion, model_args, xt, t, device, edge_index=None, target_t=None, next_xt=None, batch_t=False, inference=False, sparse=False, point_indicator=None):
    if isinstance(t, np.ndarray):
        t = torch.from_numpy(t).view(1)
    pred = model.forward(
        xt.float().to(device),
        t.float().to(device),
        edge_index.long().to(device) if edge_index is not None else None,
    )
    pred = pred.squeeze(1)
    xt, log_prob, prob = gaussian_posterior(model, model_diffusion, target_t, t, pred, xt, next_xt, batch_t,inference,sparse, point_indicator)
    
    return xt, log_prob, prob

def categorical_denoise_step_mis(model, model_diffusion, model_args, xt, t, device, edge_index=None, target_t=None, next_xt=None, batch_t=False, inference=False, sparse=False, point_indicator=None,aux=False):
    if isinstance(t, np.ndarray):
        t = torch.from_numpy(t).view(1)
    x0_pred, aux_pred  = model.forward(
        xt.float().to(device),
        t.float().to(device),
        edge_index.long().to(device) if edge_index is not None else None,
    )
    x0_pred_prob = x0_pred.reshape((1, xt.shape[0], -1, 2)).softmax(dim=-1)
    xt, log_prob, prob = categorical_posterior(model, model_diffusion,model_args, target_t, t, x0_pred_prob, xt, next_xt, batch_t,inference,sparse, point_indicator)
    if aux:
        return xt, log_prob, prob, x0_pred, aux_pred
    else:
        return xt, log_prob, prob, x0_pred

def difusco_with_logprob(model, model_diffusion, model_args,batch,inference=False,use_env=False,sparse=False, reward_gap=False,reward_2opt='none'):
    # gpu_id = self.args.gpu_id[0]
    # set_target_gpu(gpu_id)
    # Reward_Reg = 0
    
    device = model.device


    if not sparse:
        edge_index = None
        np_edge_index = None
        real_batch_idx, points, adj_matrix, gt_tour, cost = batch 
        
        real_batch_idx, points, adj_matrix, gt_tour, cost = real_batch_idx.to(device), points.to(device), adj_matrix.to(device), gt_tour.to(device), cost.to(device)
        
        np_points = points.cpu().numpy()
        np_gt_tour = gt_tour.cpu().numpy()
        batch_size = points.shape[0]
        
    else:
        if len(batch) == 5:
            real_batch_idx, graph_data, point_indicator, gt_tour,cost = batch
        elif len(batch) == 6:
            real_batch_idx, graph_data, point_indicator, edge_indicator, gt_tour, cost = batch
        route_edge_flags = graph_data.edge_attr
        points = graph_data.x
        edge_index = graph_data.edge_index
        num_edges = edge_index.shape[1]
        batch_size = point_indicator.shape[0]
        if use_env :
            adj_matrix = np.array([num_edges//batch_size])
        else :
            route_edge_flags = graph_data.edge_attr
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
        points = points.reshape((-1, 2))
        edge_index = edge_index.reshape((2, -1))
        np_points = points.cpu().numpy()
        np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
        np_edge_index = edge_index.cpu().numpy()
        

    unsolved_tours_list, solved_tours_list = [], []

    tsp_solvers = [] # generate tsp_solvers
    for batch_idx in range(batch_size): 
        tsp_solvers.append(TSPEvaluator(np_points.reshape([batch_size, -1, 2])[batch_idx], batch=True))

    for _ in range(model_args.sequential_sampling):
        if not inference:
            if use_env :
                xt = torch.randn([batch_size]+adj_matrix.tolist()).to(model.device)
            else:
                xt = torch.randn_like(adj_matrix.float()).to(model.device)
                
        else:
            xt = torch.randn_like(adj_matrix.float()).to(model.device) # (sparse) xt = [batch, sparse*100]

        if model_args.diffusion_type == 'gaussian':
            xt.requires_grad = True
        else:
            xt = (xt > 0).long()

        if sparse:
            xt = xt.reshape(-1)

        steps = model_args.inference_diffusion_steps
        time_schedule = InferenceSchedule(inference_schedule=model_args.inference_schedule, T=model_args.diffusion_steps, inference_T=steps)
        all_latents = [xt.clone().cpu().detach()]
        all_log_probs = []
        all_time_steps = []
        all_rb = []
        all_aux_pred = []
        for i in range(steps):
            t1, t2 = time_schedule(i)
            t1 = np.array([t1]).astype(int)
            #t1 = (np.ones(points.shape[0])*t1).astype(int)
            t2 = np.array([t2]).astype(int)

            if model_args.reward_shaping :
                xt_before = xt.clone()
            
            if model_args.diffusion_type == 'gaussian':
                xt = model.gaussian_denoise_step(
                        points, xt, t1, device, edge_index, target_t=t2)
            else:
                xt, log_prob, prob, _ , aux_pred = categorical_denoise_step(model,model_diffusion,model_args,
                        points, xt, t1, device, edge_index, target_t=t2,
                        inference=inference, sparse=sparse, aux=True)
                # print('batchness', batchness, 'xt', xt, 'log_prob',log_prob, 'prob', prob)
            if model_args.reward_shaping :
                r_b = get_reward_bonus(xt,xt_before,batch_size,model_args,np_points,np_edge_index,sparse)
            else:
                r_b = torch.zeros(batch_size)
            if not (log_prob == None):
                if aux_pred is not None:
                    all_aux_pred.append(aux_pred.cpu().detach())
                else:
                    all_aux_pred.append(None)
                all_latents.append(xt.cpu().detach())
                all_log_probs.append(log_prob.cpu().detach())
                all_time_steps.append(torch.LongTensor([t1[0],t2[0]]))
                all_rb.append(r_b)

            
        # if 
        if model_args.diffusion_type == 'gaussian':
            # adj_mat = prob.cpu().detach().numpy() * 0.5 + 0.5
            adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
            
        else:
            # adj_mat = prob.float().cpu().detach().numpy() + 1e-6
            adj_mat = xt.float().cpu().detach().numpy()  + 1e-6
        
        tours, merge_iterations, splitted_real_adj_mat = merge_tours(adj_mat, np_points, np_edge_index,
                sparse_graph=sparse, batch_size=batch_size, guided=True,tsp_decoder=model_args.tsp_decoder) 
        # (dense) adj_mat=[batch, 100, 100] np_points = [batch, 100, 2], [batch, 100 + 1] (double-list)
        # (sparse) adj_mat = [batch*sparse*parallel*100], np_points= [batch*100, 2], np_edge_index = [2, sparse*batch*100]
        new_target = []
        for i in range(0,len(tours)) :
            if model_args.kl_aux>0 or model_args.kl_grdy>0 or model_args.use_critic>0:
                new_adj =  torch.from_numpy(splitted_real_adj_mat[i])
                if sparse :
                    edgeidx = graph_data[i].edge_index.to(new_adj.device)
                    new_adj = (new_adj[edgeidx[0], edgeidx[1]]==1).to(new_adj.dtype)
            else:
                new_adj = torch.tensor(1)
            new_target.append(new_adj)
        
        unsolved_tours = torch.tensor(tours)
        unsolved_tours_list.append(unsolved_tours) # [1, 8, 100 + 1]

                # unsolved_tours = torch.tensor(tours)
                # unsolved_tours_list.append(unsolved_tours) # [1, 8, 100 + 1]

        np_points_reshape = np_points.reshape([batch_size, -1, 2])
        # tours_reshape = tours.reshape([batch_size, -1])

        for idx in range(batch_size):
            if reward_2opt!='none' or model_args.reward_2opt:
                solved_tours, ns = batched_two_opt_torch(
                        np_points_reshape[idx].astype("float64"),
                        np.array([tours[idx]]).astype("int64"),
                        max_iterations=model_args.two_opt_iterations,
                        device=device,
                        # batch=True,
                )
                solved_tours_list.append(solved_tours)
            else:
                solved_tours_list.append(tours[idx])

        unsolved_tours_list = np.concatenate(unsolved_tours_list, axis=0)
        unsolved_tours_list = torch.tensor(unsolved_tours_list)
        unsolved_tours_list = unsolved_tours_list.view(model_args.sequential_sampling, batch_size, -1)
        
        solved_tours_list = np.concatenate(solved_tours_list, axis=0)
        solved_tours_list = torch.tensor(solved_tours_list)
        solved_tours_list = solved_tours_list.view(model_args.sequential_sampling, batch_size, -1)
        gt_costs, best_unsolved_costs, best_solved_costs = [], [], []

        for batch_idx in range(batch_size):
            tsp_solver = tsp_solvers[batch_idx]
            best_unsolved_costs.append(tsp_solver.evaluate(unsolved_tours_list[:,batch_idx]).min())
            best_solved_costs.append(tsp_solver.evaluate(solved_tours_list[:,batch_idx]).min())

        gt_cost = gt_costs
        wo_2opt_costs = best_unsolved_costs
        best_solved_cost = best_solved_costs

    if reward_gap:
        wo_2opt_costs = torch.tensor(wo_2opt_costs).to('cpu').detach().clone().flatten()/cost.to('cpu').detach().clone().flatten()
        best_solved_costs = torch.tensor(best_solved_costs).to('cpu').detach().clone().flatten()/cost.to('cpu').detach().clone().flatten()

    if model_args.reward_2opt:
        rewards = best_solved_costs
    else:
        rewards = best_unsolved_costs

    if inference :
        return gt_cost, wo_2opt_costs, best_solved_cost
    else :
        if sparse>0:
            np_edge_index = torch.from_numpy(np_edge_index)
        
        return all_latents, np_edge_index, all_log_probs, - torch.tensor(rewards), all_time_steps, all_rb, new_target, all_aux_pred


# def plot_graph(nodes, adjacency_matrix, threshold=0.9, filename='graph.png'):
#     import matplotlib.pyplot as plt
#     import seaborn as sns
#     import networkx as nx
#     sns.set_style("white")
#     fig, ax = plt.subplots(figsize=(12, 10))

#     # 그래프 생성
#     G = nx.Graph()

#     # 노드 추가
#     for i, (x, y) in enumerate(nodes):
#         G.add_node(i, pos=(x, y))

#     # 엣지 추가
#     for i in range(len(nodes)):
#         for j in range(i+1, len(nodes)):
#             if adjacency_matrix[i, j] > threshold:
#                 G.add_edge(i, j, weight=adjacency_matrix[i, j])

#     # 노드 위치 가져오기
#     pos = nx.get_node_attributes(G, 'pos')

#     # 엣지 가중치 가져오기
#     edges = G.edges()
#     weights = [G[u][v]['weight'] for u, v in edges]

#     # 알파 값 계산 (수정된 부분)
#     alphas = [((w - threshold) / (1 - threshold))*0.5 for w in weights]


#     # alphas = [alpha_func(w) for w in weights]

#     # 엣지 색상 설정 (원래의 Blues 컬러맵 사용)
#     cmap = plt.cm.get_cmap('Blues')
#     edge_colors = [cmap((w - threshold) / (1 - threshold)) for w in weights]

#     # 엣지 그리기
#     for (u, v), alpha, color in zip(edges, alphas, edge_colors):
#         nx.draw_networkx_edges(G, pos, ax=ax, edgelist=[(u, v)],
#                                edge_color=[color],
#                                alpha=alpha,
#                                width=6)

#     # 노드 그리기 (단일 색상)
#     nx.draw_networkx_nodes(G, pos, ax=ax, node_size=250, node_color='#e74c3c',  # 빨간색 계열
#                            edgecolors='none')

#     # 컬러바 추가
#     sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=threshold, vmax=1))
#     sm.set_array([])

#     ax.axis('off')

#     # 검정색 사각형 그리기
#     ax.add_patch(plt.Rectangle((-0.05, -0.05), 1.1, 1.1, fill=False, edgecolor='black', linewidth=2))

#     plt.tight_layout()
#     plt.savefig(filename, dpi=300, bbox_inches='tight')
#     plt.close()  # 메모리에서 그래프 해제
def plot_graph(nodes, adjacency_matrix, threshold=0.5, filename='graph.png'):
    import matplotlib.pyplot as plt
    import seaborn as sns
    import networkx as nx
    import numpy as np
    sns.set_style("white")
    fig, ax = plt.subplots(figsize=(12, 10))

    # 그래프 생성
    G = nx.Graph()

    # 노드 추가
    for i, (x, y) in enumerate(nodes):
        G.add_node(i, pos=(x, y))

    # 엣지 추가
    for i in range(len(nodes)):
        for j in range(i+1, len(nodes)):
            if abs(adjacency_matrix[i, j]) > threshold or adjacency_matrix[i, j] <= -0.5:
                G.add_edge(i, j, weight=adjacency_matrix[i, j])

    # 노드 위치 가져오기
    pos = nx.get_node_attributes(G, 'pos')

    # 엣지 가중치 가져오기
    edges = G.edges()
    weights = [G[u][v]['weight'] for u, v in edges]

    # 알파 값과 색상 계산
    alphas = []
    edge_colors = []
    for w in weights:
        if w >= 1.4:
            alphas.append(0.5)
            edge_colors.append('red')
        
        else:
            if w<=threshold:
                w = threshold
            # alphas.append(0.5)
            # edge_colors.append('blue')
            print(w)
            alphas.append(((w - threshold) / (1 - threshold)) * 0.5)
            edge_colors.append(plt.cm.Blues((w - threshold) / (1 - threshold)))

    # 엣지 그리기
    for (u, v), alpha, color in zip(edges, alphas, edge_colors):
        nx.draw_networkx_edges(G, pos, ax=ax, edgelist=[(u, v)],
                               edge_color=[color],
                               alpha=alpha,
                               width=6)

    # 노드 그리기 (단일 색상)
    nx.draw_networkx_nodes(G, pos, ax=ax, node_size=250, node_color='#A9A9A9',  # 빨간색 계열
                           edgecolors='none')

    ax.axis('off')

    # 그래프 영역 조정
    x_values, y_values = zip(*pos.values())
    x_margin = (max(x_values) - min(x_values)) * 0.05
    y_margin = (max(y_values) - min(y_values)) * 0.05
    plt.xlim(min(x_values) - x_margin, max(x_values) + x_margin)
    plt.ylim(min(y_values) - y_margin, max(y_values) + y_margin)

    # 검정색 사각형 그리기
    ax.add_patch(plt.Rectangle((min(x_values) - x_margin, min(y_values) - y_margin),
                               max(x_values) - min(x_values) + 2*x_margin,
                               max(y_values) - min(y_values) + 2*y_margin,
                               fill=False, edgecolor='black', linewidth=2))

    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.close()  # 메모리에서 그래프 해제

def difusco_with_logprob_heatmap(model, model_diffusion, model_args,batch,inference=False,use_env=False,sparse=False, reward_gap=False,reward_2opt='none'):
    # gpu_id = self.args.gpu_id[0]
    # set_target_gpu(gpu_id)
    # Reward_Reg = 0
    from difusco.utils.tsp_utils import calculate_distance_matrix
    import torch.nn.functional as F
    import copy as cp
    device = model.device

    if not sparse:
        edge_index = None
        np_edge_index = None
        real_batch_idx, points, adj_matrix, gt_tour, cost = batch 
        
        real_batch_idx, points, adj_matrix, gt_tour, cost = real_batch_idx.to(device), points.to(device), adj_matrix.to(device), gt_tour.to(device), cost.to(device)
        
        np_points = points.cpu().numpy()
        np_gt_tour = gt_tour.cpu().numpy()
        batch_size = points.shape[0]
        
    else:
        if len(batch) == 5:
            real_batch_idx, graph_data, point_indicator, gt_tour,cost = batch
        elif len(batch) == 6:
            real_batch_idx, graph_data, point_indicator, edge_indicator, gt_tour, cost = batch
        route_edge_flags = graph_data.edge_attr
        points = graph_data.x
        edge_index = graph_data.edge_index
        num_edges = edge_index.shape[1]
        batch_size = point_indicator.shape[0]
        if use_env :
            adj_matrix = np.array([num_edges//batch_size])
        else :
            route_edge_flags = graph_data.edge_attr
            adj_matrix = route_edge_flags.reshape((batch_size, num_edges // batch_size))
        points = points.reshape((-1, 2))
        edge_index = edge_index.reshape((2, -1))
        np_points = points.cpu().numpy()
        np_gt_tour = gt_tour.cpu().numpy().reshape(-1)
        np_edge_index = edge_index.cpu().numpy()
    
    unsolved_tours_list, solved_tours_list = [], []

    tsp_solvers = [] # generate tsp_solvers
    for batch_idx in range(batch_size): 
        tsp_solvers.append(TSPEvaluator(np_points.reshape([batch_size, -1, 2])[batch_idx], batch=True))

    all_probs = []
    for _ in range(model_args.sequential_sampling):
        if not inference:
            if use_env:
                xt = torch.randn([batch_size]+adj_matrix.tolist()).to(model.device)
            else:
                xt = torch.randn_like(adj_matrix.float()).to(model.device)
                
        else:
            xt = torch.randn_like(adj_matrix.float()).to(model.device) # (sparse) xt = [batch, sparse*100]

        if model_args.diffusion_type == 'gaussian':
            xt.requires_grad = True
        else:
            xt = (xt > 0).long()

        if sparse:
            xt = xt.reshape(-1)

        steps = model_args.inference_diffusion_steps
        time_schedule = InferenceSchedule(inference_schedule=model_args.inference_schedule, T=model_args.diffusion_steps, inference_T=steps)
        all_latents = [xt.clone().cpu().detach()]
        all_log_probs = []
        all_time_steps = []
        all_rb = []
        all_aux_pred = []
        for i in range(steps):
            t1, t2 = time_schedule(i)
            t1 = np.array([t1]).astype(int)
            #t1 = (np.ones(points.shape[0])*t1).astype(int)
            t2 = np.array([t2]).astype(int)
            print(i, 'xt', xt.abs().sum())
            if model_args.reward_shaping :
                xt_before = xt.clone()
            
            if model_args.diffusion_type == 'gaussian':
                xt = model.gaussian_denoise_step(
                        points, xt, t1, device, edge_index, target_t=t2)
            else:
                xt, log_prob, prob, x0_pred , aux_pred = categorical_denoise_step(model,model_diffusion,model_args,
                        points, xt, t1, device, edge_index, target_t=t2,
                        inference=inference, sparse=sparse, aux=True)
                # print('batchness', batchness, 'xt', xt, 'log_prob',log_prob, 'prob', prob)
            if model_args.reward_shaping :
                r_b = get_reward_bonus(xt,xt_before,batch_size,model_args,np_points,np_edge_index,sparse)
            else:
                r_b = torch.zeros(batch_size)
            if not (log_prob == None):
                if aux_pred is not None:
                    all_aux_pred.append(aux_pred.cpu().detach())
                else:
                    all_aux_pred.append(None)
                all_latents.append(xt.cpu().detach())
                all_log_probs.append(log_prob.cpu().detach())
                all_time_steps.append(torch.LongTensor([t1[0],t2[0]]))
                all_rb.append(r_b)
        # all_probs.append(prob.cpu().detach())
        
        # if 
        if model_args.diffusion_type == 'gaussian':
            # adj_mat = prob.cpu().detach().numpy() * 0.5 + 0.5
            adj_mat = xt.cpu().detach().numpy() * 0.5 + 0.5
            
        else:
            # adj_mat = prob.float().cpu().detach().numpy() + 1e-6
            adj_mat = xt.float().cpu().detach().numpy()  + 1e-6
        
        tours, merge_iterations, splitted_real_adj_mat = merge_tours(adj_mat, np_points, np_edge_index,
                sparse_graph=sparse, batch_size=batch_size, guided=True,tsp_decoder=model_args.tsp_decoder) 
        # (dense) adj_mat=[batch, 100, 100] np_points = [batch, 100, 2], [batch, 100 + 1] (double-list)
        # (sparse) adj_mat = [batch*sparse*parallel*100], np_points= [batch*100, 2], np_edge_index = [2, sparse*batch*100]
        new_target = []
        opt_target = []
        dist = calculate_distance_matrix(torch.from_numpy(np_points))
        opt_costs = []
        sol_costs = []
        check_feasibility =[]
        for i in range(0,len(tours)):
            decoded = splitted_real_adj_mat[i]>0.5
            decoded = decoded + decoded.transpose(-1,-2)
            raw = (adj_mat[i]>0.5) + (adj_mat[i]>0.5).transpose(-1,-2)
            check_feasibility.append((decoded == raw).all())

        for i in range(0,len(tours)) :
            if model_args.kl_aux>0 or model_args.kl_grdy>0 or model_args.use_critic>0:
                new_adj =  torch.from_numpy(splitted_real_adj_mat[i])
                if sparse :
                    edgeidx = graph_data[i].edge_index.to(new_adj.device)
                    new_adj = (new_adj[edgeidx[0], edgeidx[1]]==1).to(new_adj.dtype)
            else:
                new_adj = torch.tensor(1)
            new_target.append(new_adj)
            opt_adj = tour2adj(gt_tour[i], np_points[i], sparse, model_args.sparse_factor, np_edge_index)
            opt_cost = (dist[i]*opt_adj).sum()
            sol_cost = (dist[i]*splitted_real_adj_mat[i]).sum()
            opt_adj = opt_adj + opt_adj.T
            opt_target.append(opt_adj)
            opt_costs.append(opt_cost)
            sol_costs.append(sol_cost)
            
        import time
        time = time.time()
        
        x0_pred_gap = cp.deepcopy(x0_pred)
        x0_pred_gap[:,1] = x0_pred[:,1] - x0_pred[:,0]
        x0_pred_gap[:,0] = 0
        x0_pred_gap = torch.maximum(x0_pred_gap, x0_pred_gap.transpose(-1,-2))
        x0_pred_softmax = F.softmax(x0_pred_gap, dim=1)
        opt_target_sum = torch.stack(opt_target)
        kl_loss = F.cross_entropy(x0_pred_gap.cpu(), opt_target_sum.long(), reduction='none').mean([1,2])
        

        for i in range(0,len(tours)):
            print('drawing', i)
            gap = (sol_costs[i]/opt_costs[i] -1)*100
            adj_mat = 2*(splitted_real_adj_mat[i]+splitted_real_adj_mat[i].T)-1*opt_target_sum[i].cpu().numpy()
            # adj_mat = np.where(adj_mat>-1, (splitted_real_adj_mat[i]+splitted_real_adj_mat[i].T), -opt_target_sum[i].cpu().numpy())
            plot_graph(np_points[i],adj_mat, filename=f'figures/decoded_{i}_feas{int(check_feasibility[i])}_sd{model_args.seed}_kl{kl_loss[i]:.3f}_gap{gap:.3f}_{time}.png')

            soft_mat = 2*(x0_pred_softmax[i, 1].cpu().numpy())-1*opt_target_sum[i].cpu().numpy()
            # soft_mat = np.where(soft_mat>-1 x0_pred_softmax[i, 1].cpu().numpy(), -opt_target_sum[i].cpu().numpy())
            plot_graph(np_points[i], soft_mat, filename=f'figures/heatmap_{i}_feas{int(check_feasibility[i])}_sd{model_args.seed}_kl{kl_loss[i]:.3f}_gap{gap:.3f}_{time}.png')
            plot_graph(np_points[i], opt_target_sum[i].cpu().numpy(), filename=f'figures/opt_{i}_feas{int(check_feasibility[i])}_sd{model_args.seed}_kl{kl_loss[i]:.3f}_gap{gap:.3f}.png')

        
        unsolved_tours = torch.tensor(tours)
        unsolved_tours_list.append(unsolved_tours) # [1, 8, 100 + 1]

        

                # unsolved_tours = torch.tensor(tours)
                # unsolved_tours_list.append(unsolved_tours) # [1, 8, 100 + 1]

        np_points_reshape = np_points.reshape([batch_size, -1, 2])
        # tours_reshape = tours.reshape([batch_size, -1])

        for idx in range(batch_size):
            if reward_2opt!='none' or model_args.reward_2opt:
                solved_tours, ns = batched_two_opt_torch(
                        np_points_reshape[idx].astype("float64"),
                        np.array([tours[idx]]).astype("int64"),
                        max_iterations=model_args.two_opt_iterations,
                        device=device,
                        # batch=True,
                )
                solved_tours_list.append(solved_tours)
            else:
                solved_tours_list.append(tours[idx])

        unsolved_tours_list = np.concatenate(unsolved_tours_list, axis=0)
        unsolved_tours_list = torch.tensor(unsolved_tours_list)
        unsolved_tours_list = unsolved_tours_list.view(model_args.sequential_sampling, batch_size, -1)
        
        solved_tours_list = np.concatenate(solved_tours_list, axis=0)
        solved_tours_list = torch.tensor(solved_tours_list)
        solved_tours_list = solved_tours_list.view(model_args.sequential_sampling, batch_size, -1)
        gt_costs, best_unsolved_costs, best_solved_costs = [], [], []

        for batch_idx in range(batch_size):
            tsp_solver = tsp_solvers[batch_idx]
            best_unsolved_costs.append(tsp_solver.evaluate(unsolved_tours_list[:,batch_idx]).min())
            best_solved_costs.append(tsp_solver.evaluate(solved_tours_list[:,batch_idx]).min())

        gt_cost = gt_costs
        wo_2opt_costs = best_unsolved_costs
        best_solved_cost = best_solved_costs

    if reward_gap:
        wo_2opt_costs = torch.tensor(wo_2opt_costs).to('cpu').detach().clone().flatten()/cost.to('cpu').detach().clone().flatten()
        best_solved_costs = torch.tensor(best_solved_costs).to('cpu').detach().clone().flatten()/cost.to('cpu').detach().clone().flatten()

    if model_args.reward_2opt:
        rewards = best_solved_costs
    else:
        rewards = best_unsolved_costs

    if inference :
        return gt_cost, wo_2opt_costs, best_solved_cost
    else :
        if sparse>0:
            np_edge_index = torch.from_numpy(np_edge_index)
        
        return all_latents, np_edge_index, all_log_probs, - torch.tensor(rewards), all_time_steps, all_rb, new_target, all_aux_pred



def test_eval_mis(model, model_diffusion, model_args, batch, inference=False, sparse=False, decode_heatmap=False):
    device = model.device
    real_batch_idx, graph_data, point_indicator = batch
    node_labels = graph_data.x # originally for gt costs
    edge_length = graph_data.edge_length
    edge_index = graph_data.edge_index
    edge_index = edge_index.to(node_labels.device).reshape(2, -1)
    edge_index_np = edge_index.cpu().numpy()
    adj_mat = scipy.sparse.coo_matrix(
        (np.ones_like(edge_index_np[0]), (edge_index_np[0], edge_index_np[1])),
    )
    batch_size = len(point_indicator)
    stacked_xt_labels = []
    stacked_prob_labels = []

    if not inference:
        xt = torch.randn_like(node_labels.float()).to(model.device)
    else:
        xt = torch.randn_like(node_labels.float()).to(model.device)

    if model_args.parallel_sampling > 1:
        edge_index = model.module.duplicate_edge_index(model_args.parallel_sampling, edge_index, node_labels.shape[0], device)

        xt = xt.repeat(model_args.parallel_sampling, 1, 1)
        xt = torch.randn_like(xt)


    if model_args.diffusion_type == 'gaussian':
        xt.requires_grad = True
        raise ValueError('Gaussian diffusion not supported for MIS')
    else:
        xt = (xt > 0).long()

    xt = xt.reshape(-1)
    batch_size = 1
    steps = model_args.inference_diffusion_steps
    time_schedule = InferenceSchedule(inference_schedule=model_args.inference_schedule,
                                        T=model_args.diffusion_steps, inference_T=steps)

    all_latents = [xt.clone().cpu().detach()]
    all_log_probs = []
    all_time_steps = []
    for i in range(steps):
        t1, t2 = time_schedule(i)
        t1 = np.array([t1 for _ in range(1)]).astype(int)
        t2 = np.array([t2 for _ in range(1)]).astype(int)

        if model_args.diffusion_type == 'gaussian':
            xt, log_prob, prob = gaussian_denoise_step_mis(
                model, model_diffusion, xt, t1, device, edge_index, target_t=t2, point_indicator=point_indicator)
        else:
            xt, log_prob, prob, _ = categorical_denoise_step_mis(model, model_diffusion, model_args, xt, t1, device, edge_index, target_t=t2,
                inference=inference, sparse=sparse, point_indicator=point_indicator)

        if log_prob is not None:
            all_latents.append(xt.cpu().detach())
            all_log_probs.append(log_prob.cpu().detach())
            all_time_steps.append(torch.LongTensor([t1[0], t2[0]]))

    if model_args.diffusion_type == 'gaussian':
        xt_labels = xt.float().cpu().detach() * 0.5 + 0.5
        prob_labels = prob.float().cpu().detach().numpy() * 0.5 + 0.5
    else:
        xt_labels = xt.float().cpu().detach() + 1e-6


        prob_labels=  prob.float().cpu().detach().numpy() + 1e-6

    stacked_xt_labels.append(xt_labels)
    stacked_prob_labels.append(prob_labels)


    xt_labels = np.concatenate(stacked_xt_labels, axis=0).flatten()
    prob_labels = np.concatenate(stacked_prob_labels, axis=0).flatten()


    splitted_xt_labels = np.split(xt_labels, model_args.parallel_sampling)
    splitted_prob_labels = np.split(prob_labels, model_args.parallel_sampling)


    solved_solutions_prob = np.array([mis_decode_np(prob_labels, adj_mat) for prob_labels in splitted_prob_labels])

    # print('solved_solutions_cost', solved_solutions_prob)
    gt_cost = node_labels.cpu().numpy().sum()

    solved_costs_xt = []
    solved_costs_prob = []
    idx = 0

    solved_costs = solved_solutions_prob.reshape([model_args.parallel_sampling, -1]).sum(axis=1)
    best_solved_cost = np.max(solved_costs)
    best_solved_id = np.argmax(solved_costs)

    guided_gap, g_best_solved_cost = -1, best_solved_cost
    if model_args.rewrite:
      g_best_solution = solved_solutions_prob[best_solved_id]
      for _ in range(model_args.rewrite_steps):
        g_stacked_predict_labels = []
        g_x0 = torch.from_numpy(g_best_solution).unsqueeze(0).to(device)
        g_x0 = F.one_hot(g_x0.long(), num_classes=2).float()

        steps_T = int(model_args.diffusion_steps * model_args.rewrite_ratio)
        steps_inf = model_args.inference_steps

        time_schedule = InferenceSchedule(inference_schedule=model_args.inference_schedule,
                                          T=steps_T, inference_T=steps_inf)

        Q_bar = model_diffusion.Q_bar[steps_T].float().to(g_x0.device)

        g_xt_prob = torch.matmul(g_x0, Q_bar)  # [B, N, 2]
        g_xt = torch.bernoulli(g_xt_prob[..., 1].clamp(0, 1)).to(g_x0.device)  # [B, N]
        g_xt = g_xt * 2 - 1  # project to [-1, 1]
        g_xt = g_xt * (1.0 + 0.05 * torch.rand_like(g_xt))  # add noise

        if model_args.parallel_sampling > 1:
          g_xt = g_xt.repeat(model_args.parallel_sampling, 1, 1)

        g_xt = (g_xt > 0).long().reshape(-1)
        for i in range(steps_inf):
          t1, t2 = time_schedule(i)
          t1 = np.array([t1]).astype(int)
          t2 = np.array([t2]).astype(int)

          g_xt, _, _, _ = categorical_denoise_step_mis(model, model_diffusion, model_args, g_xt, t1, device, edge_index, target_t=t2, inference=inference, sparse=sparse, point_indicator=point_indicator)

        g_predict_labels = g_xt.float().cpu().detach().numpy() + 1e-6
        g_stacked_predict_labels.append(g_predict_labels)
        g_predict_labels = np.concatenate(g_stacked_predict_labels, axis=0)

        g_splitted_predict_labels = np.split(g_predict_labels, model_args.parallel_sampling)
        g_solved_solutions = [mis_decode_np(g_predict_labels, adj_mat) for g_predict_labels in g_splitted_predict_labels]
        g_solved_costs = [g_solved_solution.sum() for g_solved_solution in g_solved_solutions]
        g_best_solved_cost = np.max([g_best_solved_cost, np.max(g_solved_costs)])
        g_best_solved_id = np.argmax(g_solved_costs)

        g_best_solution = g_solved_solutions[g_best_solved_id]

    #   print(f'tot_points: {g_x0.shape[-2]}, gt_cost: {gt_cost}, selected_points: {best_solved_cost} -> {g_best_solved_cost}')

    metrics = {
        "test/rewrite_ratio": model_args.rewrite_ratio,
        # "test/norm": model_args.norm,
        # "test/gap": gap,
        # "test/guided_gap": guided_gap,
        "test/gt_cost": gt_cost,
        "test/guided_solved_cost": g_best_solved_cost,
    }
    # for k, v in metrics.items():
    #     model.log(k, v, on_epoch=True, sync_dist=True)
    # model.log("test/solved_cost", g_best_solved_cost, prog_bar=True, on_epoch=True, sync_dist=True)
    
    
    return [g_best_solved_cost], [best_solved_cost]
    


def difusco_with_logprob_mis(model, model_diffusion, model_args, batch, inference=False, sparse=False, decode_heatmap=False):
    device = model.device
    real_batch_idx, graph_data, point_indicator = batch
    node_labels = graph_data.x # originally for gt costs
    edge_length = graph_data.edge_length
    edge_index = graph_data.edge_index
    edge_index = edge_index.to(node_labels.device).reshape(2, -1)
    edge_index_np = edge_index.cpu().numpy()
    adj_mat = scipy.sparse.coo_matrix(
        (np.ones_like(edge_index_np[0]), (edge_index_np[0], edge_index_np[1])),
    )

    # print(' model_args.diffusion_type', model_args.diffusion_type)
    batch_size = len(point_indicator)
    stacked_xt_labels = []
    stacked_prob_labels = []

    for _ in range(model_args.sequential_sampling):
        if not inference:
            xt = torch.randn_like(node_labels.float()).to(model.device)
        else:
            xt = torch.randn_like(node_labels.float()).to(model.device)

        if model_args.diffusion_type == 'gaussian':
            xt.requires_grad = True
        else:
            xt = (xt > 0).long()

        xt = xt.reshape(-1)

        steps = model_args.inference_diffusion_steps
        time_schedule = InferenceSchedule(inference_schedule=model_args.inference_schedule,
                                          T=model_args.diffusion_steps, inference_T=steps)

        all_latents = [xt.clone().cpu().detach()]
        all_log_probs = []
        all_time_steps = []
        all_aux_pred = []
        for i in range(steps):
            t1, t2 = time_schedule(i)
            t1 = np.array([t1 for _ in range(1)]).astype(int)
            t2 = np.array([t2 for _ in range(1)]).astype(int)

            if model_args.diffusion_type == 'gaussian':
                xt, log_prob, prob = gaussian_denoise_step_mis(
                    model, model_diffusion, xt, t1, device, edge_index, target_t=t2, point_indicator=point_indicator)
            else:
                xt, log_prob, prob,_, aux_pred = categorical_denoise_step_mis(model, model_diffusion, model_args, xt, t1, device, edge_index, target_t=t2,
                    inference=inference, sparse=sparse, point_indicator=point_indicator,aux=True)

            if log_prob is not None:
                if model_args.use_critic>0 or model_args.kl_aux>0:
                    all_aux_pred.append(aux_pred.cpu().detach())
                all_latents.append(xt.cpu().detach())
                all_log_probs.append(log_prob.cpu().detach())
                all_time_steps.append(torch.LongTensor([t1[0], t2[0]]))

        if model_args.diffusion_type == 'gaussian':
            xt_labels = xt.float().cpu().detach() * 0.5 + 0.5
            prob_labels = prob.float().cpu().detach().numpy() * 0.5 + 0.5
        else:
            xt_labels = xt.float().cpu().detach() + 1e-6
            prob_labels=  prob.float().cpu().detach().numpy() + 1e-6

        stacked_xt_labels.append(xt_labels)
        stacked_prob_labels.append(prob_labels)


    xt_labels = np.concatenate(stacked_xt_labels, axis=0)
    prob_labels = np.concatenate(stacked_prob_labels, axis=0)

    splitted_xt_labels = np.split(xt_labels, model_args.sequential_sampling)
    splitted_prob_labels = np.split(prob_labels, model_args.sequential_sampling)

    solved_solutions_xt = np.array([mis_decode_np(xt_labels, adj_mat) for xt_labels in splitted_xt_labels])[0]
    solved_solutions_prob = np.array([mis_decode_np(prob_labels, adj_mat) for prob_labels in splitted_prob_labels])[0]

    # solved_solutions_xt_degree = np.array([mis_decode_degree(xt_labels, adj_mat) for xt_labels in splitted_xt_labels])[0]
    # solved_solutions_prob_degree = np.array([mis_decode_degree(prob_labels, adj_mat) for prob_labels in splitted_prob_labels])[0]

    gt_cost = node_labels.cpu().numpy().sum()

    solved_costs_xt = []
    solved_costs_prob = []
    idx = 0
    if decode_heatmap:
        for length in point_indicator:
            solved_costs_prob.append(solved_solutions_prob[idx:idx+length].sum())
            idx+=length
        
        rewards = solved_costs_prob
        new_target = solved_solutions_prob     
        
    else :
        
        for length in point_indicator:
            solved_costs_xt.append(solved_solutions_xt[idx:idx+length].sum())
            idx+=length
        rewards = solved_costs_xt
        new_target =solved_solutions_xt 

    if inference:
        return gt_cost, solved_costs_xt, solved_costs_prob
    else:
        return all_latents, edge_length, edge_index, all_log_probs, torch.tensor(rewards), all_time_steps, torch.tensor(new_target),all_aux_pred # reward is positive in MIS
    

def _get_variance(diffusion, timestep, prev_timestep):
    alphabar = torch.from_numpy(diffusion.alphabar)
    # print('zalphabar',alphabar, len(alphabar))

    alpha_prod_t = torch.gather(alphabar, 0, timestep.cpu()).to(
        timestep.device
    )
    alpha_prod_t_prev = torch.where(
        prev_timestep.cpu() >= 0,
        alphabar.gather(0, prev_timestep.cpu()),
        alphabar[0],
        # self.final_alpha_cumprod,
    ).to(timestep.device)
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev
    # print('alphabar[0]',alphabar[0])
    if prev_timestep==0:
        variance = (1 - alpha_prod_t / alpha_prod_t_prev)
    
    else:
        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
    # print('timestep',timestep,'prev_timestep',prev_timestep , 'variance', variance, 'beta_prod_t',beta_prod_t, 'beta_prod_t_prev', beta_prod_t_prev, 'alpha_prod_t',alpha_prod_t, 'alpha_prod_t_prev',alpha_prod_t_prev)

    return variance

def tour2adj(tour, points, sparse, sparse_factor, edge_index):
    if not sparse:
        adj_matrix = torch.zeros((points.shape[0], points.shape[0]))
        for i in range(tour.shape[0] - 1):
            adj_matrix[tour[i], tour[i + 1]] = 1
    else:
        adj_matrix = np.zeros(points.shape[0], dtype=np.int64)
        adj_matrix[tour[:-1]] = tour[1:]
        adj_matrix = torch.from_numpy(adj_matrix)
        adj_matrix = adj_matrix.reshape((-1, 1)).repeat(1, sparse_factor).reshape(-1)
        adj_matrix = torch.eq(edge_index[1].cpu(), adj_matrix).to(torch.int)
    return adj_matrix