
from dataclasses import dataclass
import torch
torch.set_default_device('cuda')
from TSProblemDef import get_random_problems, augment_xy_data_by_8_fold

from dataclasses import dataclass
import torch

@dataclass
class Reset_State:
    problems: torch.Tensor
    batch_idx: torch.Tensor = None
    current_node: torch.Tensor = None
    selected_count: torch.Tensor = None
    ninf_mask: torch.Tensor = None
    finished: torch.Tensor = None

    def to(self, device):
        self.problems = self.problems.to(device)
        if self.batch_idx is not None:
            self.batch_idx = self.batch_idx.to(device)
        if self.current_node is not None:
            self.current_node = self.current_node.to(device)
        if self.selected_count is not None:
            self.selected_count = self.selected_count.to(device)
        if self.ninf_mask is not None:
            self.ninf_mask = self.ninf_mask.to(device)
        if self.finished is not None:
            self.finished = self.finished.to(device)
        return self
# @dataclass
# class Reset_State:
#     problems: torch.Tensor
#     # def to(self, device):
#     #     self.problems = self.problems.to(device)
#     #     self.batch_idx = self.batch_idx.to(device)
#     #     self.current_node = self.current_node.to(device)
#     #     self.selected_count = self.selected_count.to(device)
#     #     self.ninf_mask = self.ninf_mask.to(device)
#     #     self.finished = self.finished.to(device)
#     #     return self
#     # shape: (batch, problem, 2)


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor
    POMO_IDX: torch.Tensor
    # shape: (batch, pomo)
    current_node: torch.Tensor = None
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, node)


class TSPEnv:
    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']
        self.pomo_size = env_params['pomo_size']
        self.test_file_path = env_params['test_file_path']

        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.POMO_IDX = None
        # IDX.shape: (batch, pomo)
        self.problems = None
        # shape: (batch, node, node)

        # Dynamic
        ####################################
        self.selected_count = None
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~problem)

    def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size
        if self.test_file_path is not None:
            self.problems = torch.load(self.test_file_path)
        else:
            self.problems = get_random_problems(batch_size, self.problem_size)
        # problems.shape: (batch, problem, 2)

        if aug_factor > 1:
            if aug_factor == 8:
                self.batch_size = self.batch_size * 8
                self.problems = augment_xy_data_by_8_fold(self.problems)
                # shape: (8*batch, problem, 2)
            else:
                raise NotImplementedError

        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~problem)

        # CREATE STEP STATE
        self.step_state = Step_State(BATCH_IDX=self.BATCH_IDX, POMO_IDX=self.POMO_IDX)
        self.step_state.ninf_mask = torch.zeros((self.batch_size, self.pomo_size, self.problem_size))
        # shape: (batch, pomo, problem)

        reward = None
        done = False
        return Reset_State(self.problems), reward, done

    def pre_step(self):
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, selected):
        # selected.shape: (batch, pomo)

        self.selected_count += 1
        self.current_node = selected
        # shape: (batch, pomo)
        self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
        # shape: (batch, pomo, 0~problem)

        # UPDATE STEP STATE
        self.step_state.current_node = self.current_node
        # shape: (batch, pomo)
        self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, self.current_node] = float('-inf')
        # shape: (batch, pomo, node)

        # returning values
        done = (self.selected_count == self.problem_size)
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done





    def _get_travel_distance(self):
        device = self.problems.device  # 以问题数据所在设备为准

        gathering_index = self.selected_node_list.unsqueeze(3).expand(self.batch_size, -1, self.problem_size, 2).to(device)
        seq_expanded = self.problems[:, None, :, :].expand(self.batch_size, self.pomo_size, self.problem_size, 2).to(device)

        ordered_seq = seq_expanded.gather(dim=2, index=gathering_index)
        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        segment_lengths = ((ordered_seq - rolled_seq) ** 2).sum(3).sqrt()
        travel_distances = segment_lengths.sum(2)

        return travel_distances
    # def _get_travel_distance(self):
    #     gathering_index = self.selected_node_list.unsqueeze(3).expand(self.batch_size, -1, self.problem_size, 2)
    #     # shape: (batch, pomo, problem, 2)
    #     seq_expanded = self.problems[:, None, :, :].expand(self.batch_size, self.pomo_size, self.problem_size, 2)

    #     ordered_seq = seq_expanded.gather(dim=2, index=gathering_index)
    #     # shape: (batch, pomo, problem, 2)

    #     rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
    #     segment_lengths = ((ordered_seq-rolled_seq)**2).sum(3).sqrt()
    #     # shape: (batch, pomo, problem)

    #     travel_distances = segment_lengths.sum(2)
    #     # shape: (batch, pomo)
    #     return travel_distances


