from dataclasses import dataclass
import torch

from CVRProblemDef import get_random_problems, augment_xy_data_by_8_fold


@dataclass
class Reset_State:
    depot_xy: torch.Tensor = None
    # shape: (batch, 1, 2)
    node_xy: torch.Tensor = None
    # shape: (batch, problem, 2)
    node_demand: torch.Tensor = None
    # shape: (batch, problem)


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor = None
    POMO_IDX: torch.Tensor = None
    # shape: (batch, pomo)
    selected_count: int = None
    load: torch.Tensor = None
    # shape: (batch, pomo)
    current_node: torch.Tensor = None
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None
    old_ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, problem+1)
    finished: torch.Tensor = None
    probs = None
    distances = None
    cs_encoded_nodes = None
    reembeded_nodes = None

    # shape: (batch, pomo)

    def get_closest_nodes(self, last_node_index, k):
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))
        _, topk_index = torch.topk(current_dis_list, k)
        return topk_index

    def get_closest_cs(self, last_node_index, number_of_css, k=1):
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))
        _, topk_index = torch.topk(current_dis_list[:, :, -number_of_css:], k)
        return topk_index

    def get_closest_node_list(self, last_node_index, ninf_mask):
        k = 20
        current_dis_list = torch.gather(self.distances, 1, last_node_index.unsqueeze(-1).expand(last_node_index.size(0),
                                                                                                last_node_index.size(1),
                                                                                                self.distances.size(
                                                                                                    -1)))

        current_dis_list = current_dis_list + ninf_mask
        _, topk_index = torch.topk(current_dis_list, k)
        return topk_index


