
import torch
from logging import getLogger

from .env import Env
from .model import Model

from torch.optim import Adam as Optimizer
from torch.optim.lr_scheduler import MultiStepLR as Scheduler
from torch.optim.lr_scheduler import LinearLR, ConstantLR, SequentialLR

from .logging_utils import *

from .validator import Validator
import itertools

import wandb


class Trainer:
    def __init__(self,
                 env_params,
                 model_params,
                 optimizer_params,
                 trainer_params,
                 logger_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.optimizer_params = optimizer_params
        self.trainer_params = trainer_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()
        self.result_log = LogData()

        # cuda
        USE_CUDA = self.trainer_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.trainer_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            self.device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')

        # Main Components
        self.model = Model(**self.model_params)
        self.model_frozen = Model(**self.model_params)
        self.env = Env(True, **self.env_params)
        self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
        
        # Initialize scheduler with warmup support
        self.scheduler = self._create_scheduler()
        self.scaler = torch.cuda.amp.GradScaler()

        # Restore
        self.start_epoch = 1
        self.wandb_run_id = None

        model_load = trainer_params['model_load']
        if model_load['enable']:
            checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
            # Use partial loading to handle architecture changes (e.g., TourLayer modifications)
            partial_load = model_load.get('partial_load', True)
            if partial_load:
                checkpoint = self._load_checkpoint_partial(checkpoint_fullname, strict=False)
            else:
                checkpoint = torch.load(checkpoint_fullname, map_location=self.device, weights_only=False)
                self.model.load_state_dict(checkpoint['model_state_dict'])
            
            # Load optimizer state with partial compatibility (if enabled)
            load_optimizer_state = model_load.get('load_optimizer_state', True)
            if load_optimizer_state:
                self._load_optimizer_state_partial(checkpoint['optimizer_state_dict'])
            else:
                self.logger.info('Skipping optimizer state loading (disabled in config)')
            
            # if 'scheduler_state_dict' in checkpoint:
            #     self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            # self.logger.info('Saved model loaded!')

        if os.path.isfile(self.result_folder + '/latest_model.pt'):
            checkpoint_fullname = self.result_folder + '/latest_model.pt'
            # For resume, we typically want strict loading since it's the same architecture
            checkpoint = torch.load(checkpoint_fullname, map_location=self.device, weights_only=False)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.start_epoch = 1 + checkpoint['epoch']
            self.result_log.set_raw_data(checkpoint['result_log'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            # if 'scheduler_state_dict' in checkpoint:
            #     self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            # else:
            #     # Fallback for old checkpoints without scheduler state
            #     self.scheduler.last_epoch = checkpoint['epoch'] - 1
            self.wandb_run_id = checkpoint['wandb_run_id']
            self.logger.info('Resuming training run!')

        # utility
        self.time_estimator = TimeEstimator()

        # Validator
        self.validator = Validator(self.device, self.env_params, self.trainer_params, self.model_params, logger_params)

        # Training parameters
        self.batch_size = self.trainer_params['train_batch_size']
        self.rollout_size = self.trainer_params['rollout_size']

        self.binary_string_pool = torch.Tensor(
            [list(i) for i in itertools.product([0, 1], repeat=model_params['z_dim'])])

        # Logging
        self.use_wandb = logger_params['wandb']['enable']
        if self.use_wandb:
            run = wandb.init(
                project=logger_params["wandb"]['project'],
                name=logger_params["desc"],
                entity="cbhua",
                config={
                    'env_params': env_params,
                    'model_params': model_params,
                    'optimizer_params': optimizer_params,
                    'trainer_params': trainer_params
                },
                id=self.wandb_run_id,
                resume='allow'
            )
            self.wandb_run_id = run.id

    def _create_scheduler(self):
        """Create scheduler with optional warmup support"""
        scheduler_params = self.optimizer_params['scheduler']
        
        # Check if warmup is configured
        if 'warmup' in scheduler_params:
            warmup_config = scheduler_params['warmup']
            warmup_steps = warmup_config['steps']
            warmup_type = warmup_config.get('type', 'linear')
            
            # Create warmup scheduler
            if warmup_type == 'linear':
                warmup_scheduler = LinearLR(
                    self.optimizer, 
                    start_factor=0.01,  # Start with 1% of the learning rate
                    end_factor=1.0, 
                    total_iters=warmup_steps
                )
            elif warmup_type == 'constant':
                warmup_scheduler = ConstantLR(
                    self.optimizer, 
                    factor=0.01,  # Use 1% of the learning rate during warmup
                    total_iters=warmup_steps
                )
            else:
                raise ValueError(f"Unsupported warmup type: {warmup_type}")
            
            # Create main scheduler
            main_scheduler = Scheduler(self.optimizer, **{k: v for k, v in scheduler_params.items() if k != 'warmup'})
            
            # Create sequential scheduler
            scheduler = SequentialLR(
                self.optimizer,
                schedulers=[warmup_scheduler, main_scheduler],
                milestones=[warmup_steps]
            )
            
            self.logger.info(f"Created scheduler with {warmup_type} warmup for {warmup_steps} steps")
        else:
            # No warmup, use regular scheduler
            scheduler = Scheduler(self.optimizer, **scheduler_params)
            self.logger.info("Created scheduler without warmup")
        
        return scheduler

    def _try_partial_parameter_load(self, checkpoint_param, current_param, param_name):
        """
        Try to load compatible parts of a parameter when shapes don't match.
        This is useful for cases like tour_combiner where we add new input dimensions.
        
        Args:
            checkpoint_param: Parameter from checkpoint
            current_param: Current model parameter
            param_name: Name of the parameter (for logging)
        
        Returns:
            bool: True if partial loading was successful, False otherwise
        """
        try:
            # Handle common cases for tour_combiner and similar linear layers
            if len(checkpoint_param.shape) == 2 and len(current_param.shape) == 2:
                # Linear layer: [out_features, in_features]
                checkpoint_out, checkpoint_in = checkpoint_param.shape
                current_out, current_in = current_param.shape
                
                # Case 1: Same output size, expanded input size (e.g., adding pos_emb)
                if checkpoint_out == current_out and checkpoint_in < current_in:
                    # Load the existing part and keep the new part as initialized
                    current_param.data[:, :checkpoint_in] = checkpoint_param.data
                    self.logger.info(f"Partial load {param_name}: loaded {checkpoint_in}/{current_in} input features")
                    return True
                
                # Case 2: Expanded output size, same input size
                elif checkpoint_out < current_out and checkpoint_in == current_in:
                    # Load the existing part and keep the new part as initialized
                    current_param.data[:checkpoint_out, :] = checkpoint_param.data
                    self.logger.info(f"Partial load {param_name}: loaded {checkpoint_out}/{current_out} output features")
                    return True
                
                # Case 3: Both dimensions expanded
                elif checkpoint_out < current_out and checkpoint_in < current_in:
                    # Load the compatible part
                    current_param.data[:checkpoint_out, :checkpoint_in] = checkpoint_param.data
                    self.logger.info(f"Partial load {param_name}: loaded {checkpoint_out}x{checkpoint_in}/{current_out}x{current_in} features")
                    return True
            
            # Handle bias vectors
            elif len(checkpoint_param.shape) == 1 and len(current_param.shape) == 1:
                checkpoint_size = checkpoint_param.shape[0]
                current_size = current_param.shape[0]
                
                if checkpoint_size < current_size:
                    # Load the existing part and keep the new part as initialized
                    current_param.data[:checkpoint_size] = checkpoint_param.data
                    self.logger.info(f"Partial load {param_name}: loaded {checkpoint_size}/{current_size} bias elements")
                    return True
            
            # If no compatible pattern found, return False
            return False
            
        except Exception as e:
            self.logger.warning(f"Failed to partially load {param_name}: {e}")
            return False

    def _load_optimizer_state_partial(self, checkpoint_optimizer_state):
        """
        Load optimizer state with partial compatibility for models with different architectures.
        This handles cases where some parameters have changed shapes.
        
        Args:
            checkpoint_optimizer_state: Optimizer state dict from checkpoint
        """
        try:
            # Try to load the optimizer state normally first
            self.optimizer.load_state_dict(checkpoint_optimizer_state)
            self.logger.info("Optimizer state loaded successfully")
        except (ValueError, RuntimeError) as e:
            self.logger.warning(f"Failed to load optimizer state directly: {e}")
            self.logger.info("Attempting partial optimizer state loading...")
            
            # Get current optimizer state
            current_optimizer_state = self.optimizer.state_dict()
            
            # Load compatible parts of the optimizer state
            loaded_state_params = 0
            skipped_state_params = 0
            
            # Load state for each parameter group
            for group_idx, (current_group, checkpoint_group) in enumerate(
                zip(current_optimizer_state['param_groups'], checkpoint_optimizer_state['param_groups'])
            ):
                # Load group settings (lr, weight_decay, etc.)
                for key in ['lr', 'weight_decay', 'eps', 'betas']:
                    if key in checkpoint_group:
                        current_group[key] = checkpoint_group[key]
                
                # Handle parameter states
                current_params = current_group['params']
                checkpoint_params = checkpoint_group['params']
                
                # Create mapping between current and checkpoint parameters
                param_mapping = {}
                for i, current_param_id in enumerate(current_params):
                    # Find matching parameter in checkpoint by comparing shapes and positions
                    if i < len(checkpoint_params):
                        param_mapping[current_param_id] = checkpoint_params[i]
                
                # Load state for each parameter
                for current_param_id, checkpoint_param_id in param_mapping.items():
                    if checkpoint_param_id in checkpoint_optimizer_state['state']:
                        current_param_state = checkpoint_optimizer_state['state'][checkpoint_param_id]
                        
                        # Check if we can load this state
                        if self._can_load_optimizer_param_state(current_param_id, current_param_state):
                            current_optimizer_state['state'][current_param_id] = current_param_state
                            loaded_state_params += 1
                        else:
                            skipped_state_params += 1
                    else:
                        skipped_state_params += 1
            
            # Load the partially updated optimizer state
            self.optimizer.load_state_dict(current_optimizer_state)
            
            self.logger.info(f"Partial optimizer state loading completed:")
            self.logger.info(f"  - Loaded parameter states: {loaded_state_params}")
            self.logger.info(f"  - Skipped parameter states: {skipped_state_params}")

    def _can_load_optimizer_param_state(self, param_id, param_state):
        """
        Check if we can load the optimizer state for a specific parameter.
        This is mainly for Adam optimizer states (exp_avg, exp_avg_sq).
        """
        try:
            current_param = None
            for group in self.optimizer.param_groups:
                for p in group['params']:
                    if id(p) == param_id:
                        current_param = p
                        break
                if current_param is not None:
                    break
            
            if current_param is None:
                return False
            
            # Check if the state tensors have compatible shapes
            for state_key, state_tensor in param_state.items():
                if isinstance(state_tensor, torch.Tensor):
                    if state_tensor.shape != current_param.shape:
                        return False
            
            return True
        except Exception:
            return False

    def _load_checkpoint_partial(self, checkpoint_path, strict=False):
        """
        Load checkpoint with partial compatibility for models with different architectures.
        This is useful when fine-tuning with modified layers (e.g., adding pos_emb to TourLayer).
        
        Args:
            checkpoint_path: Path to the checkpoint file
            strict: If True, requires exact parameter match. If False, allows partial loading.
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
        model_state_dict = checkpoint['model_state_dict']
        current_model_state_dict = self.model.state_dict()
        
        # Track which parameters were loaded and which were skipped
        loaded_params = []
        skipped_params = []
        shape_mismatch_params = []
        
        for name, param in model_state_dict.items():
            if name in current_model_state_dict:
                current_param = current_model_state_dict[name]
                
                # Check if shapes match
                if param.shape == current_param.shape:
                    current_model_state_dict[name] = param
                    loaded_params.append(name)
                else:
                    # Try to load compatible parts of the parameter
                    if self._try_partial_parameter_load(param, current_param, name):
                        shape_mismatch_params.append(f"{name}: {param.shape} -> {current_param.shape} (partial load)")
                        loaded_params.append(name)
                    else:
                        shape_mismatch_params.append(f"{name}: {param.shape} -> {current_param.shape} (full reinit)")
                        if strict:
                            raise RuntimeError(f"Shape mismatch for parameter {name}: "
                                             f"checkpoint {param.shape} vs model {current_param.shape}")
            else:
                skipped_params.append(name)
                if strict:
                    raise RuntimeError(f"Parameter {name} not found in current model")
        
        # Load the partially updated state dict
        self.model.load_state_dict(current_model_state_dict)
        
        # Log the loading results
        self.logger.info(f"Checkpoint loading completed:")
        self.logger.info(f"  - Loaded parameters: {len(loaded_params)}")
        self.logger.info(f"  - Skipped parameters: {len(skipped_params)}")
        self.logger.info(f"  - Shape mismatches: {len(shape_mismatch_params)}")
        
        if shape_mismatch_params:
            self.logger.info("Shape mismatches:")
            for mismatch in shape_mismatch_params:
                self.logger.info(f"  - {mismatch}")
        
        if skipped_params:
            self.logger.info("Skipped parameters (not in current model):")
            for skipped in skipped_params:
                self.logger.info(f"  - {skipped}")
        
        return checkpoint


    def run(self):
        self.time_estimator.reset(self.start_epoch)
        for epoch in range(self.start_epoch, self.trainer_params['epochs'] + 1):
            self.logger.info('=================================================================')

            # Train
            train_score, train_loss, cost_avg, final_costs = self._train_one_epoch(epoch)
            self.result_log.append('train_score', epoch, train_score)
            self.result_log.append('train_loss', epoch, train_loss)
            self.result_log.append('cost_avg', epoch, cost_avg)
            self.result_log.append('final_costs', epoch, final_costs)

            # LR Decay
            self.scheduler.step()

            ############################
            # Logs & Checkpoint
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
            self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
                epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str))

            all_done = (epoch == self.trainer_params['epochs'])
            model_save_interval = self.trainer_params['model_save_interval']

            # Save latest images
            if epoch > 1:
                self.logger.info("Saving log_image")
                try:
                    image_prefix = '{}/latest'.format(self.result_folder)
                    util_save_log_image_with_label(image_prefix,
                                                   self.result_log, labels=['cost_avg'])
                    util_save_log_image_with_label(image_prefix,
                                                   self.result_log, labels=['train_loss'])
                    util_save_log_image_with_label(image_prefix,
                                                   self.result_log, labels=['final_costs'])
                except Exception:
                    self.logger.info("Error creating plots")

            # Save Model
            if all_done or (epoch % model_save_interval) == 0:
                self.logger.info("Saving trained_model")
                checkpoint_dict = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'result_log': self.result_log.get_raw_data(),
                    'model_params': self.model_params,
                    'env_params': self.env_params,
                    'wandb_run_id': self.wandb_run_id
                }
                torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))

            # save latest model every epoch
            checkpoint_dict = {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'result_log': self.result_log.get_raw_data(),
                'model_params': self.model_params,
                'env_params': self.env_params,
                'wandb_run_id': self.wandb_run_id
            }
            torch.save(checkpoint_dict, '{}/latest_model.pt'.format(self.result_folder))

            ############################
            # Validation
            ############################

            # call validation code every 20 iterations
            if (epoch % 1 == 0):
                aug_score = self.validator.run(self.model, self.model_frozen, epoch)

                self.result_log.append('valid_aug_score', epoch, aug_score)
                image_prefix = '{}/latest'.format(self.result_folder)
                util_save_log_image_with_label(image_prefix, self.result_log, labels=['valid_aug_score'])

            # All-done announcement
            if all_done:
                self.logger.info(" *** Training Done *** ")
                self.logger.info("Now, printing log array...")
                util_print_log_array(self.logger, self.result_log)

    def _train_one_epoch(self, epoch):
        grad_acc_iterations = self.trainer_params['grad_acc_iterations']
        iterations_per_instance = self.env_params['iterations_per_instance']
        iterations_per_epoch = self.trainer_params['iterations_per_epoch']

        assert iterations_per_instance % grad_acc_iterations == 0
        assert grad_acc_iterations <= iterations_per_instance

        score_AM = AverageMeter()
        loss_AM = AverageMeter()
        cost_AM = AverageMeter()
        imp_AM = AverageMeter()

        final_costs_list = []

        loop_cnt = 0
        sum_iterations = 0

        epoch_start_time = time.time()

        self.model.zero_grad()
        while sum_iterations < iterations_per_epoch:

            # load instances
            self.env.init_instances(self.batch_size, self.rollout_size, self.device)

            for iteration in range(self.trainer_params['nb_skipped_iterations']):
                self._search_one_batch(self.batch_size)

            for iteration in range(iterations_per_instance):

                avg_score, avg_loss, avg_cost, nb_improved_instances = self._train_one_batch(self.batch_size)

                if (iteration + 1) % grad_acc_iterations == 0:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.model.zero_grad()

                score_AM.update(avg_score, self.batch_size)
                loss_AM.update(avg_loss, self.batch_size)
                cost_AM.update(avg_cost, self.batch_size)
                imp_AM.update(nb_improved_instances / self.batch_size, self.batch_size)

            # After all destroy and repair iterations are finished, store final costs
            final_costs_list.extend(self.env.instanceSet.costs)

            sum_iterations += self.batch_size * iterations_per_instance

            # Log First 10 Batch, only at the first epoch
            if epoch == self.start_epoch:
                loop_cnt += 1
                if loop_cnt <= 10:
                    self.logger.info(
                        'Epoch {:3d}: Train {:3d}/{:3d}({:1.1f}%)  Best Reward: {:.4f},  Loss: {:.4f},  Reward: {:.5f}, Improved: {:.4f} Final costs: {:.4f}'
                        .format(epoch, sum_iterations, iterations_per_epoch, 100. * sum_iterations / iterations_per_epoch,
                                score_AM.avg, loss_AM.avg, cost_AM.avg, imp_AM.avg,
                                np.mean(final_costs_list)))

        # Log Once, for each epoch
        self.logger.info(
            'Epoch {:3d}: Train ({:3.0f}%)  Best Reward: {:.4f},  Loss: {:.4f}, Reward: {:.5f}, Improved.: {:.4f} Final costs: {:.4f}'
            .format(epoch, 100. * sum_iterations / iterations_per_epoch,
                    score_AM.avg, loss_AM.avg, cost_AM.avg, imp_AM.avg,
                    np.mean(final_costs_list)))

        if self.use_wandb:
            epoch_duration = time.time() - epoch_start_time
            wandb.log(step=epoch, data={"train/max_reward": score_AM.avg, "train/loss": loss_AM.avg,
                       "train/mean_reward": cost_AM.avg, "train/Improvement": imp_AM.avg,
                       'train/final_costs': np.mean(final_costs_list), "time/epoch": epoch_duration})

        return score_AM.avg, loss_AM.avg, cost_AM.avg, np.mean(final_costs_list)

    def _train_one_batch(self, batch_size):
        z_dim = self.model_params['z_dim']

        self.model.train()

        state = self.env.reset()
        reset_state = self.env.get_model_input(self.device)

        # Sample z vectors
        z_idx = torch.multinomial((torch.ones(batch_size, 2 ** z_dim) / 2 ** z_dim),
                                  self.rollout_size, replacement=False)
        z = self.binary_string_pool[z_idx].reshape(batch_size, 1, self.rollout_size, z_dim)
        z = z.transpose(1, 2).reshape(batch_size, self.rollout_size, z_dim)

        with torch.amp.autocast(device_type=self.device.type):
            self.model.pre_forward(reset_state, z)

        prob_list = torch.zeros(size=(batch_size, self.env.rollout_size, 0))
        probs_list = torch.zeros(size=(batch_size, self.env.rollout_size, self.env.problem_size, 0))
        # shape: (batch, pomo, 0~problem)

        # DLD Rollout
        ###############################################
        done = False
        while not done:
            with torch.amp.autocast(device_type=self.device.type):
                selected, prob, probs = self.model(state)
            # shape: (batch, pomo)

            state, done = self.env.step(selected)
            prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)
            probs_list = torch.cat((probs_list, probs[:, :, :, None]), dim=3)

        selected_nodes = self.env.selected_node_list.cpu().numpy()

        reward = self.destroy_repair(selected_nodes)

        # Loss
        reward_pop = reward.reshape(batch_size, self.rollout_size, -1)
        # Mean is calculated over different z samples
        advantage = reward_pop - reward_pop.mean(dim=1, keepdim=True)
        advantage = advantage.reshape(batch_size, -1)
        # shape: (batch, pomo)
        log_prob = prob_list.log().sum(dim=2)
        # size = (batch, pomo)

        # Finding the best rollout of each z sample
        costs = -reward.reshape(batch_size, self.rollout_size, -1)
        best_idx = costs.argsort(1).argsort(1)
        best_idx = best_idx.reshape(batch_size, -1)
        mask = best_idx < 1
        # mask = torch.clamp(mask + (self.trainer_params["mask_leak_alpha"]/z_sample_size), max=1)

        log_prob *= mask

        loss = - advantage * log_prob  # Minus Sign: To Increase REWARD
        # shape: (batch, rollout)
        loss_mean = loss.mean()

        # Step & Return
        ###############################################
        self.scaler.scale(loss_mean).backward()

        # Score
        ###############################################
        max_rollout_reward, _ = reward.max(dim=1)  # get best results from pomo
        score_mean = max_rollout_reward.float().mean()  # negative sign to make positive value
        nb_improved_instances = (max_rollout_reward > 1e-5).sum().item()


        return score_mean.item(), loss_mean.item(), reward.mean().item(), nb_improved_instances


    def destroy_repair(self, selected_nodes):
        recreate_n = self.env_params['recreate_n']
        reward_type = self.trainer_params['reward_type']
        beta = self.env_params['beta']
        insert_in_new_tours_only = self.env_params['insert_in_new_tours_only']

        # needs to return reward of size (batch, pomo)
        reward = np.zeros((self.batch_size, self.rollout_size))

        old_costs = np.array(self.env.instanceSet.costs)
        all_costs = self.env.instanceSet.remove_recreate(selected_nodes, recreate_n, "singleImp", beta=beta, insert_in_new_tours_only=insert_in_new_tours_only)
        all_costs = np.array(all_costs)

        for b_idx in range(self.batch_size):

            old_cost = old_costs[b_idx]  # cost before destroy
            costs = all_costs[b_idx]

            if reward_type == 'b':
                r = (costs < old_cost - 0.0001).astype('float') + ((old_cost - costs) * 0.0001)
            else:
                abs_improv = old_cost - costs
                r = np.maximum(abs_improv, 0)

            reward[b_idx] = r

        reward = torch.Tensor(reward)
        return reward  # negative


    def _search_one_batch(self, batch_size):
        z_dim = self.model.model_params['z_dim']
        recreate_n = self.env_params['recreate_n']
        beta = self.env_params['beta']
        insert_in_new_tours_only = self.env_params['insert_in_new_tours_only']
        rollout_size = 100

        # Ready
        ###############################################
        self.model.eval()
        self.env.change_rollout_size(rollout_size)
        with torch.no_grad():

            state = self.env.reset()
            reset_state = self.env.get_model_input(self.device)

            # Sample z vectors
            z_idx = torch.multinomial((torch.ones(batch_size, 2 ** z_dim) / 2 ** z_dim),
                                      self.env.rollout_size, replacement=False)
            z = self.binary_string_pool[z_idx].reshape(batch_size, 1, self.env.rollout_size, z_dim)
            z = z.transpose(1, 2).reshape(batch_size, self.env.rollout_size, z_dim)

            with torch.amp.autocast(device_type=self.device.type):
                self.model.pre_forward(reset_state, z)

            prob_list = torch.zeros(size=(batch_size, self.env.rollout_size, 0))
            # shape: (batch, pomo, 0~problem)

            # DLD Rollout
            ###############################################
            done = False
            while not done:
                with torch.amp.autocast(device_type=self.device.type):
                    selected, prob, _ = self.model(state)
                    # shape: (batch, pomo)

                state, done = self.env.step(selected)
                prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

            selected_nodes = self.env.selected_node_list.cpu().numpy()

            # Repair
            self.env.instanceSet.remove_recreate(selected_nodes, recreate_n, "allImp", beta=beta,
                                                 insert_in_new_tours_only=insert_in_new_tours_only)

        self.env.change_rollout_size(self.rollout_size)