# from dataclasses import dataclass
# import torch

# from TSProblemDef import get_random_problems, augment_xy_data_by_8_fold


# @dataclass
# class Reset_State:
#     problems: torch.Tensor
#     # shape: (batch, problem, 2)


# @dataclass
# class Step_State:
#     BATCH_IDX: torch.Tensor
#     POMO_IDX: torch.Tensor
#     # shape: (batch, pomo)
#     current_node: torch.Tensor = None
#     # shape: (batch, pomo)
#     ninf_mask: torch.Tensor = None
#     # shape: (batch, pomo, node)


# class TSPEnv:
#     def __init__(self, **env_params):

#         # Const @INIT
#         ####################################
#         self.env_params = env_params
#         self.problem_size = env_params['problem_size']
#         self.pomo_size = env_params['pomo_size']
#         self.test_file_path = env_params['test_file_path']

#         # Const @Load_Problem
#         ####################################
#         self.batch_size = None
#         self.BATCH_IDX = None
#         self.POMO_IDX = None
#         # IDX.shape: (batch, pomo)
#         self.problems = None
#         # shape: (batch, node, node)

#         # Dynamic
#         ####################################
#         self.selected_count = None
#         self.current_node = None
#         # shape: (batch, pomo)
#         self.selected_node_list = None
#         # shape: (batch, pomo, 0~problem)

#     def load_problems(self, batch_size, aug_factor=1):
#         self.batch_size = batch_size
#         if self.test_file_path is not None:
#             self.problems = torch.load(self.test_file_path)
#         else:
#             self.problems = get_random_problems(batch_size, self.problem_size)
#         # problems.shape: (batch, problem, 2)

#         if aug_factor > 1:
#             if aug_factor == 8:
#                 self.batch_size = self.batch_size * 8
#                 self.problems = augment_xy_data_by_8_fold(self.problems)
#                 # shape: (8*batch, problem, 2)
#             else:
#                 raise NotImplementedError

#         self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
#         self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

#     def reset(self):
#         self.selected_count = 0
#         self.current_node = None
#         # shape: (batch, pomo)
#         self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
#         # shape: (batch, pomo, 0~problem)

#         # CREATE STEP STATE
#         self.step_state = Step_State(BATCH_IDX=self.BATCH_IDX, POMO_IDX=self.POMO_IDX)
#         self.step_state.ninf_mask = torch.zeros((self.batch_size, self.pomo_size, self.problem_size))
#         # shape: (batch, pomo, problem)

#         reward = None
#         done = False
#         return Reset_State(self.problems), reward, done

#     def pre_step(self):
#         reward = None
#         done = False
#         return self.step_state, reward, done

#     # def step(self, selected):
#     #     # selected.shape: (batch, pomo)

#     #     self.selected_count += 1
#     #     self.current_node = selected
#     #     # shape: (batch, pomo)
#     #     self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
#     #     # shape: (batch, pomo, 0~problem)

#     #     # UPDATE STEP STATE
#     #     self.step_state.current_node = self.current_node
#     #     # shape: (batch, pomo)
#     #     self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, self.current_node] = float('-inf')
#     #     # shape: (batch, pomo, node)

#     #     # returning values
#     #     done = (self.selected_count == self.problem_size)
#     #     if done:
#     #         reward = -self._get_travel_distance()  # note the minus sign!
#     #     else:
#     #         reward = None

#     #     return self.step_state, reward, done

#     def step(self, selected):
#         """
#         执行环境的一步
        
#         Args:
#             selected: 选择的节点，形状为 (batch, pomo) 或可能具有更多维度
            
#         Returns:
#             state: 更新后的状态
#             reward: 如果完成，则为负旅行距离；否则为None
#             done: 是否完成
#         """
#         # 打印selected的维度以便调试
#         print(f"selected shape at beginning of step: {selected.shape}")
        
#         self.selected_count += 1
#         self.current_node = selected
        
#         # 确保current_node对后续操作有正确的维度
#         original_shape = self.current_node.shape
#         if self.current_node.dim() > 2:
#             # 对于处理selected_node_list，我们保留原始的多维current_node
#             current_node_for_list = self.current_node
            