class CVRPEnv:
    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']
        self.asymmetric_size = env_params['asymmetric_size']
        self.alpha = env_params['alpha']
        self.pomo_size = env_params['pomo_size']
        self.bucket_size = env_params['bucket_size']
        self.number_of_locality_att = env_params['number_of_locality_att']


        self.FLAG__use_saved_problems = False
        self.saved_depot_xy = None
        self.saved_node_xy = None
        self.saved_cs_xy = None
        self.saved_node_demand = None
        self.saved_index = None

        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.POMO_IDX = None
        # IDX.shape: (batch, pomo)
        self.depot_node_xy = None
        # shape: (batch, problem+1, 2)
        self.depot_node_demand = None
        self.extended_cs_embedding = None
        # shape: (batch, problem+1)

        # Dynamic-1
        ####################################
        self.selected_count = None
        self.current_node = None
        self.last_node_index = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = None
        # shape: (batch, pomo)
        self.load = None
        # shape: (batch, pomo)
        self.charge = None
        # shape: (batch, pomo)
        self.visited_ninf_flag = None
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = None
        # shape: (batch, pomo, problem+1)
        self.finished = None
        self.distances = None
        self.whole_theta = None
        self.whole_dist = None
        self.probs = None
        # shape: (batch, pomo)

        # states to return
        ####################################
        self.reset_state = Reset_State()
        self.step_state = Step_State()

    def use_saved_problems(self, filename, device):
        self.FLAG__use_saved_problems = True

        loaded_dict = torch.load(filename, map_location=device)
        self.saved_depot_xy = loaded_dict['depot_xy']
        self.saved_node_xy = loaded_dict['node_xy']
        self.saved_node_demand = loaded_dict['node_demand']
        self.saved_distances = loaded_dict['distances']
        self.saved_index = 0

    def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size

        if not self.FLAG__use_saved_problems:
            depot_xy, node_xy, node_demand, distances = get_random_problems(batch_size, self.problem_size, alpha=self.alpha, n = self.asymmetric_size)
        else:
            depot_xy = self.saved_depot_xy[self.saved_index:self.saved_index + batch_size]
            node_xy = self.saved_node_xy[self.saved_index:self.saved_index + batch_size]
            node_demand = self.saved_node_demand[self.saved_index:self.saved_index + batch_size]
            distances = self.saved_distances[self.saved_index:self.saved_index + batch_size]
            self.saved_index += batch_size

        if aug_factor > 1:
            if aug_factor == 8:
                self.batch_size = self.batch_size * 8
                depot_xy = augment_xy_data_by_8_fold(depot_xy)
                node_xy = augment_xy_data_by_8_fold(node_xy)
                node_demand = node_demand.repeat(8, 1)
                distances = distances.repeat(8, 1, 1)
            else:
                raise NotImplementedError

        self.depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)

        # shape: (batch, problem+1, 2)
        depot_demand = torch.zeros(size=(self.batch_size, 1))
        # shape: (batch, 1)
        self.depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
        # shape: (batch, problem+1)

        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

        self.reset_state.depot_xy = depot_xy
        self.reset_state.node_xy = node_xy
        self.reset_state.whole_nodes = torch.cat((depot_xy, node_xy), -2)
        self.reset_state.node_demand = node_demand
        self.step_state.BATCH_IDX = self.BATCH_IDX
        self.step_state.POMO_IDX = self.POMO_IDX
        diff = self.depot_node_xy.unsqueeze(2) - self.depot_node_xy.unsqueeze(1)
        # self.step_state.distances = - torch.norm(diff, dim=-1)
        self.step_state.distances = distances
        self.distances = distances
        self.reset_state.distances = distances
        self.step_state.node_distances = - self.step_state.distances

        shift_coors = self.depot_node_xy - self.depot_node_xy[:, 0:1, :]
        _x, _y = shift_coors[:, :, 0], shift_coors[:, :, 1]
        r = torch.sqrt(_x ** 2 + _y ** 2)
        theta = torch.atan2(_y, _x)
        x = torch.stack((r, theta, self.depot_node_demand), -1)
        x_depot = x[:, 0:1, :]
        x_node = x[:, 1:, :]
        self.reset_state.x_depot = x_depot
        self.reset_state.x_node = x_node
        self.reset_state.shift_coors = shift_coors
        r_min = r[:,1:].min(dim=-1, keepdim=True).values
        r_max = r[:,1:].max(dim=-1, keepdim=True).values
        r_norm = (r[:,1:] - r_min) / (r_max - r_min + 1e-8)
        theta2 = theta % (2 * torch.pi)
        theta_min = theta2[:,1:].min(dim=-1, keepdim=True).values
        theta_max = theta2[:,1:].max(dim=-1, keepdim=True).values
        theta_norm = (theta2[:,1:] - theta_min) / (theta_max - theta_min + 1e-8)

        weights = torch.linspace(0, 1, steps=self.number_of_locality_att)  # shape: (n,)
        weight_A = weights.view(1, self.number_of_locality_att, 1)  # shape: (n, 1, 1)
        weight_B = (1 - weights).view(1, self.number_of_locality_att, 1)  # shape: (n, 1, 1)

        # Step 2: Apply weights and sum A and B
        # Expand A and B to shape: (1, 2, 111) → broadcast to (n, 2, 111)
        A_exp = r_norm.unsqueeze(1)  # (1, 2, 111)
        B_exp = theta_norm.unsqueeze(1)
        output = (weight_A * A_exp + weight_B * B_exp)  # shape: (n, 2, 111)
        _, sorted_indices = torch.sort(output, dim=-1)
        half_bucket = int(self.bucket_size/2)
        x_reordered = torch.cat((sorted_indices[:, :, half_bucket:], sorted_indices[:, :, :half_bucket]), dim=-1)
        sorted_indices = torch.cat((sorted_indices, x_reordered), dim=1)
        # Step 5: Bucket sorted indices
        total = sorted_indices.shape[-1]
        num_buckets = total // self.bucket_size
        sorted_indices = sorted_indices[:, :, :num_buckets * self.bucket_size]  # remove extra
        bucketed_indices = sorted_indices.view(self.batch_size, self.number_of_locality_att*2, num_buckets, self.bucket_size) +1
        zero_padding = torch.zeros(bucketed_indices.size(0), bucketed_indices.size(1), bucketed_indices.size(2), 1, dtype=bucketed_indices.dtype, device=bucketed_indices.device)
        bucketed_indices = torch.cat([zero_padding, bucketed_indices], dim=-1)
        self.reset_state.buckets = bucketed_indices
        self.reset_state.sorted_indices = sorted_indices
        self.step_state.theta = theta
        self.step_state.r = r
        self.step_state.norm_demand = self.depot_node_demand

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        self.last_node_index = torch.zeros((self.batch_size, self.pomo_size, 1), dtype=torch.long)
        # shape: (batch, pomo)
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~)

        self.at_the_depot = torch.ones(size=(self.batch_size, self.pomo_size + 1),
                                       dtype=torch.bool)
        # shape: (batch, pomo)
        self.load = torch.ones(size=(self.batch_size, self.pomo_size))
        # shape: (batch, pomo)
        self.charge = torch.ones(size=(self.batch_size, self.pomo_size))

        # shape: (batch, pomo)
        self.visited_ninf_flag = torch.zeros(
            size=(self.batch_size, self.pomo_size, self.problem_size + 1))
        # shape: (batch, pomo, problem+1)
        self.ninf_mask = torch.zeros(size=(self.batch_size, self.pomo_size, self.problem_size + 1))
        # shape: (batch, pomo, problem+1)
        self.finished = torch.zeros(size=(self.batch_size, self.pomo_size), dtype=torch.bool)

        # self.distances = None
        # shape: (batch, pomo)

        reward = None
        done = False
        return self.reset_state, reward, done

    def compute_sparse_attention_mask(self, coords, top_k=20):
        batch_size, num_nodes, _ = coords.shape

        # Compute pairwise Euclidean distance (batch-wise)
        coords_exp = coords.unsqueeze(2)  # Shape: (batch_size, num_nodes, 1, 2)
        distances = torch.norm(coords_exp - coords.unsqueeze(1), dim=-1)  # (batch_size, num_nodes, num_nodes)

        # Get indices of the top_k closest nodes (excluding itself)
        top_k_indices = distances.argsort(dim=-1)[:, :, 1:top_k + 1]  # Shape: (batch_size, num_nodes, top_k)

        # Create a full mask initialized to -inf
        mask = torch.full((batch_size, num_nodes, num_nodes), True, device=coords.device)

        # Create an index mask to select top_k indices without looping
        batch_indices = torch.arange(batch_size, device=coords.device).view(batch_size, 1, 1)
        node_indices = torch.arange(num_nodes, device=coords.device).view(1, num_nodes, 1)

        # Assign 0 to the top_k closest neighbors
        mask[batch_indices, node_indices, top_k_indices] = False
        mask[:, 0, :] = False
        mask[:, :, 0] = False
        # mask[:, -self.charging_station_size:, :] = 0
        # mask[:, :, -self.charging_station_size:] = 0
        return mask, top_k_indices

    def pre_step(self):
        self.step_state.selected_count = self.selected_count
        self.step_state.load = self.load
        self.step_state.charge = self.charge
        self.step_state.old_charge = self.charge
        self.step_state.current_node = self.current_node
        self.step_state.ninf_mask = self.ninf_mask
        self.step_state.finished = self.finished
        self.step_state.if_cannot_make_travel_finished = None
        self.step_state.selected = None
        self.step_state.current_charge = None
        self.step_state.distance = None
        self.step_state.distance = None
        self.step_state.distance = None
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, selected):
        # selected.shape: (batch, pomo)

        # Dynamic-1
        ####################################
        self.selected_count += 1
        self.current_node = selected
        # shape: (batch, pomo)
        self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
        # shape: (batch, pomo, 0~)

        # Dynamic-2
        ####################################
        self.at_the_depot = (selected == 0)

        demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
        # shape: (batch, pomo, problem+1)
        gathering_index = selected[:, :, None]
        # shape: (batch, pomo, 1)
        selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
        # shape: (batch, pomo)
        self.load -= selected_demand
        self.load[self.at_the_depot] = 1  # refill loaded at the depot

        self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
        # shape: (batch, pomo, problem+1)
        self.visited_ninf_flag[:, :, 0][
            ~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

        self.ninf_mask = self.visited_ninf_flag.clone()
        round_error_epsilon = 0.00001
        demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
        # shape: (batch, pomo, problem+1)
        self.ninf_mask[demand_too_large] = float('-inf')
        # shape: (batch, pomo, problem+1)

        newly_finished = (self.visited_ninf_flag == float('-inf')).all(dim=2)
        # shape: (batch, pomo)
        self.finished = self.finished + newly_finished
        # shape: (batch, pomo)

        # do not mask depot for finished episode.
        self.ninf_mask[:, :, 0][self.finished] = 0

        self.step_state.selected_count = self.selected_count
        self.step_state.load = self.load
        self.step_state.current_node = self.current_node
        self.step_state.ninf_mask = self.ninf_mask
        self.step_state.finished = self.finished

        # returning values
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done

    def _get_travel_distance(self):

        # gathering_index = self.selected_node_list[:, :, :, None].expand(-1, -1, -1, 2)
        # all_xy = self.depot_node_xy[:, None, :, :].expand(-1, self.pomo_size, -1, -1)
        # # shape: (batch, pomo, problem+1, 2)
        #
        # ordered_seq = all_xy.gather(dim=2, index=gathering_index)
        # # shape: (batch, pomo, selected_list_length, 2)
        #
        # rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        # segment_lengths = ((ordered_seq - rolled_seq) ** 2).sum(3).sqrt()
        #
        #
        # travel_distances = segment_lengths.sum(2)
        index = self.selected_node_list  # (batch, pomo, tour_len)

        # Prepare indices to gather distances between consecutive nodes in the tour
        from_index = index
        to_index = index.roll(dims=2, shifts=-1)  # shifted version for next node in each tour

        # Expand dimensions to gather from self.distances
        batch_idx = torch.arange(index.size(0))[:, None, None].expand_as(index)  # (batch, pomo, tour_len)

        # Gather distances
        distances = self.distances[batch_idx, from_index, to_index]  # (batch, pomo, tour_len)

        # Sum the distances along the tour
        travel_distances = distances.sum(2)  # (batch, pomo)

        return travel_distances
