from dataclasses import dataclass
import torch

from CVRProblemDef import get_random_problems, augment_xy_data_by_8_fold


@dataclass
class Reset_State:
    depot_xy: torch.Tensor = None
    # shape: (batch, 1, 2)
    node_xy: torch.Tensor = None
    # shape: (batch, problem, 2)
    node_demand: torch.Tensor = None
    # shape: (batch, problem)


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor = None
    POMO_IDX: torch.Tensor = None
    # shape: (batch, pomo)
    selected_count: int = None
    load: torch.Tensor = None
    # shape: (batch, pomo)
    current_node: torch.Tensor = None
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None
    old_ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, problem+1)
    finished: torch.Tensor = None
    probs = None
    distances = None
    cs_encoded_nodes = None
    reembeded_nodes = None

    # shape: (batch, pomo)

    def get_closest_nodes(self, last_node_index, k):
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))
        _, topk_index = torch.topk(current_dis_list, k)
        return topk_index

    def get_closest_cs(self, last_node_index, number_of_css, k=1):
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))
        _, topk_index = torch.topk(current_dis_list[:, :, -number_of_css:], k)
        return topk_index

    def get_closest_node_list(self, last_node_index, ninf_mask):
        k = 20
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))

        current_dis_list = current_dis_list + ninf_mask
        _, topk_index = torch.topk(current_dis_list, k)
        return topk_index


