from typing import List, Tuple, Callable
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, EarlyStopping
from tqdm import tqdm

from .torch_model import (
    ArbitraryConditionModel
)
from ...base import *
from ...metrics.high_order import EnergyDistance


class ImpuGen(
    EMAMixIn,
    BasePredictorMixIn,
    BaseImputerMixIn,
    BaseUnconditionalGeneratorMixIn,
    BaseArbitraryConditionalGeneratorMixIn,
    BaseImbalanceMixin,
    Base
):

    """
    ImpuGen: Unified Tabular Imputation and Generation via Task-Aligned Sampling Strategies
    """

    def __init__(
            self,
            lr: float = 1e-3,
            d_model: [int, float, str] = 2048,
            d_time: int = 256,
            batch_mul: int = 1,
            auto_batch_mul: int = 4,
            mask_ratio_scheduler: str = 'linear',
            **kwargs
    ):
        super().__init__()
        self.d_model = self._get_d_model(d_model)
        self.token_embed_dim = 1

        self.model = ArbitraryConditionModel(
            input_dim=self.column_dim,
            condition_dim=self.column_dim,
            n_columns=self.n_columns,
            d_model=self.d_model,
            d_time=d_time,
            d_cond=max(self.column_dim, d_time),
            ode=False
        )
        self.lr = lr
        self.batch_mul = batch_mul
        self.auto_batch_mul = auto_batch_mul
        self.best_par_steps = 1
        self.best_mask_prob = None

        # Mask ratio scheduler function
        self.reference = None
        self.register_buffer('mean', torch.zeros(1, self.column_dim))
        self.register_buffer('std', torch.ones(1, self.column_dim))

    def _generate_uncond(
            self,
            bsz: int,
            num_iter: Optional[int] = None,
            progress: bool = False,
            temperature: float = 0.0,
            scheduler: Optional[str] = None,
            latent_scale: float = 1.0,
            mask_prob: Optional[Union[float, torch.Tensor]] = None,
            mask_target_column: bool = False,
            **kwargs
    ) -> pd.DataFrame:
        tokens = torch.full([bsz, self.numerical_dim + self.categorical_dim], torch.nan, device=self.device,
                            dtype=self.dtype)
        return self.sample_tokens(
            tokens, num_iter=num_iter, progress=progress, temperature=temperature,
            scheduler=scheduler, latent_scale=latent_scale, mask_prob=mask_prob,
            mask_target_column=mask_target_column, **kwargs
        )

    def _impute(
            self,
            dataframe: pd.DataFrame,
            num_iter: Optional[int] = None,
            num_average: int = 1,
            progress: bool = False,
            temperature: float = 0.0,
            scheduler: Optional[str] = None,
            latent_scale: float = 0.0,
            mask_target_column: bool = False,
            **kwargs
    ) -> pd.DataFrame:
        tokens = self.tabular_transform.transform(dataframe, return_as_tensor=True).to(self.device, self.dtype)
        num_tokens = []
        cat_tokens = []
        verbose = tqdm(range(num_average)) if num_average > 1 else range(num_average)
        for _ in verbose:
            imputed = self.sample_tokens(  # for deterministic imputation
                tokens, num_iter=1, progress=progress, temperature=temperature,
                scheduler=scheduler, latent_scale=latent_scale, mask_target_column=mask_target_column,
                return_as_dataframe=False, **kwargs)
            num_tokens.append(imputed[:, :self.numerical_dim])
            cat_tokens.append(imputed[:, self.numerical_dim:])
        num = torch.stack(num_tokens).mean(dim=0)
        if self.model_flags['onehot']:
            cat = torch.stack(cat_tokens).mean(dim=0)
        else:
            cat = torch.stack(cat_tokens).mode(dim=0).values
        return self.tabular_transform.inverse_transform(num, cat)

    def _predict(
            self,
            dataframe: pd.DataFrame,
            num_iter: Optional[int] = None,
            num_average: int = 1,
            progress: bool = False,
            temperature: float = 0.0,
            scheduler: Optional[str] = None,
            latent_scale: float = 0.0,
            **kwargs
    ) -> np.ndarray:
        prediction = self._impute(
            dataframe, num_iter=num_iter, num_average=num_average, progress=progress, temperature=temperature,
            scheduler=scheduler, latent_scale=latent_scale, **kwargs
        )
        return prediction[self.tgt].values

    def _predict_proba(
            self,
            dataframe: pd.DataFrame,
            num_iter: Optional[int] = None,
            num_average: int = 20,
            progress: bool = False,
            temperature: float = 0.0,
            scheduler: Optional[str] = None,
            latent_scale: float = 0.5,
            mask_target_column: bool = False,
            **kwargs
    ) -> np.ndarray:
        tokens = self.tabular_transform.transform(dataframe, return_as_tensor=True).to(self.device, self.dtype)
        cat_tokens = []
        verbose = tqdm(range(num_average)) if num_average > 1 else range(num_average)
        for _ in verbose:
            imputed = self.sample_tokens(
                tokens, num_iter=num_iter, progress=progress, temperature=temperature,
                scheduler=scheduler, latent_scale=latent_scale, mask_target_column=mask_target_column,
                return_as_dataframe=True, **kwargs)
            cat_tokens.append(
                self.tabular_transform.transform(imputed, onehot=False, return_as_tensor=True)[:, self.numerical_dim:])

        tgt_index = self.tabular_transform.categorical_columns.index(self.tgt)
        cat = torch.stack(cat_tokens)[:, :, tgt_index]
        n_cls = self.n_categories_per_columns[tgt_index]
        mapping = torch.eye(n_cls, device=self.device)
        onehot = mapping[cat.long()]
        proba = onehot.mean(dim=0)
        return proba.cpu().numpy()

    def _generate_by_class(self,
                           dataframe: pd.DataFrame,
                           num_iter: Optional[int] = None,
                           progress: bool = False,
                           temperature: float = 0.0,
                           scheduler: Optional[str] = None,
                           latent_scale: float = 1.0,
                           mask_target_column: bool = False,
                           cfg: float = 1.7,
                           **kwargs) -> pd.DataFrame:
        tokens = self.tabular_transform.transform(dataframe, return_as_tensor=True).to(self.device, self.dtype)
        return self.sample_tokens(
            tokens, num_iter=num_iter, progress=progress, temperature=temperature,
            scheduler=scheduler, latent_scale=latent_scale, mask_target_column=mask_target_column, cfg=cfg,
            **kwargs
        )

    def _generate_by_condition(self,
                               dataframe: pd.DataFrame,
                               num_iter: Optional[int] = None,
                               progress: bool = False,
                               temperature: float = 0.0,
                               scheduler: Optional[str] = None,
                               latent_scale: float = 1.0,
                               mask_target_column: bool = False,
                               cfg: float = 1.0,
                               **kwargs) -> pd.DataFrame:
        tokens = self.tabular_transform.transform(dataframe, return_as_tensor=True).to(self.device, self.dtype)
        return self.sample_tokens(
            tokens, num_iter=num_iter, progress=progress, temperature=temperature,
            scheduler=scheduler, latent_scale=latent_scale, mask_target_column=mask_target_column, cfg=cfg,
            **kwargs
        )

    def _get_d_model(self, d_model):
        if isinstance(d_model, int):
            pass
        elif isinstance(d_model, float):
            d_model = int(self.column_dim * d_model)
        elif d_model == 'auto':
            d_model = min(int(np.ceil(self.column_dim * 8 / 512) * 512), 2048)  # stepwise
        else:
            raise NotImplementedError(d_model)
        if self._cfg is not None:
            self._cfg.model.d_model = d_model
        return d_model

    def encode(self, df_or_tensor, *args, **kwargs) -> torch.Tensor:
        x = df_or_tensor
        if isinstance(x, pd.DataFrame):
            x = self.tabular_transform.transform(x, return_as_tensor=True)
        # Mean imputation
        x = (x - self.mean) / self.std / 2
        x = x.nan_to_num()
        return x

    def decode(self, tensor: torch.Tensor, *args, **kwargs):
        tensor = tensor.to(self.device, self.dtype).squeeze(-1)

        tensor = tensor * self.std * 2 + self.mean
        num_dim = self.tabular_transform.numerical_dim

        num = tensor[:, :num_dim].view(len(tensor), -1)
        cat = tensor[:, num_dim:].view(len(tensor), -1)
        # View/reshape in case the caller supplied a flattened batch
        return num.view(len(num), -1), cat.view(len(cat), -1)

    def _expand_mask_if_onehot(self, mask):
        if not self.model_flags['onehot']:
            return mask
        mask = torch.cat([mask[:, :self.numerical_dim]] + [e.tile(1, c) for e, c in
                                                           zip(mask[:, self.numerical_dim:].split(1, dim=1),
                                                               self.n_categorical_dim_per_columns)], dim=1)
        return mask

    def _shrink_mask_if_onehot(self, mask):
        if not self.model_flags['onehot']:
            return mask
        mask = torch.cat([mask[:, :self.numerical_dim]] + [e.any(dim=1, keepdim=True) for e in
                                                           mask[:, self.numerical_dim:].split(
                                                               self.n_categorical_dim_per_columns, dim=1)], dim=1)
        return mask

    def masking(
            self,
            tokens: torch.Tensor,
            mask_rate: float = None,
            nan_mask: torch.Tensor = None,
            order: torch.Tensor = None,
            order_bias: torch.Tensor = None,
            masking_order: str = 'random',
            return_order: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Create a boolean mask over `tokens` based on the provided mask rate or scheduling function.
        Optionally returns the updated 'order' for repeated usage.

        Args:
            tokens (torch.Tensor): Input tokens of shape (batch_size, seq_len).
            mask_rate (float): Fraction of tokens to mask.
            nan_mask (torch.Tensor): Boolean tensor indicating which positions are valid (not NaN).
            order (torch.Tensor): If provided, used to rank positions for masking.
            order_bias (torch.Tensor): Additional offset for the `order`.
            masking_order (str): 'random', 'raster', or 'decay' to choose how mask positions are computed.
            return_order (bool): If True, also return the updated ordering tensor.

        Returns:
            mask (torch.Tensor): A boolean mask of the same shape as `tokens`.
            order (torch.Tensor, optional): If `return_order=True`, also return the updated order.
        """
        bsz = len(tokens)
        seq_len = self.n_columns

        # If mask_rate not given, schedule it randomly or from the distribution
        if mask_rate is None:
            mask_rate = torch.Tensor(np.random.rand(bsz)).to(self.device)
        else:
            mask_rate = torch.Tensor([mask_rate]).to(self.device, self.dtype)

        # Setup order if none given
        if order is None:
            order = torch.rand(bsz, seq_len, device=self.device)
        if order_bias is None:
            order_bias = torch.zeros_like(order)

        # Different masking orders
        if masking_order == 'random':
            order = torch.rand(bsz, seq_len, device=self.device)
        elif masking_order == 'raster':
            order = torch.arange(seq_len, 0, -1, dtype=torch.float32, device=self.device)[None].repeat(bsz, 1)
        elif masking_order == 'raster-':
            order = torch.arange(seq_len, dtype=torch.float32, device=self.device)[None].repeat(bsz, 1)

        # Expand mask_rate for indexing
        mask_rate = torch.floor(mask_rate * seq_len).unsqueeze(-1)

        if nan_mask is not None:
            nan_mask = self._shrink_mask_if_onehot(nan_mask)
            order[~nan_mask] = torch.inf
            # Sort order+order_bias, pick positions < mask_rate
            mask = (torch.argsort(torch.argsort(order + order_bias, dim=1), dim=1) < mask_rate) * nan_mask
        else:
            mask = torch.argsort(torch.argsort(order + order_bias, dim=1), dim=1) < mask_rate

        if return_order:
            if masking_order == 'decay':
                # Increase order for masked positions
                order[mask] = 1 + torch.floor(order[mask]) + torch.rand_like(order[mask])
            mask = self._expand_mask_if_onehot(mask)
            return mask, order
        mask = self._expand_mask_if_onehot(mask)
        return mask

    def mask_with_prob(
            self,
            tokens: torch.Tensor,
            mask_prob: Union[float, torch.Tensor],
            nan_mask: torch.Tensor = None,
            return_order: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Mask tokens independently per column using the provided probabilities.

        Args:
            tokens (torch.Tensor): Input tokens of shape (batch_size, seq_len).
            mask_prob (Union[float, torch.Tensor]): Probability of masking each column. Accepts a scalar,
                1D tensor of length seq_len, or 2D tensor broadcastable to (batch_size, seq_len).
            nan_mask (torch.Tensor): Optional boolean tensor indicating valid positions (True where data is present).
            return_order (bool): When True, also return a placeholder order tensor for API compatibility.

        Returns:
            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Boolean mask matching `tokens`
            (and None placeholder if `return_order` is True).
        """
        if mask_prob is None:
            raise ValueError("mask_prob must be provided for masking_v2.")
        bsz = len(tokens)
        seq_len = self.n_columns
        if not torch.is_tensor(mask_prob):
            mask_prob = torch.tensor(mask_prob, device=self.device, dtype=self.dtype)
        else:
            mask_prob = mask_prob.to(device=self.device, dtype=self.dtype)
        if mask_prob.ndim == 0:
            mask_prob = mask_prob.repeat(seq_len)
        if mask_prob.ndim == 1:
            if mask_prob.numel() != seq_len:
                raise ValueError(f"Expected mask_prob to have length {seq_len}, got {mask_prob.numel()}.")
            mask_prob = mask_prob.unsqueeze(0).expand(bsz, -1)
        elif mask_prob.ndim == 2:
            if mask_prob.shape[1] != seq_len:
                raise ValueError(f"Expected mask_prob to have {seq_len} columns, got {mask_prob.shape[1]}.")
            if mask_prob.shape[0] == 1:
                mask_prob = mask_prob.expand(bsz, -1)
            elif mask_prob.shape[0] != bsz:
                raise ValueError(f"Expected mask_prob batch dimension {bsz}, got {mask_prob.shape[0]}.")
        else:
            raise ValueError("mask_prob must be a scalar, 1D tensor, or 2D tensor.")
        mask_prob = mask_prob.clamp(0.0, 1.0)
        random_vals = torch.rand(bsz, seq_len, device=self.device, dtype=self.dtype)
        mask = random_vals < mask_prob
        if nan_mask is not None:
            nan_mask = self._shrink_mask_if_onehot(nan_mask)
            mask = mask & nan_mask
        mask = self._expand_mask_if_onehot(mask)
        if return_order:
            order_placeholder = torch.zeros(bsz, seq_len, device=self.device, dtype=self.dtype)
            return mask, order_placeholder
        return mask

    def search_mask_prob(
            self,
            reference: pd.DataFrame,
            mask_prob_values: Optional[List[float]] = None,
            num_average: int = 3,
            verbose: bool = True
    ) -> None:
        """
        Search for the best masking probability using α-precision over two-step refinement.

        Args:
            reference (pd.DataFrame): Reference data to condition on when computing α-precision.
            mask_prob_values (Optional[List[float]]): Candidate probabilities to sweep. Defaults to 0.1..0.9.
            num_average (int): Number of samples to average for each probability.
            verbose (bool): Whether to print progress information.
        """
        metric = EnergyDistance(self._transform, self.model_flags['drop_target'])

        if mask_prob_values is None:
            mask_prob_values = [round(p, 1) for p in np.arange(0.1, 1.0, 0.1)]

        def _alpha_precision_error(num_iter: int, mask_prob: Optional[float], label: str):
            current_error = 0.0
            for j in range(1, num_average + 1):
                if verbose:
                    print(f"{label} ({j}/{num_average})")
                df_ref = self.generate_by_condition(reference.sample(min(8192, len(reference))))
                gen_kwargs = dict(num_iter=num_iter)
                if mask_prob is not None:
                    gen_kwargs['mask_prob'] = mask_prob
                gen_data = self.generate_uncond(len(df_ref), **gen_kwargs)
                result = metric.evaluation(df_ref, gen_data)
                current_error += result['energy_distance']
            return current_error / num_average

        baseline_error = _alpha_precision_error(1, None, 'Baseline energy distance (num_iter=1)')
        best_error = baseline_error
        best_iter = 1
        best_mask_prob = None

        mask_errors = {}

        for prob in mask_prob_values:
            avg_error = _alpha_precision_error(2, prob, f"Mask prob searching... (p={prob:.2f})")
            mask_errors[prob] = avg_error
            if avg_error < best_error:
                if verbose:
                    print(f"energy distance improved: {best_error :.6f} -> {avg_error :.6f} (p={prob:.2f})")
                best_error = avg_error
                best_iter = 2
                best_mask_prob = float(prob)

        if verbose:
            print('energy distance baseline:', baseline_error)
            print('energy distance per mask_prob:', mask_errors)

        self.best_par_steps = max(best_iter, 1)
        self.best_mask_prob = best_mask_prob

    def _step(self, tokens, mask, update_mask, latents, latent_scale, temperature, cfg, mask_target_column, num_steps=50, **kwargs):
        # Force positions in 'mask' to NaN
        tokens[mask] = torch.nan

        # Encode partially masked tokens as condition
        condition = self.encode(tokens)
        observed_column = self._shrink_mask_if_onehot(~mask)
        condition_u = self.encode(torch.full_like(tokens, torch.nan))

        # The model uses condition & latents to produce a refined output
        imputed = self.model.sample(
            condition=condition,
            observed_column=observed_column,
            condition_u=condition_u,
            observed_column_u=None,
            latents=latents * latent_scale,
            temperature=temperature,
            num_steps=num_steps,
            cfg=cfg,
        ).view(len(tokens), self.column_dim, self.token_embed_dim)

        # Decode the results
        num, cat = self.decode(imputed)
        if not self.model_flags['onehot']:
            if self.categorical_dim:
                cat = torch.stack([col.argmax(dim=-1) for col in cat], dim=1)
            else:
                cat = torch.zeros(len(tokens), 0, device=self.device, dtype=torch.int32)
        next_tokens = torch.cat([num, cat], dim=1)
        tokens[update_mask] = next_tokens[update_mask]
        if mask_target_column and self._transform.target_column is not None and not \
                self.model_flags['drop_target']:
            tokens = self.tabular_transform.transform(
                self.tabular_transform.inverse_transform(tokens).drop(self._transform.target_column,
                                                                      axis=1), return_as_tensor=True).to(
                tokens.device,
                tokens.dtype)
        return tokens

    @torch.no_grad()
    def sample_tokens(
            self,
            tokens: torch.Tensor,
            num_iter: Optional[int] = None,
            progress: bool = False,
            temperature: float = 0.0,
            latent_scale: float = 1.0,
            cfg: float = 1.0,
            mask_prob: Optional[Union[float, torch.Tensor]] = None,
            stop: int = -1,
            return_as_dataframe: bool = True,
            mask_target_column: bool = False,
            num_steps: int = 50,
            **kwargs
    ) -> Union[pd.DataFrame, torch.Tensor]:
        """
        Iteratively refine `tokens` using the model's 'impute' method, optionally
        using the autoencoder to encode/decode partial or masked data.

        Args:
            tokens (torch.Tensor): Initial tensor of shape (batch_size, seq_len), possibly with NaNs.
            num_iter (int, optional): Number of refinement steps. If None, uses `self.best_par_steps`.
            progress (bool): If True, display a progress bar via tqdm.
            temperature (float): Temperature for sampling in the model's decode step.
            latent_scale (float): A scaling factor for the latent random noise in each iteration.
            mask_prob (Optional[Union[float, torch.Tensor]]): Per-column masking probability used to update masks. Overrides
                scheduler-based masking when provided.
            stop (int): If > 0, stop early at the specified iteration step. If -1, do all `num_iter`.
            return_as_dataframe (bool): If True, convert the final tokens to a DataFrame.
            **kwargs: Additional arguments, printed if not used.

        Returns:
            pd.DataFrame or torch.Tensor: The refined data, either as a DataFrame or as raw tokens.
        """
        tokens = tokens.clone()
        if num_iter is None:
            num_iter = max(self.best_par_steps, 1)
        assert num_iter >= 1

        iter_range = tqdm(range(num_iter)) if progress else range(num_iter)
        bsz = len(tokens)

        # Keep track of which positions are NaN
        nan_mask = tokens.isnan()
        mask = nan_mask.clone()

        # Sample initial latents
        latents = torch.randn(
            bsz, self.column_dim, self.token_embed_dim,
            device=tokens.device, dtype=torch.float64
        )
        self._latents = latents

        # Determine mask ratio scheduling function
        order = None

        mask_prob_tensor = mask_prob
        if mask_prob_tensor is None and num_iter >= 2 and self.best_mask_prob is not None:
            mask_prob_tensor = self.best_mask_prob
        if mask_prob_tensor is not None:
            if torch.is_tensor(mask_prob_tensor):
                mask_prob_tensor = mask_prob_tensor.to(self.device, self.dtype)
            else:
                mask_prob_tensor = torch.as_tensor(mask_prob_tensor, device=self.device, dtype=self.dtype)

        with self.ema.average_parameters():
            for step in iter_range:
                tokens = self._step(tokens, mask, mask, latents, latent_scale, temperature, cfg, mask_target_column,
                                    num_steps=int(num_steps / num_iter),
                                    **kwargs)
                # Early stopping condition
                if step + 1 == num_iter or step + 1 == stop:
                    break

                # Update mask for next iteration
                mask, order = self.mask_with_prob(tokens, mask_prob_tensor, nan_mask, True)

        if return_as_dataframe:
            return self.tabular_transform.inverse_transform(tokens)
        return tokens

    def forward(self, x: torch.Tensor):
        """
        Forward pass for training. Masks parts of x, encodes both original and masked versions,
        then computes a difference loss.

        Args:
            x (torch.Tensor): Batched data of shape (batch_size, seq_len).

        Returns:
            torch.Tensor: A scalar loss value.
        """
        with torch.no_grad():
            x = x.tile([self.batch_mul, 1])
            bsz = len(x)
            mask = self.masking(x, masking_order='random')

            x_masked = x.clone()
            mask[: int(bsz // 10)] = True  # Drop 10%
            x_masked[mask] = torch.nan

            target = self.encode(x).view(bsz, -1)
            condition = self.encode(x_masked).view(bsz, -1)
            observed_column = self._shrink_mask_if_onehot(~x_masked.isnan())

        loss = self.model(target=target,
                          condition=condition,
                          observed_column=observed_column)
        loss = loss[~x.isnan()].mean()
        return loss

    def training_step(self, x: Union[torch.Tensor, List[torch.Tensor]]):
        if not isinstance(x, torch.Tensor):
            x = x[0]

        diff_loss = self(x)
        self.log('t.loss', diff_loss.item(), on_step=True, on_epoch=True, prog_bar=True)

        # If the ODE sub-model tracks # function evaluations (nfe), log it
        if hasattr(self.model.diffloss.model, 'nfe'):
            self.log('t.nfe', self.model.diffloss.model.nfe, on_step=True, on_epoch=True, prog_bar=True)
        return diff_loss

    def configure_optimizers(self):
        """
        Define optimizer (Adam) and a scheduler (ReduceLROnPlateau).

        Returns:
            dict: Contains "optimizer" and "lr_scheduler" entries.
        """
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=0)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.95, patience=20,
        )
        sch_config = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": "t.loss_epoch"
        }
        return {"optimizer": self.optimizer, "lr_scheduler": sch_config}

    def configure_callbacks(self, *args, **kwargs):
        """
        Provide callbacks for:
         - LearningRateMonitor
         - EarlyStopping (patience=500 on training loss)

        Returns:
            list: A list of callback instances.
        """
        callbacks = [
            LearningRateMonitor(logging_interval='epoch'),
            EarlyStopping(monitor='t.loss_epoch', patience=500, min_delta=0)
        ]
        return callbacks

    def on_train_end(self):
        """
        Called at the end of training, triggers searching for the best masking probability
        and refinement steps.
        """
        if self.reference is not None:
            self.search_mask_prob(self.reference)

    def on_save_checkpoint(self, checkpoint: dict):
        """
        Store the best search hyperparameters in the checkpoint.

        Args:
            checkpoint (dict): Checkpoint data to be saved.
        """
        checkpoint['best_par_steps'] = self.best_par_steps
        checkpoint['best_mask_prob'] = self.best_mask_prob

    def on_load_checkpoint(self, checkpoint: dict):
        """
        Load stored search hyperparameters from the checkpoint if available.

        Args:
            checkpoint (dict): Loaded checkpoint dictionary.
        """
        self.best_par_steps = checkpoint.get('best_par_steps', 0)
        self.best_mask_prob = checkpoint.get('best_mask_prob', None)

    def fit(self, scenario: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x):
        cfg = self._cfg
        self.best_par_steps = 1
        self.best_mask_prob = None
        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)
        self._transform.fit(train_df)
        train_data = self.tabular_transform.transform(train_df, return_as_tensor=True)
        mean = train_data.nanmean(dim=0, keepdim=True)
        std = torch.stack([e[~e.isnan()].std() for e in train_data.T])[None] + 1e-4
        self.mean.data = mean
        self.std.data = std
        if len(train_df) < 2048:
            self.batch_mul = self.auto_batch_mul
        self.reference = train_df
        return super().fit(scenario)