#             # 对于其他用途，我们将其降维为2D
#             if self.current_node.dim() == 3:
#                 self.current_node = self.current_node[:, :, 0]
#             elif self.current_node.dim() == 4:
#                 self.current_node = self.current_node[:, :, 0, 0]
#             print(f"Reduced current_node from {original_shape} to {self.current_node.shape}")
#         else:
#             current_node_for_list = self.current_node
        
#         # 处理selected_node_list
#         if not hasattr(self, 'selected_node_list'):
#             # 初始化selected_node_list
#             if current_node_for_list.dim() == 2:
#                 # 对于2D current_node，创建3D列表
#                 self.selected_node_list = current_node_for_list[:, :, None]
#             else:
#                 # 对于更高维度，创建相应维度的列表
#                 self.selected_node_list = current_node_for_list.unsqueeze(2)
#             # 保存当前状态用于可能的恢复
#             self._prev_selected_node_list = self.selected_node_list.clone()
#         else:
#             # 连接到现有列表
#             try:
#                 # 尝试直接连接
#                 if current_node_for_list.dim() == self.selected_node_list.dim() - 1:
#                     # 如果current_node的维度比selected_node_list小1，添加一个维度
#                     current_node_expanded = current_node_for_list.unsqueeze(2)
#                     self.selected_node_list = torch.cat((self.selected_node_list, current_node_expanded), dim=2)
#                 elif current_node_for_list.dim() == self.selected_node_list.dim():
#                     # 如果维度相同，则检查是否需要在其他维度上扩展
#                     if current_node_for_list.shape[2:] == self.selected_node_list.shape[3:]:
#                         current_node_expanded = current_node_for_list.unsqueeze(2)
#                         self.selected_node_list = torch.cat((self.selected_node_list, current_node_expanded), dim=2)
#                     else:
#                         # 不兼容的形状，重新创建
#                         print(f"Warning: Incompatible shapes, recreating selected_node_list")
#                         self.selected_node_list = current_node_for_list.unsqueeze(2)
#                 else:
#                     # 不兼容的维度，重新创建
#                     print(f"Warning: Incompatible dimensions, recreating selected_node_list")
#                     self.selected_node_list = current_node_for_list.unsqueeze(2)
                
#                 # 保存当前状态用于可能的恢复
#                 self._prev_selected_node_list = self.selected_node_list.clone()
#             except RuntimeError as e:
#                 # 如果连接失败，打印错误并重新创建
#                 print(f"Error concatenating: {e}")
#                 print(f"selected_node_list shape: {self.selected_node_list.shape}")
#                 print(f"current_node_for_list shape: {current_node_for_list.shape}")
                
#                 # 重新初始化为所有节点的列表 - 使用problem_size而不是selected_count
#                 self.selected_node_list = torch.zeros(
#                     *current_node_for_list.shape[:2], self.problem_size, *current_node_for_list.shape[2:],
#                     device=current_node_for_list.device, dtype=current_node_for_list.dtype
#                 )
                
#                 # 复制已有的节点选择
#                 if hasattr(self, '_prev_selected_node_list') and self._prev_selected_node_list is not None:
#                     # 如果我们保存了之前的selected_node_list，复制它的内容
#                     prev_size = min(self._prev_selected_node_list.size(2), self.selected_count - 1)
#                     self.selected_node_list[:, :, :prev_size] = self._prev_selected_node_list[:, :, :prev_size]
                
#                 # 将当前节点放入最后位置
#                 if current_node_for_list.dim() == 2:
#                     self.selected_node_list[:, :, self.selected_count-1] = current_node_for_list
#                 elif current_node_for_list.dim() == 3:
#                     self.selected_node_list[:, :, self.selected_count-1, :] = current_node_for_list
#                 elif current_node_for_list.dim() == 4:
#                     self.selected_node_list[:, :, self.selected_count-1, :, :] = current_node_for_list
        
#         # 更新step_state
#         self.step_state.current_node = self.current_node
        
#         # 更新ninf_mask
#         try:
#             self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, self.current_node] = float('-inf')
#         except Exception as e:
#             print(f"Error updating ninf_mask: {e}")
#             print(f"current_node shape: {self.current_node.shape}")
#             print(f"ninf_mask shape: {self.step_state.ninf_mask.shape}")
        
