from dataclasses import dataclass
import torch

from CVRProblemDef import get_random_problems, augment_xy_data_by_8_fold


@dataclass
class Reset_State:
    depot_xy: torch.Tensor = None
    # shape: (batch, 1, 2)
    node_xy: torch.Tensor = None
    # shape: (batch, problem, 2)
    node_demand: torch.Tensor = None
    # shape: (batch, problem)


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor = None
    POMO_IDX: torch.Tensor = None
    # shape: (batch, pomo)
    selected_count: int = None
    load: torch.Tensor = None
    # shape: (batch, pomo)
    current_node: torch.Tensor = None
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None
    old_ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, problem+1)
    finished: torch.Tensor = None
    probs = None
    distances = None
    is_encoded_nodes = None
    reembeded_nodes = None

    # shape: (batch, pomo)

    def get_closest_nodes(self, last_node_index, k):
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))
        _, topk_index = torch.topk(current_dis_list, k)
        return topk_index

    def get_closest_is(self, last_node_index, number_of_iss, k=1):
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))
        _, topk_index = torch.topk(current_dis_list[:, :, -number_of_iss:], k)
        return topk_index

    def get_closest_node_list(self, last_node_index, ninf_mask):
        k = 20
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))

        current_dis_list = current_dis_list + ninf_mask
        _, topk_index = torch.topk(current_dis_list, k)
        return topk_index


