from dataclasses import dataclass
import torch
from .problem_cvrp import ProblemCVRP
from .problem_vrptw import ProblemVRPTW
from .problem_pcvrp import ProblemPCVRP

from .instance_set import InstanceSet
import numpy as np
import math

@dataclass
class Reset_State:
    problem_feat = None
    tour_index: torch.Tensor = None
    neighbours: torch.Tensor = None
    pos_index: torch.Tensor = None
    cur_dist: torch.Tensor = None
    tour_angle: torch.Tensor = None


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor = None
    ROLLOUT_IDX: torch.Tensor = None
    selected_count: int = None
    current_node: torch.Tensor = None  # last selected node
    ninf_mask: torch.Tensor = None


class Env:
    def __init__(self, use_multiprocessing, **env_params):
        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']
        self.rollout_size = None

        if env_params['problem'] == "cvrp":
            self.problem = ProblemCVRP(self.problem_size, env_params.get('generator_params', None))
        elif env_params['problem'] == "vrptw":
            self.problem = ProblemVRPTW(self.problem_size, env_params.get('generator_params', None))
        elif env_params['problem'] == "pcvrp":
            self.problem = ProblemPCVRP(self.problem_size, env_params.get('generator_params', None))
        else:
            raise NotImplementedError

        self.num_nodes_to_remove = env_params['num_nodes_to_remove']

        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.ROLLOUT_IDX = None

        # Instance & Solution properties
        ####################################
        self.instanceSet = InstanceSet(env_params['problem'], use_multiprocessing)

        # Dynamic
        ####################################
        self.selected_count = None
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        self.ninf_mask = None
        # shape: (batch, pomo, problem+1)

        # Features & State
        ####################################
        self.step_state = Step_State()

        self.problem_data = None
        self.problem_feat = None

    def load_problem_dataset_pkl(self, filename, num_problems, index_begin=0):
        self.problem.load_problem_dataset_pkl(filename, num_problems, index_begin)

    def load_problem_dataset_pt(self, filename, device):
        self.problem.load_problem_dataset_pt(filename, device)

    def init_instances(self, nb_instances, rollout_size, device, aug_factor=1):
        self.rollout_size = rollout_size

        self.batch_size, self.problem_data, self.problem_feat = self.problem.init_problems(nb_instances, aug_factor)

        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.rollout_size)
        self.ROLLOUT_IDX = torch.arange(self.rollout_size)[None, :].expand(self.batch_size, self.rollout_size)

        self.step_state.BATCH_IDX = self.BATCH_IDX
        self.step_state.ROLLOUT_IDX = self.ROLLOUT_IDX

        self.instanceSet.init_instances(self.problem_data)
        self.get_model_input(device)

    def change_rollout_size(self, rollout_size):
        self.rollout_size = rollout_size
        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.rollout_size)
        self.ROLLOUT_IDX = torch.arange(self.rollout_size)[None, :].expand(self.batch_size, self.rollout_size)
        self.step_state.BATCH_IDX = self.BATCH_IDX
        self.step_state.ROLLOUT_IDX = self.ROLLOUT_IDX

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        self.selected_node_list = torch.zeros((self.batch_size, self.rollout_size, 0), dtype=torch.long)

        self.ninf_mask = torch.zeros(size=(self.batch_size, self.rollout_size, self.problem_size + 1))
        self.ninf_mask[:, :, 0] = float('-inf')  # prevent network from removing depot

        self.step_state.selected_count = self.selected_count
        self.step_state.current_node = self.current_node
        self.step_state.ninf_mask = self.ninf_mask

        return self.step_state

    def get_model_input(self, device):
        depot_xy = self.problem_feat.depot_xy
        node_xy = self.problem_feat.node_xy
        tours = self.instanceSet.getTours()

        device = node_xy.device
        B, N, _ = node_xy.shape

        # --- Distance matrix including depot ---
        full_xy = torch.cat([depot_xy, node_xy], dim=1)    # [B, N+1, 2], ids 0..N
        dist_full = torch.cdist(full_xy, full_xy, p=2)     # [B, N+1, N+1]

        # Vector from depot to each node
        vec = node_xy - depot_xy.expand(-1, N, -1)         # [B, N, 2]
        dx, dy = vec[..., 0], vec[..., 1]                  # [B, N], [B, N]
        theta = torch.atan2(dy, dx)                        # [-π, π]
        two_pi = torch.tensor(2 * math.pi, dtype=node_xy.dtype, device=device)
        theta = torch.remainder(theta, two_pi)             # [0, 2π)

        # Choose node id=1 (index 0) as the reference that gets angle 0 per batch
        ref = theta[:, :1]                                 # [B, 1]
        tour_angle = torch.remainder(theta - ref, two_pi)  # [B, N], in [0, 2π)

        # --- Pad ragged tours into tensor: routes[b, r, p] (pad = -1) ---
        R_max = max((len(rs) for rs in tours), default=0)
        L_max = max((len(r) for rs in tours for r in rs), default=0)
        routes = torch.full((B, R_max, L_max), -1, dtype=torch.long, device=device)

        # lightweight packing (only to build a tensor from ragged lists)
        for b, rs in enumerate(tours):
            for r, path in enumerate(rs):
                # ensure begins/ends with 0
                if not path or path[0] != 0:
                    path = [0, *path]
                if path[-1] != 0:
                    path = [*path, 0]
                L = min(len(path), L_max)
                routes[b, r, :L] = torch.as_tensor(path[:L], dtype=torch.long, device=device)

        if R_max == 0 or L_max < 3:
            # No actual customer positions to assign
            return (
                torch.full((B, N), -1, dtype=torch.long, device=device),
                torch.full((B, N, 2), -1, dtype=torch.long, device=device),
                torch.full((B, N), -1, dtype=torch.long, device=device),
                torch.zeros((B, N), dtype=node_xy.dtype, device=device),
            )

        # windows: prev, curr, next (drop the depots at ends for curr)
        curr   = routes[..., 1:-1]              # [B, R, L-2]
        prev_  = routes[..., :-2]               # [B, R, L-2]
        next_  = routes[..., 2:]                # [B, R, L-2]

        # customers only (exclude depot 0)
        curr_mask = curr.gt(0)

        # --- cumulative distance along each route (between consecutive ids) ---
        a = routes[..., :-1].clamp_min(0)       # [B, R, L-1]
        b = routes[...,  1:].clamp_min(0)       # [B, R, L-1]
        batch_idx = torch.arange(B, device=device)[:, None, None]
        edge_len = dist_full[batch_idx, a, b]   # [B, R, L-1]
        edge_csum = torch.cumsum(edge_len, dim=-1)     # [B, R, L-1]
        # for a customer at position k (1..L-2), arrival distance = csum at edge (k-1)
        cur_dist_pos = edge_csum[..., :-1]      # [B, R, L-2], aligned with curr/prev_/next_

        # --- per-position tour id and 0-based position within route (excluding first depot) ---
        R = routes.size(1)
        P = curr.size(-1)                       # L-2
        tour_id = torch.arange(R, device=device)[None, :, None].expand(B, R, P)
        pos0    = torch.arange(P, device=device)[None, None, :].expand(B, R, P)

        # --- Flatten (use reshape, not view) ---
        curr_flat = curr.reshape(B, -1)
        prev_flat = prev_.reshape(B, -1)
        next_flat = next_.reshape(B, -1)
        tour_flat = tour_id.reshape(B, -1)
        pos_flat  = pos0.reshape(B, -1)
        dist_flat = cur_dist_pos.reshape(B, -1)
        mask_flat = curr_mask.reshape(B, -1)

        # targets
        tour_index   = torch.full((B, N), -1, dtype=torch.long, device=device)
        pos_index    = torch.full((B, N), -1, dtype=torch.long, device=device)
        neighbours   = torch.full((B, N, 2), -1, dtype=torch.long, device=device)
        cur_dist_out = torch.zeros((B, N), dtype=dist_full.dtype, device=device)

        # indices for scattering
        bidx = torch.arange(B, device=device)[:, None].expand_as(curr_flat)
        node_idx = (curr_flat - 1).clamp(min=0)  # map ids 1..N -> 0..N-1

        m = mask_flat
        bi = bidx[m]
        ni = node_idx[m]

        # scatter-assign (vectorized)
        tour_index[bi, ni]   = tour_flat[m]
        pos_index[bi, ni]    = pos_flat[m]
        neighbours[bi, ni, 0] = prev_flat[m]    # prev id (0 allowed)
        neighbours[bi, ni, 1] = next_flat[m]    # next id (0 allowed)
        cur_dist_out[bi, ni]  = dist_flat[m]

        reset_state = Reset_State
        reset_state.problem_feat = self.problem_feat
        reset_state.tour_index = tour_index
        reset_state.neighbours = neighbours
        reset_state.pos_index = pos_index
        reset_state.cur_dist = cur_dist_out
        reset_state.tour_angle = tour_angle

        return reset_state


    def get_model_input_nds(self, device):
        depot_xy = self.problem_feat.depot_xy
        node_xy = self.problem_feat.node_xy

        neighbours = np.zeros((self.batch_size, self.problem_size, 2), dtype=np.int_)
        
        # tour_index contains the index of which tour the node belongs to
        tour_index = np.zeros((self.batch_size, self.problem_size), dtype=np.int_) - 1

        # distance matrix for the instance
        # node_xy.shape: (batch, problem, 2) -> (batch, problem, problem)
        # no for loop for efficiency
        depot_node_xy = torch.cat((depot_xy, node_xy), dim=1) # (batch, problem+1, 2), e.g. (batch, 501, 2)
        dist_matrix = torch.cdist(node_xy, node_xy, p=2) # (batch, problem+1, problem+1), e.g. (batch, 501, 501)

        # pos_index contains the index of which position the node belongs to in the tour
        pos_index = np.zeros((self.batch_size, self.problem_size), dtype=np.int_) - 1

        # pre_dist contains the euclidean distance from the node to the previous node int the tour
        # only this is torch because the node_xy is torch
        pre_dist = torch.zeros((self.batch_size, self.problem_size), dtype=torch.float32) - 1

        tours = self.instanceSet.getTours()

        for b_idx in range(self.batch_size):
            tour = tours[b_idx]
            tour = [[0, *t, 0] for t in tour]

            for t_i in range(len(tour)):
                for c_i in range(1, len(tour[t_i]) - 1):
                    tour_index[b_idx, tour[t_i][c_i] - 1] = t_i
                    neighbours[b_idx, tour[t_i][c_i] - 1] = tour[t_i][c_i - 1], tour[t_i][c_i + 1]
                    pos_index[b_idx, tour[t_i][c_i] - 1] = c_i - 1
                    pre_dist[b_idx, tour[t_i][c_i] - 1] = dist_matrix[b_idx, tour[t_i][c_i] - 1, tour[t_i][c_i - 1] - 1]

        reset_state = Reset_State
        reset_state.problem_feat = self.problem_feat
        reset_state.tour_index = torch.Tensor(tour_index).long() # (batch, problem), e.g. (batch, 500)
        reset_state.neighbours = torch.Tensor(neighbours).long() # (batch, problem, 2), e.g. (batch, 500, 2)
        reset_state.pos_index = torch.Tensor(pos_index).long() # (batch, problem), e.g. (batch, 500)
        reset_state.pre_dist = torch.Tensor(pre_dist).long() # (batch, problem), e.g. (batch, 500)
        return reset_state

    def step(self, selected):
        # selected.shape: (batch, pomo)
        # Dynamic-1
        ####################################
        self.selected_count += 1
        self.current_node = selected
        self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

        self.ninf_mask[self.BATCH_IDX, self.ROLLOUT_IDX, selected] = float('-inf')

        self.step_state.selected_count = self.selected_count
        self.step_state.current_node = self.current_node
        self.step_state.ninf_mask = self.ninf_mask

        done = True if self.selected_count == self.num_nodes_to_remove else False

        return self.step_state, done