#         # 返回值
#         done = (self.selected_count == self.problem_size)
#         if done:
#             # 确保在计算旅行距离前selected_node_list的大小正确
#             if self.selected_node_list.size(2) != self.problem_size:
#                 print(f"Warning: selected_node_list size {self.selected_node_list.size(2)} doesn't match problem_size {self.problem_size}")
#                 # 如果大小不匹配，调整它
#                 if self.selected_node_list.size(2) < self.problem_size:
#                     # 如果太小，填充它
#                     padded = torch.zeros(
#                         *self.selected_node_list.shape[:2], self.problem_size, *self.selected_node_list.shape[3:],
#                         device=self.selected_node_list.device, dtype=self.selected_node_list.dtype
#                     )
#                     padded[:, :, :self.selected_node_list.size(2)] = self.selected_node_list
#                     self.selected_node_list = padded
#                 else:
#                     # 如果太大，截断它
#                     self.selected_node_list = self.selected_node_list[:, :, :self.problem_size]
            
#             reward = -self._get_travel_distance()  # 注意负号！
#         else:
#             reward = None
        
#         return self.step_state, reward, done


#     def _get_travel_distance(self):
#         """
#         计算旅行距离
        
#         Returns:
#             所选节点路径的总旅行距离
#         """
#         # 获取当前selected_node_list的实际大小
#         actual_size = self.selected_node_list.size(2)
        
#         # 打印调试信息
#         print(f"_get_travel_distance - selected_node_list shape: {self.selected_node_list.shape}")
#         print(f"_get_travel_distance - problems shape: {self.problems.shape}")
        
#         # 使用实际大小而不是self.problem_size
#         try:
#             gathering_index = self.selected_node_list.unsqueeze(3).expand(self.batch_size, self.pomo_size, actual_size, 2)
#             # shape: (batch, pomo, actual_size, 2)
            
#             # 确保problems的第二个维度与actual_size匹配或可以扩展
#             if self.problems.size(1) >= actual_size:
#                 # 如果problems的大小大于等于actual_size，取前actual_size个元素
#                 seq_expanded = self.problems[:, None, :actual_size, :].expand(self.batch_size, self.pomo_size, actual_size, 2)
#             else:
#                 # 如果problems的大小小于actual_size，使用填充
#                 padded_problems = torch.zeros(self.batch_size, actual_size, 2, device=self.problems.device)
#                 padded_problems[:, :self.problems.size(1), :] = self.problems
#                 seq_expanded = padded_problems[:, None, :, :].expand(self.batch_size, self.pomo_size, actual_size, 2)
            
#             # 获取排序后的序列
#             ordered_seq = seq_expanded.gather(dim=2, index=gathering_index)
#             # shape: (batch, pomo, actual_size, 2)
            
#             # 计算段长度
#             rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
#             segment_lengths = ((ordered_seq-rolled_seq)**2).sum(3).sqrt()
#             # shape: (batch, pomo, actual_size)
            
#             # 计算总距离
#             travel_distances = segment_lengths.sum(2)
#             # shape: (batch, pomo)
            
#             print(f"_get_travel_distance - travel_distances shape: {travel_distances.shape}")
#             return travel_distances
#         except RuntimeError as e:
#             # 如果上述方法失败，尝试更直接的方法
#             print(f"Error in _get_travel_distance: {e}")
#             print(f"Trying alternative method...")
            
#             # 手动计算旅行距离
#             travel_distances = torch.zeros(self.batch_size, self.pomo_size, device=self.selected_node_list.device)
            
#             for batch_idx in range(self.batch_size):
#                 for pomo_idx in range(self.pomo_size):
#                     total_distance = 0.0
#                     for i in range(actual_size):
#                         # 当前节点和下一个节点（循环回到第一个）
#                         current_idx = self.selected_node_list[batch_idx, pomo_idx, i].item()
#                         next_idx = self.selected_node_list[batch_idx, pomo_idx, (i+1) % actual_size].item()
                        
#                         # 确保索引在有效范围内
#                         if current_idx < self.problems.size(1) and next_idx < self.problems.size(1):
#                             # 获取坐标
#                             current_xy = self.problems[batch_idx, current_idx]
#                             next_xy = self.problems[batch_idx, next_idx]
                            
#                             # 计算距离
#                             distance = torch.sqrt(((current_xy - next_xy) ** 2).sum())
#                             total_distance += distance
                    
#                     travel_distances[batch_idx, pomo_idx] = total_distance
            
#             return travel_distances