class CVRPEnv:
    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']
        self.pomo_size = env_params['pomo_size']
        self.intermediate_stop_size = env_params['intermediate_stop_size']
        self.bucket_size = env_params['bucket_size']
        self.number_of_locality_att = env_params['number_of_locality_att']


        self.FLAG__use_saved_problems = False
        self.saved_depot_xy = None
        self.saved_node_xy = None
        self.saved_is_xy = None
        self.saved_node_demand = None
        self.saved_index = None

        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.POMO_IDX = None
        # IDX.shape: (batch, pomo)
        self.depot_node_xy = None
        # shape: (batch, problem+1, 2)
        self.depot_node_demand = None
        self.extended_is_embedding = None
        # shape: (batch, problem+1)

        # Dynamic-1
        ####################################
        self.selected_count = None
        self.current_node = None
        self.last_node_index = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = None
        # shape: (batch, pomo)
        self.load = None
        # shape: (batch, pomo)
        self.distance_can_travel = None
        self.distance_from_all_nodes_to_depot = None
        # shape: (batch, pomo)
        self.visited_ninf_flag = None
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = None
        # shape: (batch, pomo, problem+1)
        self.finished = None
        self.distances = None
        self.whole_theta = None
        self.whole_dist = None
        self.probs = None
        # shape: (batch, pomo)

        # states to return
        ####################################
        self.reset_state = Reset_State()
        self.step_state = Step_State()

    def use_saved_problems(self, filename, device):
        self.FLAG__use_saved_problems = True

        loaded_dict = torch.load(filename, map_location=device)
        self.saved_depot_xy = loaded_dict['depot_xy']
        self.saved_node_xy = loaded_dict['node_xy']
        self.saved_is_xy = loaded_dict['is_xy']
        self.saved_node_demand = loaded_dict['node_demand']
        self.saved_index = 0

    def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size

        if not self.FLAG__use_saved_problems:
            depot_xy, node_xy, is_xy, node_demand = get_random_problems(batch_size, self.problem_size,
                                                                        self.intermediate_stop_size)
        else:
            depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index + batch_size]
            node_xy = self.saved_node_xy[self.saved_index:self.saved_index + batch_size]
            is_xy = self.saved_is_xy[self.saved_index:self.saved_index + batch_size]
            node_demand = self.saved_node_demand[self.saved_index:self.saved_index + batch_size]
            self.saved_index += batch_size

        if aug_factor > 1:
            if aug_factor == 8:
                self.batch_size = self.batch_size * 8
                depot_xy = augment_xy_data_by_8_fold(depot_xy)
                node_xy = augment_xy_data_by_8_fold(node_xy)
                is_xy = augment_xy_data_by_8_fold(is_xy)
                node_demand = node_demand.repeat(8, 1)
            else:
                raise NotImplementedError
        self.depot_xy = depot_xy
        self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
        self.depot_node_is_xy = torch.cat((depot_xy, node_xy, is_xy), dim=1)
        self.is_xy = is_xy
        self.max_travel_distance = torch.ones(size=(self.batch_size, 1)) * 4
        depot_demand = torch.zeros(size=(self.batch_size, 1))
        is_demand = torch.zeros(size=(self.batch_size, is_xy.size(1)))
        self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
        self.depot_node_is_demand = torch.cat((depot_demand, node_demand, is_demand), dim=1)
        # shape: (batch, problem+1)

        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)
        self.IS_IDX = torch.arange(self.problem_size + 1, self.problem_size + self.intermediate_stop_size + 1)[None, :].expand(
            self.batch_size, self.intermediate_stop_size)

        self.reset_state.depot_xy = depot_xy
        self.reset_state.is_xy = is_xy
        self.reset_state.node_xy = node_xy
        self.reset_state.whole_nodes = torch.cat((depot_xy, node_xy, is_xy), -2)
        self.reset_state.node_demand = node_demand

        self.step_state.Is_IDX = self.IS_IDX
        self.step_state.BATCH_IDX = self.BATCH_IDX
        self.step_state.POMO_IDX = self.POMO_IDX
        diff = self.depot_node_is_xy.unsqueeze(2) - self.depot_node_is_xy.unsqueeze(1)
        self.step_state.distances = - torch.norm(diff, dim=-1)
        self.reset_state.distances = self.step_state.distances

        diff = self.depot_node_xy.unsqueeze(2) - self.depot_node_xy.unsqueeze(1)
        self.step_state.node_distances = - torch.norm(diff, dim=-1)

        diff = self.is_xy.unsqueeze(2) - self.is_xy.unsqueeze(1)
        self.step_state.is_distances = - torch.norm(diff, dim=-1)

        factor = 10_00
        diff2 = self.depot_node_is_xy.unsqueeze(2) - self.depot_xy.unsqueeze(1)
        distance_from_all_nodes_to_depot = torch.norm(diff2, dim=-1).squeeze(-1)[:, None, :].expand(self.batch_size, self.pomo_size, self.depot_node_is_xy.size(1))
        self.distance_from_all_nodes_to_depot = (distance_from_all_nodes_to_depot * factor).int().float() / factor

        shift_coors = self.depot_node_is_xy - self.depot_node_is_xy[:, 0:1, :]
        _x, _y = shift_coors[:, :, 0], shift_coors[:, :, 1]
        r = torch.sqrt(_x ** 2 + _y ** 2)
        theta = torch.atan2(_y, _x)
        x = torch.stack((r, theta, self.depot_node_is_demand), -1)
        x_depot = x[:, 0:1, :]
        x_node = x[:, 1:-self.intermediate_stop_size, :]
        x_is = x[:, -self.intermediate_stop_size:, 0:2]
        shift_coors_without_is = self.depot_node_xy - self.depot_node_xy[:, 0:1, :]
        self.reset_state.shift_coors = shift_coors_without_is
        self.reset_state.x_depot = x_depot
        self.reset_state.x_node = x_node
        self.reset_state.x_is = x_is

        r_min = r[:,1:-self.intermediate_stop_size].min(dim=-1, keepdim=True).values
        r_max = r[:,1:-self.intermediate_stop_size].max(dim=-1, keepdim=True).values
        r_norm = (r[:,1:-self.intermediate_stop_size] - r_min) / (r_max - r_min + 1e-8)
        theta2 = theta % (2 * torch.pi)
        theta_min = theta2[:,1:-self.intermediate_stop_size].min(dim=-1, keepdim=True).values
        theta_max = theta2[:,1:-self.intermediate_stop_size].max(dim=-1, keepdim=True).values
        theta_norm = (theta2[:,1:-self.intermediate_stop_size] - theta_min) / (theta_max - theta_min + 1e-8)

        weights = torch.linspace(0, 1, steps=self.number_of_locality_att)  # shape: (n,)
        weight_A = weights.view(1, self.number_of_locality_att, 1)  # shape: (n, 1, 1)
        weight_B = (1 - weights).view(1, self.number_of_locality_att, 1)  # shape: (n, 1, 1)

        # Step 2: Apply weights and sum A and B
        # Expand A and B to shape: (1, 2, 111) → broadcast to (n, 2, 111)
        A_exp = r_norm.unsqueeze(1)  # (1, 2, 111)
        B_exp = theta_norm.unsqueeze(1)
        output = (weight_A * A_exp + weight_B * B_exp)  # shape: (n, 2, 111)
        _, sorted_indices = torch.sort(output, dim=-1)
        half_bucket = int(self.bucket_size/2)
        x_reordered = torch.cat((sorted_indices[:, :, half_bucket:], sorted_indices[:, :, :half_bucket]), dim=-1)
        sorted_indices = torch.cat((sorted_indices, x_reordered), dim=1)

        total = sorted_indices.shape[-1]
        num_buckets = total // self.bucket_size
        sorted_indices = sorted_indices[:, :, :num_buckets * self.bucket_size]  # remove extra
        bucketed_indices = sorted_indices.view(self.batch_size, self.number_of_locality_att*2, num_buckets, self.bucket_size) +1
        zero_padding = torch.zeros(bucketed_indices.size(0), bucketed_indices.size(1), bucketed_indices.size(2), 1, dtype=bucketed_indices.dtype, device=bucketed_indices.device)
        bucketed_indices = torch.cat([zero_padding, bucketed_indices], dim=-1)
        # shift_coors = self.depot_node_is_xy.unsqueeze(2) - self.depot_node_is_xy.unsqueeze(1)
        # r = torch.sqrt(torch.sum(shift_coors ** 2, dim=-1))  # Shape: (2, 100, 100)
        # theta = torch.atan2(shift_coors[..., 1], shift_coors[..., 0])
        self.reset_state.buckets = bucketed_indices
        self.reset_state.sorted_indices = sorted_indices
        self.step_state.xy = self.depot_node_is_xy
        self.step_state.theta = theta
        self.step_state.r = r
        self.step_state.norm_demand = self.depot_node_is_demand
        # self.step_state.whole_theta = theta[:, None, :, :].expand(theta.size(0), self.pomo_size, theta.size(1),
        #                                                           theta.size(2))
        # self.step_state.whole_dist = r[:, None, :, :].expand(r.size(0), self.pomo_size, r.size(1), r.size(2))

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        self.last_node_index = torch.zeros((self.batch_size, self.pomo_size, 1), dtype=torch.long)
        # shape: (batch, pomo)
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~)

        self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size + self.intermediate_stop_size + 1),
                                       dtype=torch.bool)
        # shape: (batch, pomo)
        self.load = torch.ones(size=(self.batch_size, self.pomo_size))
        # shape: (batch, pomo)
        self.remaining_travel_distance = torch.ones(size=(self.batch_size, self.pomo_size))

        # shape: (batch, pomo)
        self.visited_ninf_flag = torch.zeros(
            size=(self.batch_size, self.pomo_size, self.problem_size + self.intermediate_stop_size + 1))
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size + 1))
        # shape: (batch, pomo, problem+1)
        self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)

        self.distances = None
        self.distance_can_travel = None
        # shape: (batch, pomo)

        reward = None
        done = False
        return self.reset_state, reward, done

    def compute_sparse_attention_mask(self, coords, top_k=20):
        batch_size, num_nodes, _ = coords.shape

        # Compute pairwise Euclidean distance (batch-wise)
        coords_exp = coords.unsqueeze(2)  # Shape: (batch_size, num_nodes, 1, 2)
        distances = torch.norm(coords_exp - coords.unsqueeze(1), dim=-1)  # (batch_size, num_nodes, num_nodes)

        # Get indices of the top_k closest nodes (excluding itself)
        top_k_indices = distances.argsort(dim=-1)[:, :, 1:top_k + 1]  # Shape: (batch_size, num_nodes, top_k)

        # Create a full mask initialized to -inf
        mask = torch.full((batch_size, num_nodes, num_nodes), True, device=coords.device)

        # Create an index mask to select top_k indices without looping
        batch_indices = torch.arange(batch_size, device=coords.device).view(batch_size, 1, 1)
        node_indices = torch.arange(num_nodes, device=coords.device).view(1, num_nodes, 1)

        # Assign 0 to the top_k closest neighbors
        mask[batch_indices, node_indices, top_k_indices] = False
        mask[:, 0, :] = False
        mask[:, :, 0] = False
        # mask[:, -self.intermediate_stop_size:, :] = 0
        # mask[:, :, -self.intermediate_stop_size:] = 0
        return mask, top_k_indices

    def pre_step(self):
        self.step_state.selected_count = self.selected_count
        self.step_state.load = self.load
        self.step_state.remaining_travel_distance = self.remaining_travel_distance
        self.step_state.current_node = self.current_node
        self.step_state.ninf_mask = self.ninf_mask
        self.step_state.finished = self.finished
        self.step_state.if_cannot_make_travel_finished = None
        self.step_state.selected = None
        self.step_state.current_charge = None
        self.step_state.distance = None
        self.step_state.distance = None
        self.step_state.distance = None
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, selected):
        # selected.shape: (batch, pomo)

        # Dynamic-1
        ####################################
        self.selected_count += 1
        self.current_node = selected
        last_xy = torch.gather(self.depot_node_is_xy, 1,
                               self.last_node_index.expand(self.last_node_index.size(0), self.last_node_index.size(1),
                                                           2))
        current_xy = torch.gather(self.depot_node_is_xy, 1,
                                  selected[..., None].expand(selected.size(0), selected.size(1), 2))
        factor = 10_00
        try:
            traveled_distance = torch.norm(last_xy - current_xy, dim=2)
        except:
            print("sdfasfdasf")
        traveled_distance = (traveled_distance * factor).int().float() / factor
        normalized_travel_distance = traveled_distance / self.max_travel_distance
        # required_energy_for_last_move = (required_energy_for_last_move * factor).int().float() / factor
        self.remaining_travel_distance -= normalized_travel_distance

        self.last_node_index = selected.clone().unsqueeze(-1)
        # shape: (batch, pomo)
        self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = (selected == 0)
        self.at_is = torch.any(selected.unsqueeze(-1) == self.IS_IDX.unsqueeze(1), dim=-1)

        demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
        demand_list2 = self.depot_node_is_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
        # shape: (batch, pomo, problem+1)
        gathering_index = selected[:, :, None]
        # shape: (batch, pomo, 1)
        selected_demand = demand_list2.gather(dim=2, index=gathering_index).squeeze(dim=2)
        # shape: (batch, pomo)
        self.load -= selected_demand
        self.load[self.at_the_depot] = 1  # refill loaded at the depot
        self.remaining_travel_distance[self.at_the_depot] = 1
        self.load[self.at_is] = 1
        # ---------------------------------------------------------------------

        diff = current_xy.unsqueeze(2) - self.depot_node_is_xy.unsqueeze(1)  # Shape: (2, 100, 106, 2)
        distance_from_current_to_all_others = torch.norm(diff, dim=-1)
        distance_from_current_to_all_others = (distance_from_current_to_all_others * factor).int().float() / factor

        try:
            distance_to_next_node_and_depot_norm = (distance_from_current_to_all_others + self.distance_from_all_nodes_to_depot) / self.max_travel_distance[:, None]
        except:
            self.distance_from_all_nodes_to_depot = self.distance_from_all_nodes_to_depot[:,0:1,:].expand_as(distance_from_current_to_all_others)
            distance_to_next_node_and_depot_norm = (distance_from_current_to_all_others + self.distance_from_all_nodes_to_depot) / self.max_travel_distance[
                                                                                                                                              :,
                                                                                                                                              None]

        if_cannot_make_travel_finished1 = (distance_to_next_node_and_depot_norm > self.remaining_travel_distance.unsqueeze(-1) - 0.0001)
        if_cannot_make_travel_finished2 = (distance_to_next_node_and_depot_norm > self.remaining_travel_distance.unsqueeze(-1) + 0.0001)
        if_cannot_make_travel_finished = if_cannot_make_travel_finished1 & if_cannot_make_travel_finished2  # & if_cannot_make_travel_finished3 & if_cannot_make_travel_finished4
        # --------------------------------------------------

        self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
        self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, self.problem_size + 1:] = 0
        # shape: (batch, pomo, problem+1)
        self.visited_ninf_flag[:, :, 0][
            ~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot
        is_extended = self.at_is.unsqueeze(-1).expand(-1, -1, self.intermediate_stop_size)
        depot_extended = self.at_the_depot.unsqueeze(-1).expand(-1, -1, self.intermediate_stop_size)
        overall_extended = is_extended | depot_extended
        try:
            self.visited_ninf_flag[:, :, self.problem_size + 1:][overall_extended] = float('-inf')
        except:
            print("sadfsadfasdf")
        # print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")
        # for i in range(2):
        #     print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        #     for j in range(100):
        #         print(self.visited_ninf_flag[i,j,:])

        self.ninf_mask = self.visited_ninf_flag.clone()
        round_error_epsilon = 0.00001
        demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
        # shape: (batch, pomo, problem+1)
        self.ninf_mask[:, :, :self.problem_size + 1][demand_too_large] = float('-inf')
        # self.ninf_mask[self.BATCH_IDX,self.POMO_IDX, :self.pomo_size+1][demand_too_large] = float('-inf')
        self.ninf_mask[if_cannot_make_travel_finished] = float('-inf')

        # shape: (batch, pomo, problem+1)
        # if torch.all(self.ninf_mask == float('-inf')):
        #     print("All elements in tensor mask are -inf.")
        #     if torch.all(self.visited_ninf_flag== float('-inf')):
        #         print("All elements in visited_ninf_flag mask are -inf.")
        #     if torch.all(if_cannot_make_travel_finished== True):
        #         print("All elements in if_cannot_make_travel_finished mask are -inf.")

        # raise ValueError("All elements in tensor mask are -inf.")
        newly_finished = (
                self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, 0: self.pomo_size] == float('-inf')).all(
            dim=2)
        # shape: (batch, pomo)
        self.finished = self.finished + newly_finished
        # shape: (batch, pomo)

        # do not mask depot for finished episode.
        self.ninf_mask[:, :, 0][self.finished] = 0

        self.step_state.selected_count = self.selected_count
        self.step_state.load = self.load
        self.step_state.current_node = self.current_node
        self.step_state.old_ninf_mask = self.step_state.ninf_mask
        self.step_state.ninf_mask = self.ninf_mask
        self.step_state.finished = self.finished
        self.step_state.if_cannot_make_travel_finished = if_cannot_make_travel_finished
        # self.step_state.old_charge = self.step_state.charge
        self.step_state.remaining_travel_distance = self.remaining_travel_distance

        # returning values
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done

    def _get_travel_distance(self):
        gathering_index = self.selected_node_list[:, :, :, None].expand(-1, -1, -1, 2)
        # for i in range(self.selected_node_list.size(0)):
        #     for j in range(self.selected_node_list.size(1)):
        #         print(self.selected_node_list[i,j,:])
        # shape: (batch, pomo, selected_list_length, 2)
        all_xy = self.depot_node_is_xy[:, None, :, :].expand(-1, self.pomo_size, -1, -1)
        # shape: (batch, pomo, problem+1, 2)

        ordered_seq = all_xy.gather(dim=2, index=gathering_index)
        # shape: (batch, pomo, selected_list_length, 2)

        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        segment_lengths = ((ordered_seq - rolled_seq) ** 2).sum(3).sqrt()
        # shape: (batch, pomo, selected_list_length)

        travel_distances = segment_lengths.sum(2)
        # min_value, min_index = torch.min(travel_distances, 1)
        # k=-1
        # for i in min_index:
        #     k+=1
        #     print(self.selected_node_list[k, i, ...])
        # shape: (batch, pomo)
        return travel_distances