class CVRPEnv:
    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']
        self.pomo_size = env_params['pomo_size']
        self.charging_station_size = env_params['charging_station_size']
        self.bucket_size = env_params['bucket_size']
        self.number_of_locality_att = env_params['number_of_locality_att']


        self.FLAG__use_saved_problems = False
        self.saved_depot_xy = None
        self.saved_node_xy = None
        self.saved_cs_xy = None
        self.saved_node_demand = None
        self.saved_index = None

        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.POMO_IDX = None
        # IDX.shape: (batch, pomo)
        self.depot_node_xy = None
        # shape: (batch, problem+1, 2)
        self.depot_node_demand = None
        self.extended_cs_embedding = None
        # shape: (batch, problem+1)

        # Dynamic-1
        ####################################
        self.selected_count = None
        self.current_node = None
        self.last_node_index = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = None
        # shape: (batch, pomo)
        self.load = None
        # shape: (batch, pomo)
        self.charge = None
        # shape: (batch, pomo)
        self.visited_ninf_flag = None
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = None
        # shape: (batch, pomo, problem+1)
        self.finished = None
        self.distances = None
        self.whole_theta = None
        self.whole_dist = None
        self.probs = None
        # shape: (batch, pomo)

        # states to return
        ####################################
        self.reset_state = Reset_State()
        self.step_state = Step_State()

    def use_saved_problems(self, filename, device):
        self.FLAG__use_saved_problems = True

        loaded_dict = torch.load(filename, map_location=device)
        self.saved_depot_xy = loaded_dict['depot_xy']
        self.saved_node_xy = loaded_dict['node_xy']
        self.saved_cs_xy = loaded_dict['cs_xy']
        self.saved_node_demand = loaded_dict['node_demand']
        self.saved_index = 0

    def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size

        if not self.FLAG__use_saved_problems:
            depot_xy, node_xy, cs_xy, node_demand = get_random_problems(batch_size, self.problem_size,
                                                                        self.charging_station_size - 1)
        else:
            depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index + batch_size]
            node_xy = self.saved_node_xy[self.saved_index:self.saved_index + batch_size]
            cs_xy = self.saved_cs_xy[self.saved_index:self.saved_index + batch_size]
            node_demand = self.saved_node_demand[self.saved_index:self.saved_index + batch_size]
            self.saved_index += batch_size

        if aug_factor > 1:
            if aug_factor == 8:
                self.batch_size = self.batch_size * 8
                depot_xy = augment_xy_data_by_8_fold(depot_xy)
                node_xy = augment_xy_data_by_8_fold(node_xy)
                cs_xy = augment_xy_data_by_8_fold(cs_xy)
                node_demand = node_demand.repeat(8, 1)
            else:
                raise NotImplementedError

        self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
        self.depot_node_cs_xy = torch.cat((depot_xy, node_xy, cs_xy), dim=1)
        self.cs_xy = cs_xy
        self.max_travel_distance = torch.ones(size=(self.batch_size, 1)) * 2

        # shape: (batch, problem+1, 2)
        depot_demand = torch.zeros(size=(self.batch_size, 1))
        cs_demand = torch.zeros(size=(self.batch_size, cs_xy.size(1)))
        # shape: (batch, 1)
        self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
        self.depot_node_cs_demand = torch.cat((depot_demand, node_demand, cs_demand), dim=1)
        # shape: (batch, problem+1)

        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)
        self.CS_IDX = torch.arange(self.problem_size + 1, self.problem_size + self.charging_station_size + 1)[None, :].expand(
            self.batch_size, self.charging_station_size)

        self.reset_state.depot_xy = depot_xy
        self.reset_state.cs_xy = cs_xy
        self.reset_state.node_xy = node_xy
        self.reset_state.whole_nodes = torch.cat((depot_xy, node_xy, cs_xy), -2)
        self.reset_state.node_demand = node_demand

        self.step_state.CS_IDX = self.CS_IDX
        self.step_state.BATCH_IDX = self.BATCH_IDX
        self.step_state.POMO_IDX = self.POMO_IDX
        diff = self.depot_node_cs_xy.unsqueeze(2) - self.depot_node_cs_xy.unsqueeze(1)
        self.step_state.distances = - torch.norm(diff, dim=-1)
        self.reset_state.distances = self.step_state.distances

        diff = self.depot_node_xy.unsqueeze(2) - self.depot_node_xy.unsqueeze(1)
        self.step_state.node_distances = - torch.norm(diff, dim=-1)

        diff = self.cs_xy.unsqueeze(2) - self.cs_xy.unsqueeze(1)
        self.step_state.cs_distances = - torch.norm(diff, dim=-1)

        self.step_state.node_sparse_mask, self.step_state.node_top_k_indices = self.compute_sparse_attention_mask(self.depot_node_xy, top_k=self.node_top_k_count)
        self.step_state.cs_sparse_mask, self.step_state.cs_top_k_indices = self.compute_sparse_attention_mask(self.cs_xy, top_k=self.cs_top_k_count)
        self.reset_state.node_sparse_mask = self.step_state.node_sparse_mask
        self.reset_state.cs_sparse_mask = self.step_state.cs_sparse_mask

        shift_coors = self.depot_node_cs_xy - self.depot_node_cs_xy[:, 0:1, :]
        _x, _y = shift_coors[:, :, 0], shift_coors[:, :, 1]
        r = torch.sqrt(_x ** 2 + _y ** 2)
        theta = torch.atan2(_y, _x)
        x = torch.stack((r, theta, self.depot_node_cs_demand), -1)
        x_depot = x[:, 0:1, :]
        x_node = x[:, 1:-self.charging_station_size, :]
        x_cs = x[:, -self.charging_station_size:, 0:2]
        shift_coors_without_cs = self.depot_node_xy - self.depot_node_xy[:, 0:1, :]
        self.reset_state.shift_coors = shift_coors_without_cs
        self.reset_state.x_depot = x_depot
        self.reset_state.x_node = x_node
        self.reset_state.x_cs = x_cs

        r_min = r[:,1:-self.charging_station_size].min(dim=-1, keepdim=True).values
        r_max = r[:,1:-self.charging_station_size].max(dim=-1, keepdim=True).values
        r_norm = (r[:,1:-self.charging_station_size] - r_min) / (r_max - r_min + 1e-8)
        theta2 = theta % (2 * torch.pi)
        theta_min = theta2[:,1:-self.charging_station_size].min(dim=-1, keepdim=True).values
        theta_max = theta2[:,1:-self.charging_station_size].max(dim=-1, keepdim=True).values
        theta_norm = (theta2[:,1:-self.charging_station_size] - theta_min) / (theta_max - theta_min + 1e-8)

        weights = torch.linspace(0, 1, steps=self.number_of_locality_att)  # shape: (n,)
        weight_A = weights.view(1, self.number_of_locality_att, 1)  # shape: (n, 1, 1)
        weight_B = (1 - weights).view(1, self.number_of_locality_att, 1)  # shape: (n, 1, 1)

        # Step 2: Apply weights and sum A and B
        # Expand A and B to shape: (1, 2, 111) → broadcast to (n, 2, 111)
        A_exp = r_norm.unsqueeze(1)  # (1, 2, 111)
        B_exp = theta_norm.unsqueeze(1)
        eee = weight_A * A_exp
        output = (weight_A * A_exp + weight_B * B_exp)  # shape: (n, 2, 111)
        _, sorted_indices = torch.sort(output, dim=-1)
        half_bucket = int(self.bucket_size/2)
        x_reordered = torch.cat((sorted_indices[:, :, half_bucket:], sorted_indices[:, :, :half_bucket]), dim=-1)
        sorted_indices = torch.cat((sorted_indices, x_reordered), dim=1)

        total = sorted_indices.shape[-1]
        num_buckets = total // self.bucket_size
        sorted_indices = sorted_indices[:, :, :num_buckets * self.bucket_size]  # remove extra
        bucketed_indices = sorted_indices.view(self.batch_size, self.number_of_locality_att*2, num_buckets, self.bucket_size) +1
        zero_padding = torch.zeros(bucketed_indices.size(0), bucketed_indices.size(1), bucketed_indices.size(2), 1, dtype=bucketed_indices.dtype, device=bucketed_indices.device)
        bucketed_indices = torch.cat([zero_padding, bucketed_indices], dim=-1)
        # shift_coors = self.depot_node_cs_xy.unsqueeze(2) - self.depot_node_cs_xy.unsqueeze(1)
        # r = torch.sqrt(torch.sum(shift_coors ** 2, dim=-1))  # Shape: (2, 100, 100)
        # theta = torch.atan2(shift_coors[..., 1], shift_coors[..., 0])
        self.reset_state.buckets = bucketed_indices
        self.reset_state.sorted_indices = sorted_indices
        self.step_state.xy = self.depot_node_cs_xy
        self.step_state.theta = theta
        self.step_state.r = r
        self.step_state.norm_demand = self.depot_node_cs_demand
        # self.step_state.whole_theta = theta[:, None, :, :].expand(theta.size(0), self.pomo_size, theta.size(1),
        #                                                           theta.size(2))
        # self.step_state.whole_dist = r[:, None, :, :].expand(r.size(0), self.pomo_size, r.size(1), r.size(2))

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        self.last_node_index = torch.zeros((self.batch_size, self.pomo_size, 1), dtype=torch.long)
        # shape: (batch, pomo)
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~)

        self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size + self.charging_station_size + 1),
                                       dtype=torch.bool)
        # shape: (batch, pomo)
        self.load = torch.ones(size=(self.batch_size, self.pomo_size))
        # shape: (batch, pomo)
        self.charge = torch.ones(size=(self.batch_size, self.pomo_size))

        # shape: (batch, pomo)
        self.visited_ninf_flag = torch.zeros(
            size=(self.batch_size, self.pomo_size, self.problem_size + self.charging_station_size + 1))
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size + 1))
        # shape: (batch, pomo, problem+1)
        self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)

        self.distances = None
        # shape: (batch, pomo)

        reward = None
        done = False
        return self.reset_state, reward, done

    def compute_sparse_attention_mask(self, coords, top_k=20):
        batch_size, num_nodes, _ = coords.shape

        # Compute pairwise Euclidean distance (batch-wise)
        coords_exp = coords.unsqueeze(2)  # Shape: (batch_size, num_nodes, 1, 2)
        distances = torch.norm(coords_exp - coords.unsqueeze(1), dim=-1)  # (batch_size, num_nodes, num_nodes)

        # Get indices of the top_k closest nodes (excluding itself)
        top_k_indices = distances.argsort(dim=-1)[:, :, 1:top_k + 1]  # Shape: (batch_size, num_nodes, top_k)

        # Create a full mask initialized to -inf
        mask = torch.full((batch_size, num_nodes, num_nodes), True, device=coords.device)

        # Create an index mask to select top_k indices without looping
        batch_indices = torch.arange(batch_size, device=coords.device).view(batch_size, 1, 1)
        node_indices = torch.arange(num_nodes, device=coords.device).view(1, num_nodes, 1)

        # Assign 0 to the top_k closest neighbors
        mask[batch_indices, node_indices, top_k_indices] = False
        mask[:, 0, :] = False
        mask[:, :, 0] = False
        # mask[:, -self.charging_station_size:, :] = 0
        # mask[:, :, -self.charging_station_size:] = 0
        return mask, top_k_indices

    def pre_step(self):
        self.step_state.selected_count = self.selected_count
        self.step_state.load = self.load
        self.step_state.charge = self.charge
        self.step_state.old_charge = self.charge
        self.step_state.current_node = self.current_node
        self.step_state.ninf_mask = self.ninf_mask
        self.step_state.finished = self.finished
        self.step_state.if_cannot_make_travel_finished = None
        self.step_state.selected = None
        self.step_state.current_charge = None
        self.step_state.distance = None
        self.step_state.distance = None
        self.step_state.distance = None
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, selected):
        # selected.shape: (batch, pomo)

        # Dynamic-1
        ####################################
        self.selected_count += 1
        self.current_node = selected
        factor = 10_00
        last_xy = torch.gather(self.depot_node_cs_xy, 1,
                               self.last_node_index.expand(self.last_node_index.size(0), self.last_node_index.size(1),
                                                           2))
        current_xy = torch.gather(self.depot_node_cs_xy, 1,
                                  selected[..., None].expand(selected.size(0), selected.size(1), 2))
        traveled_distance = torch.norm(last_xy - current_xy, dim=2)
        # zero_dist = (traveled_distance < 0.0001)
        # traveled_distance[zero_dist] = 0
        traveled_distance = (traveled_distance * factor).int().float() / factor
        required_energy_for_last_move = traveled_distance / self.max_travel_distance
        required_energy_for_last_move = (required_energy_for_last_move * factor).int().float() / factor
        self.charge = self.charge - required_energy_for_last_move

        # self.charge = (self.charge * factor).int().float() / factor
        # if torch.any(self.charge < -0.0001):
        #     raise ValueError(" Charge contains negative elements.")

        if torch.any(self.load < -0.00001):
            raise ValueError(" load contains negative elements.")

        self.last_node_index = selected.clone().unsqueeze(-1)
        # shape: (batch, pomo)
        self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = (selected == 0)
        self.at_cs = torch.any(selected.unsqueeze(-1) == self.CS_IDX.unsqueeze(1), dim=-1)
        # charge_list = self.charge[:,None,:].expand(self.batch_size, self.pomo_size, -1)

        demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
        demand_list2 = self.depot_node_cs_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
        # shape: (batch, pomo, problem+1)
        gathering_index = selected[:, :, None]
        # shape: (batch, pomo, 1)
        selected_demand = demand_list2.gather(dim=2, index=gathering_index).squeeze(dim=2)
        # shape: (batch, pomo)
        self.load -= selected_demand
        self.load[self.at_the_depot] = 1  # refill loaded at the depot
        self.charge[self.at_the_depot] = 1
        self.charge[self.at_cs] = 1
        # ---------------------------------------------------------------------
        factor = 10_00
        # distance_from_current_to_all_others = torch.cdist(current_xy, self.depot_node_cs_xy, p=2)
        diff = current_xy.unsqueeze(2) - self.depot_node_cs_xy.unsqueeze(1)  # Shape: (2, 100, 106, 2)
        distance_from_current_to_all_others = torch.norm(diff, dim=-1)
        # if not torch.equal(distance_from_current_to_all_others, c):
        #     print("sdfasdfasdfa")
        # distance_from_all_nodes_to_css, _ = torch.min(torch.cdist(self.depot_node_cs_xy, self.cs_xy, p=2), -1)
        # xx = torch.cdist(self.depot_node_cs_xy, self.cs_xy, p=2)
        diff2 = self.depot_node_cs_xy.unsqueeze(2) - self.cs_xy.unsqueeze(1)  # Shape: (2, 100, 106, 2)
        xxx = torch.norm(diff2, dim=-1)
        # if not torch.equal(xx, xxx):
        #     print("sdfsadfsadfsadf")
        distance_from_all_nodes_to_css, _ = torch.min(xxx, -1)
        # zero_dist = (distance_from_current_to_all_others < 0.0001)
        # distance_from_current_to_all_others[zero_dist] = 0
        distance_from_current_to_all_others = (distance_from_current_to_all_others * factor).int().float() / factor

        # zero_dist = (distance_from_all_nodes_to_css < 0.0001)
        # distance_from_all_nodes_to_css[zero_dist] = 0
        distance_from_all_nodes_to_css = (distance_from_all_nodes_to_css * factor).int().float() / factor

        # distance_to_next_node_and_cs = distance_from_current_to_all_others + distance_from_all_nodes_to_css[:, None,
        #                                                                      :].expand_as(distance_from_current_to_all_others)
        #
        # distance_from_all_nodes_to_css = distance_from_all_nodes_to_css[:, None, :].expand_as(
        #     distance_from_current_to_all_others)
        #
        # required_energy_for_last_move1 = distance_from_current_to_all_others / self.max_travel_distance[:,None,:]
        # required_energy_for_last_move1 = (required_energy_for_last_move1 * factor).int().float() / factor
        #
        # required_energy_for_last_move2 = distance_from_all_nodes_to_css / self.max_travel_distance[:,None,:]
        #
        # required_energy_for_last_move2 = (required_energy_for_last_move2 * factor).int().float() / factor
        # current_charge = self.charge.clone()
        # current_charge = current_charge[...,None].expand_as(required_energy_for_last_move1) - required_energy_for_last_move1 - required_energy_for_last_move2
        # current_charge = (current_charge * factor).int().float() / factor
        # if_cannot_make_travel_finished = current_charge < 0

        distance_to_next_node_and_cs = distance_from_current_to_all_others + distance_from_all_nodes_to_css[:, None,
                                                                             :].expand_as(
            distance_from_current_to_all_others)
        aaa = ((self.charge) * self.max_travel_distance).unsqueeze(-1)
        # aaa = (aaa * factor).int().float() / factor
        if_cannot_make_travel_finished1 = (distance_to_next_node_and_cs > aaa - 0.0001)
        if_cannot_make_travel_finished2 = (distance_to_next_node_and_cs > aaa + 0.0001)
        # if_cannot_make_travel_finished3 = (distance_to_next_node_and_cs-0.0001 > aaa)
        # if_cannot_make_travel_finished4 = (distance_to_next_node_and_cs+0.0001 > aaa)
        if_cannot_make_travel_finished = if_cannot_make_travel_finished1 & if_cannot_make_travel_finished2  # & if_cannot_make_travel_finished3 & if_cannot_make_travel_finished4
        # --------------------------------------------------

        self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
        self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, self.problem_size + 1:] = 0
        # shape: (batch, pomo, problem+1)
        self.visited_ninf_flag[:, :, 0][
            ~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot
        cs_extended = self.at_cs.unsqueeze(-1).expand(-1, -1, self.charging_station_size)
        depot_extended = self.at_the_depot.unsqueeze(-1).expand(-1, -1, self.charging_station_size)
        overall_extended = cs_extended | depot_extended
        self.visited_ninf_flag[:, :, self.problem_size + 1:][overall_extended] = float('-inf')
        # print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
        # for i in range(2):
        #     print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        #     for j in range(100):
        #         print(self.visited_ninf_flag[i,j,:])

        self.ninf_mask = self.visited_ninf_flag.clone()
        round_error_epsilon = 0.00001
        demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
        # shape: (batch, pomo, problem+1)
        self.ninf_mask[:, :, :self.problem_size + 1][demand_too_large] = float('-inf')
        # self.ninf_mask[self.BATCH_IDX,self.POMO_IDX, :self.pomo_size+1][demand_too_large] = float('-inf')
        self.ninf_mask[if_cannot_make_travel_finished] = float('-inf')

        # shape: (batch, pomo, problem+1)
        # if torch.all(self.ninf_mask == float('-inf')):
        #     print("All elements in tensor mask are -inf.")
        #     if torch.all(self.visited_ninf_flag== float('-inf')):
        #         print("All elements in visited_ninf_flag mask are -inf.")
        #     if torch.all(if_cannot_make_travel_finished== True):
        #         print("All elements in if_cannot_make_travel_finished mask are -inf.")

        # raise ValueError("All elements in tensor mask are -inf.")
        newly_finished = (
                self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, 0: self.problem_size] == float('-inf')).all(
            dim=2)
        # shape: (batch, pomo)
        self.finished = self.finished + newly_finished
        # shape: (batch, pomo)

        # do not mask depot for finished episode.
        self.ninf_mask[:, :, 0][self.finished] = 0

        self.step_state.selected_count = self.selected_count
        self.step_state.load = self.load
        self.step_state.current_node = self.current_node
        self.step_state.old_ninf_mask = self.step_state.ninf_mask
        self.step_state.ninf_mask = self.ninf_mask
        self.step_state.finished = self.finished
        self.step_state.if_cannot_make_travel_finished = if_cannot_make_travel_finished
        # self.step_state.old_charge = self.step_state.charge
        self.step_state.charge = self.charge
        # self.step_state.old_aaa = self.step_state.aaa
        # self.step_state.aaa = aaa
        # self.step_state.old_distance_can_travel = self.step_state.distance_can_travel
        # self.step_state.distance_can_travel = distance_to_next_node_and_cs
        # self.step_state.distance_from_current_to_all_others = distance_from_current_to_all_others
        # self.step_state.distance_from_all_nodes_to_css = distance_from_all_nodes_to_css
        # self.step_state.current_charge = current_charge

        # returning values
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done

    def _get_travel_distance(self):
        gathering_index = self.selected_node_list[:, :, :, None].expand(-1, -1, -1, 2)
        # for i in range(self.selected_node_list.size(0)):
        #     for j in range(self.selected_node_list.size(1)):
        #         print(self.selected_node_list[i,j,:])
        # shape: (batch, pomo, selected_list_length, 2)
        all_xy = self.depot_node_cs_xy[:, None, :, :].expand(-1, self.pomo_size, -1, -1)
        # shape: (batch, pomo, problem+1, 2)

        ordered_seq = all_xy.gather(dim=2, index=gathering_index)
        # shape: (batch, pomo, selected_list_length, 2)

        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        segment_lengths = ((ordered_seq - rolled_seq) ** 2).sum(3).sqrt()
        # shape: (batch, pomo, selected_list_length)

        travel_distances = segment_lengths.sum(2)
        # min_value, min_index = torch.min(travel_distances, 1)
        # k=-1
        # for i in min_index:
        #     k+=1
        #     print(self.selected_node_list[k, i, ...])
        # shape: (batch, pomo)
        return travel_distances
