import delu
import glob
import itertools
import math
import os
import pandas as pd
import random
import re
import rtdl_num_embeddings
import rtdl_revisiting_models
import sklearn.tree as sklearn_tree
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
import xgboost as xgb

from collections import defaultdict
from collections.abc import Callable
from joblib import Memory
from rtdl_num_embeddings import _check_bins
from sklearn.tree import BaseDecisionTree
from torch import Tensor
from torch.nn import Parameter
from tqdm import tqdm
from typing import Any, Literal, Optional, cast

from lib.util import TaskType, is_oom_exception

memory = Memory(location='./joblib_tmp', verbose=2)

# ======================================================================================
# Initialization
# ======================================================================================
def init_rsqrt_uniform_(x: Tensor, d: int) -> Tensor:
    assert d > 0
    d_rsqrt = d**-0.5
    return nn.init.uniform_(x, -d_rsqrt, d_rsqrt)

def _init_rsqrt_uniform_(weight: Tensor, dim: None | int, d: None | int = None) -> None:
    if d is None:
        assert dim is not None
        d = weight.shape[dim]
    else:
        assert dim is None
    d_rsqrt = 1 / math.sqrt(d)
    nn.init.uniform_(weight, -d_rsqrt, d_rsqrt)

@torch.inference_mode()
def init_random_signs_(x: Tensor) -> Tensor:
    return x.bernoulli_(0.5).mul_(2).add_(-1)


# ======================================================================================
# Modules
# ======================================================================================
class Identity(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return x


class NLinear(nn.Module):
    """A stack of N linear layers. Each layer is applied to its own part of the input.

    **Shape**

    - Input: ``(B, N, in_features)``
    - Output: ``(B, N, out_features)``

    The i-th linear layer is applied to the i-th matrix of the shape (B, in_features).

    Technically, this is a simplified version of delu.nn.NLinear:
    https://yura52.github.io/delu/stable/api/generated/delu.nn.NLinear.html.
    The difference is that this layer supports only 3D inputs
    with exactly one batch dimension. By contrast, delu.nn.NLinear supports
    any number of batch dimensions.
    """

    def __init__(
        self, n: int, in_features: int, out_features: int, bias: bool = True
    ) -> None:
        super().__init__()
        self.weight = Parameter(torch.empty(n, in_features, out_features))
        self.bias = Parameter(torch.empty(n, out_features)) if bias else None
        self.reset_parameters()

    def reset_parameters(self):
        d = self.weight.shape[-2]
        init_rsqrt_uniform_(self.weight, d)
        if self.bias is not None:
            init_rsqrt_uniform_(self.bias, d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 3
        assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]

        x = x.transpose(0, 1)
        x = x @ self.weight
        x = x.transpose(0, 1)
        if self.bias is not None:
            x = x + self.bias
        return x


class PiecewiseLinearEmbeddings(rtdl_num_embeddings.PiecewiseLinearEmbeddings):
    """
    This class simply adds the default values for `activation` and `version`.
    """

    def __init__(
        self,
        *args,
        activation: bool = False,
        version: None | Literal['A', 'B'] = 'B',
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs, activation=activation, version=version)


class OneHotEncoding0d(nn.Module):
    # Input:  (*, n_cat_features=len(cardinalities))
    # Output: (*, sum(cardinalities))

    def __init__(self, cardinalities: list[int]) -> None:
        super().__init__()
        self._cardinalities = cardinalities

    def forward(self, x: Tensor) -> Tensor:
        assert x.ndim >= 1
        assert x.shape[-1] == len(self._cardinalities)

        return torch.cat(
            [
                # NOTE
                # This is a quick hack to support out-of-vocabulary categories.
                #
                # Recall that lib.data.transform_cat encodes categorical features
                # as follows:
                # - In-vocabulary values receive indices from `range(cardinality)`.
                # - All out-of-vocabulary values (i.e. new categories in validation
                #   and test data that are not presented in the training data)
                #   receive the index `cardinality`.
                #
                # As such, the line below will produce the standard one-hot encoding for
                # known categories, and the all-zeros encoding for unknown categories.
                # This may not be the best approach to deal with unknown values,
                # but should be enough for our purposes.
                F.one_hot(x[..., i], cardinality + 1)[..., :-1]
                for i, cardinality in enumerate(self._cardinalities)
            ],
            -1,
        )


class ScaleEnsemble(nn.Module):
    def __init__(
        self,
        k: int,
        d: int,
        *,
        init: Literal['ones', 'normal', 'random-signs'],
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.empty(k, d))
        self._weight_init = init
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self._weight_init == 'ones':
            nn.init.ones_(self.weight)
        elif self._weight_init == 'normal':
            nn.init.normal_(self.weight)
        elif self._weight_init == 'random-signs':
            init_random_signs_(self.weight)
        else:
            raise ValueError(f'Unknown weight_init: {self._weight_init}')

    def forward(self, x: Tensor) -> Tensor:
        assert x.ndim >= 2
        return x * self.weight


class LinearEfficientEnsemble(nn.Module):
    """
    This layer is a more configurable version of the "BatchEnsemble" layer
    from the paper
    "BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning"
    (link: https://arxiv.org/abs/2002.06715).

    First, this layer allows to select only some of the "ensembled" parts:
    - the input scaling  (r_i in the BatchEnsemble paper)
    - the output scaling (s_i in the BatchEnsemble paper)
    - the output bias    (not mentioned in the BatchEnsemble paper,
                          but is presented in public implementations)

    Second, the initialization of the scaling weights is configurable
    through the `scaling_init` argument.

    NOTE
    The term "adapter" is used in the TabM paper only to tell the story.
    The original BatchEnsemble paper does NOT use this term. So this class also
    avoids the term "adapter".
    """

    r: None | Tensor
    s: None | Tensor
    bias: None | Tensor

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        *,
        k: int,
        ensemble_scaling_in: bool,
        ensemble_scaling_out: bool,
        ensemble_bias: bool,
        scaling_init: Literal['ones', 'random-signs'],
    ):
        assert k > 0
        if ensemble_bias:
            assert bias
        super().__init__()

        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.register_parameter(
            'r',
            (
                nn.Parameter(torch.empty(k, in_features))
                if ensemble_scaling_in
                else None
            ),  # type: ignore[code]
        )
        self.register_parameter(
            's',
            (
                nn.Parameter(torch.empty(k, out_features))
                if ensemble_scaling_out
                else None
            ),  # type: ignore[code]
        )
        self.register_parameter(
            'bias',
            (
                nn.Parameter(torch.empty(out_features))  # type: ignore[code]
                if bias and not ensemble_bias
                else nn.Parameter(torch.empty(k, out_features))
                if ensemble_bias
                else None
            ),
        )

        self.in_features = in_features
        self.out_features = out_features
        self.k = k
        self.scaling_init = scaling_init

        self.reset_parameters()

    def reset_parameters(self):
        init_rsqrt_uniform_(self.weight, self.in_features)
        scaling_init_fn = {'ones': nn.init.ones_, 'random-signs': init_random_signs_}[
            self.scaling_init
        ]
        if self.r is not None:
            scaling_init_fn(self.r)
        if self.s is not None:
            scaling_init_fn(self.s)
        if self.bias is not None:
            bias_init = torch.empty(
                # NOTE: the shape of bias_init is (out_features,) not (k, out_features).
                # It means that all biases have the same initialization.
                # This is similar to having one shared bias plus
                # k zero-initialized non-shared biases.
                self.out_features,
                dtype=self.weight.dtype,
                device=self.weight.device,
            )
            bias_init = init_rsqrt_uniform_(bias_init, self.in_features)
            with torch.inference_mode():
                self.bias.copy_(bias_init)

    def forward(self, x: Tensor) -> Tensor:
        # x.shape == (B, K, D)
        assert x.ndim == 3

        # >>> The equation (5) from the BatchEnsemble paper (arXiv v2).
        if self.r is not None:
            x = x * self.r
        x = x @ self.weight.T
        if self.s is not None:
            x = x * self.s
        # <<<

        if self.bias is not None:
            x = x + self.bias
        return x


def make_efficient_ensemble(module: nn.Module, EnsembleLayer, **kwargs) -> None:
    """Replace linear layers with efficient ensembles of linear layers.

    NOTE
    In the paper, there are no experiments with networks with normalization layers.
    Perhaps, their trainable weights (the affine transformations) also need
    "ensemblification" as in the paper about "FiLM-Ensemble".
    Additional experiments are required to make conclusions.
    """
    for name, submodule in list(module.named_children()):
        if isinstance(submodule, nn.Linear):
            module.add_module(
                name,
                EnsembleLayer(
                    in_features=submodule.in_features,
                    out_features=submodule.out_features,
                    bias=submodule.bias is not None,
                    **kwargs,
                ),
            )
        else:
            make_efficient_ensemble(submodule, EnsembleLayer, **kwargs)


class MLP(nn.Module):
    def __init__(
        self,
        *,
        d_in: None | int = None,
        d_out: None | int = None,
        n_blocks: int,
        d_block: int,
        dropout: float,
        activation: str = 'ReLU',
    ) -> None:
        super().__init__()

        d_first = d_block if d_in is None else d_in
        self.blocks = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(d_first if i == 0 else d_block, d_block),
                    getattr(nn, activation)(),
                    nn.Dropout(dropout),
                )
                for i in range(n_blocks)
            ]
        )
        self.output = None if d_out is None else nn.Linear(d_block, d_out)

    def forward(self, x: Tensor) -> Tensor:
        for block in self.blocks:
            x = block(x)
        if self.output is not None:
            x = self.output(x)
        return x


def masked_cumsum(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    """
    Compute segment-wise cumulative sums over `X`, where segments are defined by a binary mask `Y`.

    The function interprets each row of `Y` as a sequence of segment boundaries, where `1` indicates
    the end of a segment. It computes the sum of `X` values within each segment and writes the result
    only at the segment end positions (i.e., where `Y == 1`). All other positions are set to zero.

    Parameters
    ----------
    X : torch.Tensor
        Input tensor of shape (A, B) or (A, B, C), where `A` is the batch size, `B` is the sequence length,
        and `C` is the optional feature dimension.
    Y : torch.Tensor
        Binary mask tensor of shape (A, B), where `1` indicates the end of a segment.

    Returns
    -------
    Z : torch.Tensor
        Output tensor of the same shape as `X`, where for each row, the segment-wise sum of `X` is stored
        at positions where `Y == 1`, and all other positions are set to zero.

    Example
    -------
    >>> X = torch.tensor([[1, 2, 3, 4, 5]])
    >>> Y = torch.tensor([[0, 0, 1, 0, 1]], dtype=bool)
    >>> masked_cumsum(X, Y)
    tensor([[0, 0, 6, 0, 9]])
    
    Explanation:
        - Segment 1: [1, 2, 3] → sum = 6 → placed at index 2
        - Segment 2: [4, 5]   → sum = 9 → placed at index 4
    """
    
    if X.ndim == 2:
        X = X.unsqueeze(-1)
        do_squeeze = True
    else:
        do_squeeze = False
        
    A, B, C = X.shape

    # Create segment IDs to track previous segments
    seg_id = torch.cumsum(Y, dim=1)  # [A, B]
    seg_id = torch.where(Y, seg_id - 1, seg_id)  # Use previous segment ID at positions where Y == 1

    max_seg = seg_id.max().item() + 1
    batch_offsets = torch.arange(A, device=X.device).view(A, 1) * max_seg
    global_seg_id = seg_id + batch_offsets  # [A, B]

    # Accumulator for segment sums
    Z_accum = torch.zeros((A * max_seg, C), dtype=X.dtype, device=X.device)
    Z_accum.index_add_(
        0,
        global_seg_id.reshape(-1),
        X.reshape(-1, C)
    )

    # Initialize final output tensor
    Z = torch.zeros_like(X)  # [A, B, C]

    # Assign accumulated values only at positions where Y == 1
    seg_ids_for_write = torch.cumsum(Y, dim=1) - 1  # [A, B]
    global_seg_write = seg_ids_for_write + batch_offsets  # [A, B]
    Z[Y] = Z_accum[global_seg_write[Y]]

    if do_squeeze:
        Z = Z[..., 0]
    
    return Z


class MaskedNLinear(nn.Module):
    """Masked version of NLinear using masked_cumsum
    """

    def __init__(
        self, n: int, in_features: int, out_features: int, bias: bool = True
    ) -> None:
        super().__init__()
        self.weight = Parameter(torch.empty(n, in_features, out_features))
        self.bias = Parameter(torch.empty(n, out_features)) if bias else None
        self.reset_parameters()

    def reset_parameters(self):
        """Reset the parameters."""
        d_in_rsqrt = self.weight.shape[-2] ** -0.5
        nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)

    def forward(self, x: torch.Tensor, dropout_mask: torch.Tensor) -> torch.Tensor:
        """Do the forward pass."""
        if x.ndim != 3:
            raise ValueError(
                '_NLinear supports only inputs with exactly one batch dimension,'
                ' so `x` must have a shape like (BATCH_SIZE, N_FEATURES, D_EMBEDDING).'
            )
        assert x.shape[-(self.weight.ndim - 1) :] == self.weight.shape[:-1]

        weight = masked_cumsum(self.weight, ~dropout_mask)
        
        x = x.transpose(0, 1)
        x = x @ weight
        x = x.transpose(0, 1)
        if self.bias is not None:
            x = x + self.bias
        return x


class _MaskedPiecewiseLinearEncodingImpl(nn.Module):
    mask: Tensor
    bin_width_sum: Tensor
    bin_width_logit_offset: Tensor
    
    def __init__(self, bins: list[Tensor], parameterized_bins: bool) -> None:
        assert len(bins) > 0
        super().__init__()

        self.n_features = len(bins)
        self.n_bins = [len(x) - 1 for x in bins]
        self.max_n_bins = max(self.n_bins)

        single_bin_mask = torch.tensor(self.n_bins) == 1
        self.register_buffer('single_bin_mask', single_bin_mask)
        max_bin_mask = torch.tensor(self.n_bins) == self.max_n_bins
        self.register_buffer('max_bin_mask', max_bin_mask)

        self.register_buffer(
            'mask',
            torch.row_stack(
                [
                    torch.cat(
                        [
                            torch.ones(len(x) - 1, dtype=torch.bool),
                            torch.zeros(self.max_n_bins - (len(x) - 1), dtype=torch.bool),
                        ]
                    )
                    for x in bins
                ]
            ),
        )

        self.register_buffer(
            'first_bin_mask',
            torch.row_stack(
                [
                    torch.cat(
                        [
                            torch.ones(1, dtype=torch.bool),
                            torch.zeros(self.max_n_bins - 1, dtype=torch.bool),
                        ]
                    )
                    for x in bins
                ]
            ),
        )
        self.register_buffer(
            'last_bin_mask',
            torch.row_stack(
                [
                    torch.cat(
                        [
                            torch.zeros(len(x) - 2, dtype=torch.bool),
                            torch.ones(self.max_n_bins - (len(x) - 2), dtype=torch.bool),
                        ]
                    )
                    for x in bins
                ]
            ),
        )
        
        bin_edges_0 = torch.zeros(self.n_features)
        bin_width_sum = torch.zeros(self.n_features)
        bin_width_logit = torch.zeros(self.n_features, self.max_n_bins)
        for i, bin_edges in enumerate(bins):
            bin_width = bin_edges.diff()
            bin_edges_0[i] = bin_edges[0]
            bin_width_sum[i] = bin_edges[-1] - bin_edges[0]
            bin_width_logit[i, :len(bin_width)] = (bin_width / bin_width.mean()).log()
        self.register_buffer('bin_edges_0', bin_edges_0)
        self.register_buffer('bin_width_sum', bin_width_sum)
        if parameterized_bins:
            self.register_parameter('bin_width_logit_offset', nn.Parameter(bin_width_logit))
        else:
            self.register_buffer('bin_width_logit_offset', bin_width_logit)
            
        
    def get_bin_edges(self, dropout_mask):
        min_value = torch.finfo(self.bin_width_logit_offset.dtype).min
        tiny_value = torch.finfo(self.bin_width_logit_offset.dtype).tiny

        bin_width = self.bin_width_logit_offset.masked_fill(~self.mask, min_value).softmax(-1)
        bin_width = bin_width * self.bin_width_sum.unsqueeze(-1)
        bin_width = bin_width.masked_fill(~self.mask, tiny_value)
        bin_width = masked_cumsum(bin_width, ~dropout_mask)
        bin_width = bin_width.masked_fill(dropout_mask, tiny_value)

        bin_edges = bin_width.cumsum(-1)
        bin_edges_with_0 = self.bin_edges_0.view(-1, 1).repeat(1, self.max_n_bins + 1)
        bin_edges_with_0[:, 1:] += bin_edges
        
        return bin_edges_with_0, bin_width
    
    def get_weight_bias(self, dropout_mask):
        bin_edges, bin_width = self.get_bin_edges(dropout_mask)
        w = 1.0 / bin_width
        b = -bin_edges[:, :-1] / bin_width
        weight = torch.where(self.mask & ~dropout_mask, w, torch.zeros_like(w))
        bias = torch.where(self.mask & ~dropout_mask, b, torch.zeros_like(b))
        return weight, bias
    
    def get_max_n_bins(self) -> int:
        return self.max_n_bins

    def forward(self, x: Tensor, dropout_mask: Tensor) -> Tensor:
        weight, bias = self.get_weight_bias(dropout_mask)
        x = torch.addcmul(bias, weight, x[..., None])
        if x.shape[-1] > 1:
            x_clamp_both = x.clamp(min=0, max=1)
            x_clamp_min = x.clamp(min=0)
            x_clamp_max = x.clamp(max=1)
            xs = torch.where(self.last_bin_mask, x_clamp_min, x_clamp_both)
            xs = torch.where(self.first_bin_mask, x_clamp_max, xs)
            
            if self.single_bin_mask is None:
                x = xs
            else:
                x = torch.where(
                    self.single_bin_mask[..., None],
                    x,
                    xs,
                )
        return x
        

class GGPLEmbeddings(nn.Module):
    def __init__(
        self,
        bins: list[Tensor],
        d_embedding: int,
        dropout_ratio: float = 0.0,
        parameterized_bins: bool = True,
    ) -> None:
        if d_embedding <= 0:
            raise ValueError(
                f'd_embedding must be a positive integer, however: {d_embedding=}'
            )
        rtdl_num_embeddings._check_bins(bins)
        super().__init__()
        n_features = len(bins)
        self.linear0 = rtdl_num_embeddings.LinearEmbeddings(n_features, d_embedding)
        self.impl = _MaskedPiecewiseLinearEncodingImpl(bins, parameterized_bins)
        self.linear = MaskedNLinear(
            len(bins),
            self.impl.get_max_n_bins(),
            d_embedding,
            bias=False,
        )
        nn.init.zeros_(self.linear.weight)
        self.dropout_ratio = dropout_ratio
    
    def forward(self, x: Tensor) -> Tensor:
        """Do the forward pass."""
        if x.ndim != 2:
            raise ValueError(
                'For now, only inputs with exactly one batch dimension are supported.'
            )

        if self.training:
            dropout_mask = torch.rand(x.shape[1], self.impl.get_max_n_bins(), device=x.device) < self.dropout_ratio
            dropout_mask = dropout_mask & ~self.impl.last_bin_mask
        else:
            dropout_mask = torch.zeros(x.shape[1], self.impl.get_max_n_bins(), device=x.device, dtype=torch.bool)
        
        x_linear = self.linear0(x)
        x_ple = self.impl(x, dropout_mask)
        x_ple = self.linear(x_ple, dropout_mask)
        
        return x_linear + x_ple

class CategoricalEmbeddings1d(nn.Module):
    # Input:  (*, n_cat_features=len(cardinalities))
    # Output: (*, n_cat_features, d_embedding)
    def __init__(self, cardinalities: list[int], d_embedding: int) -> None:
        super().__init__()
        self.embeddings = nn.ModuleList(
            # [nn.Embedding(c, d_embedding) for c in cardinalities]
            # NOTE: `+ 1` is here to support unknown values that are expected to have
            # the value `max-known-category + 1`.
            # This is not a good way to handle unknown values. This is just a quick
            # hack to stop failing on some datasets.
            [nn.Embedding(c + 1, d_embedding) for c in cardinalities]
        )
        self.reset_parameters()

    def reset_parameters(self) -> None:
        for m in self.embeddings:
            _init_rsqrt_uniform_(m.weight, -1)  # type: ignore[code]

    def forward(self, x: Tensor) -> Tensor:
        assert x.ndim >= 1
        return torch.stack(
            [m(x[..., i]) for i, m in enumerate(self.embeddings)], dim=-2
        )

class CLSEmbedding(nn.Module):
    def __init__(self, d_embedding: int) -> None:
        super().__init__()
        self.weight = Parameter(torch.empty(d_embedding))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        d_rsqrt = self.weight.shape[-1] ** -0.5
        nn.init.uniform_(self.weight, -d_rsqrt, d_rsqrt)

    def forward(self, batch_dims: tuple[int]) -> Tensor:
        if not batch_dims:
            raise ValueError('The input must be non-empty')

        return self.weight.expand(*batch_dims, 1, -1)


        
_CUSTOM_MODULES = {
    # https://docs.python.org/3/library/stdtypes.html#definition.__name__
    CustomModule.__name__: CustomModule
    for CustomModule in [
        rtdl_num_embeddings.LinearEmbeddings,
        rtdl_num_embeddings.LinearReLUEmbeddings,
        rtdl_num_embeddings.PeriodicEmbeddings,
        PiecewiseLinearEmbeddings,
        MLP,
        GGPLEmbeddings,
        CategoricalEmbeddings1d,
        CLSEmbedding,
    ]
}


def make_module(type: str, *args, **kwargs) -> nn.Module:
    Module = getattr(nn, type, None)
    if Module is None:
        Module = _CUSTOM_MODULES[type]
    return Module(*args, **kwargs)


def get_n_parameters(m: nn.Module):
    return sum(x.numel() for x in m.parameters() if x.requires_grad)

def get_d_out(n_classes: None | int) -> int:
    return 1 if n_classes is None or n_classes == 2 else n_classes

@torch.inference_mode()
def compute_parameter_stats(module: nn.Module) -> dict[str, dict[str, float]]:
    stats = {'norm': {}, 'gradnorm': {}, 'gradratio': {}}
    for name, parameter in module.named_parameters():
        stats['norm'][name] = parameter.norm().item()
        if parameter.grad is not None:
            stats['gradnorm'][name] = parameter.grad.norm().item()
            # Avoid computing statistics for zero-initialized parameters.
            if (parameter.abs() > 1e-6).any():
                stats['gradratio'][name] = (
                    (parameter.grad.abs() / parameter.abs().clamp_min_(1e-6))
                    .mean()
                    .item()
                )
    stats['norm']['model'] = (
        torch.cat([x.flatten() for x in module.parameters()]).norm().item()
    )
    stats['gradnorm']['model'] = (
        torch.cat([x.grad.flatten() for x in module.parameters() if x.grad is not None])
        .norm()
        .item()
    )
    return stats


# ======================================================================================
# Optimization
# ======================================================================================
def default_zero_weight_decay_condition(
    module_name: str, module: nn.Module, parameter_name: str, parameter: Parameter
):
    from rtdl_num_embeddings import _Periodic

    del module_name, parameter
    return parameter_name.endswith('bias') or isinstance(
        module,
        nn.BatchNorm1d
        | nn.LayerNorm
        | nn.InstanceNorm1d
        | rtdl_revisiting_models.LinearEmbeddings
        | rtdl_num_embeddings.LinearEmbeddings
        | rtdl_num_embeddings.LinearReLUEmbeddings
        | _Periodic,
    )


def make_parameter_groups(
    module: nn.Module,
    zero_weight_decay_condition=default_zero_weight_decay_condition,
    custom_groups: None | list[dict[str, Any]] = None,
) -> list[dict[str, Any]]:
    if custom_groups is None:
        custom_groups = []
    custom_params = frozenset(
        itertools.chain.from_iterable(group['params'] for group in custom_groups)
    )
    assert len(custom_params) == sum(
        len(group['params']) for group in custom_groups
    ), 'Parameters in custom_groups must not intersect'
    zero_wd_params = frozenset(
        p
        for mn, m in module.named_modules()
        for pn, p in m.named_parameters()
        if p not in custom_params and zero_weight_decay_condition(mn, m, pn, p)
    )
    default_group = {
        'params': [
            p
            for p in module.parameters()
            if p not in custom_params and p not in zero_wd_params
        ]
    }
    return [
        default_group,
        {'params': list(zero_wd_params), 'weight_decay': 0.0},
        *custom_groups,
    ]

def make_optimizer(type: str, **kwargs) -> torch.optim.Optimizer:
    Optimizer = getattr(torch.optim, type)
    return Optimizer(**kwargs)

def get_loss_fn(task_type: TaskType, **kwargs) -> Callable[..., Tensor]:
    loss_fn = (
        F.binary_cross_entropy_with_logits
        if task_type == TaskType.BINCLASS
        else F.cross_entropy
        if task_type == TaskType.MULTICLASS
        else F.mse_loss
    )
    return partial(loss_fn, **kwargs) if kwargs else loss_fn  # type: ignore[return-value,arg-type]

def zero_grad_forward_backward(
    optimizer: torch.optim.Optimizer,
    step_fn: Callable[[Tensor], Tensor],  # step_fn: chunk_idx -> loss
    batch_idx: Tensor,
    chunk_size: int,
) -> tuple[Tensor, int]:
    batch_size = len(batch_idx)
    loss = None
    while chunk_size != 0:
        optimizer.zero_grad()

        try:
            if batch_size <= chunk_size:
                # The simple forward-backward.
                loss = step_fn(batch_idx)
                loss.backward()
            else:
                # Forward-backward by chunks.
                # Mathematically, this is equivalent to the simple forward-backward.
                # Technically, this implementations uses less memory.
                loss = None
                for chunk_idx in batch_idx.split(chunk_size):
                    chunk_loss = step_fn(chunk_idx)
                    chunk_loss = chunk_loss * (len(chunk_idx) / batch_size)
                    chunk_loss.backward()
                    if loss is None:
                        loss = chunk_loss.detach()
                    else:
                        loss += chunk_loss.detach()
        except RuntimeError as err:
            if not is_oom_exception(err):
                raise
            delu.cuda.free_memory()
            chunk_size //= 2

        else:
            break

    if not chunk_size:
        raise RuntimeError('Not enough memory even for chunk_size=1')
    return cast(Tensor, loss), chunk_size

def extract_feature_threshold_gains(model: xgb.XGBModel) -> dict:
    booster = model.get_booster()
    df = booster.trees_to_dataframe()

    feature_threshold_gain = defaultdict(lambda: defaultdict(float))

    for _, row in df.iterrows():
        if row['Feature'] != 'Leaf':
            feature = row['Feature']
            threshold = row['Split']
            gain = row['Gain']
            feature_threshold_gain[feature][threshold] += gain

    return {f: dict(thresh_gain) for f, thresh_gain in feature_threshold_gain.items()}

def select_top_thresholds_global(threshold_gain_dict: dict, n_feats: int, n_bins: int) -> list:
    all_thresholds = []
    for feat, thresh_gain in threshold_gain_dict.items():
        for thresh, gain in thresh_gain.items():
            all_thresholds.append((feat, thresh, gain))
    
    all_thresholds.sort(key=lambda x: -x[2])

    top_k = n_feats * n_bins
    selected = all_thresholds[:top_k]

    feature_thresholds = defaultdict(list)
    for feat, thresh, _ in selected:
        feature_thresholds[feat].append(thresh)

    output = []
    for i in range(n_feats):
        feat = f"num_{i}"
        thresholds = sorted(feature_thresholds.get(feat, []))
        output.append(thresholds)

    return output

# modified from https://github.com/yandex-research/rtdl-num-embeddings/blob/main/package/rtdl_num_embeddings.py
def compute_bins(
    X: torch.Tensor,
    X_cat: torch.Tensor,
    n_bins: int = 48,
    *,
    tree_kwargs: Optional[dict[str, Any]] = None,
    y: Optional[Tensor] = None,
    regression: Optional[bool] = None,
    verbose: bool = False,
) -> list[Tensor]:
    """Compute the bin boundaries for `PiecewiseLinearEncoding` and `PiecewiseLinearEmbeddings`.

    **Usage**

    Compute bins using quantiles (Section 3.2.1 in the paper):

    >>> X_train = torch.randn(10000, 2)
    >>> bins = compute_bins(X_train)

    Compute bins using decision trees (Section 3.2.2 in the paper):

    >>> X_train = torch.randn(10000, 2)
    >>> y_train = torch.randn(len(X_train))
    >>> bins = compute_bins(
    ...     X_train,
    ...     y=y_train,
    ...     regression=True,
    ...     tree_kwargs={'min_samples_leaf': 64, 'min_impurity_decrease': 1e-4},
    ... )

    Args:
        X: the training features.
        n_bins: the number of bins.
        tree_kwargs: keyword arguments for `sklearn.tree.DecisionTreeRegressor`
            (if ``regression=True``) or `sklearn.tree.DecisionTreeClassifier`
            (if ``regression=False``).
            NOTE: requires ``scikit-learn>=1.0,>2`` to be installed.
        y: the training labels (must be provided if ``tree`` is not None).
        regression: whether the labels are regression labels
            (must be provided if ``tree`` is not None).
        verbose: if True and ``tree_kwargs`` is not None, than ``tqdm``
            (must be installed) will report the progress while fitting trees.

    Returns:
        A list of bin edges for all features. For one feature:

        - the maximum possible number of bin edges is ``n_bins + 1``.
        - the minimum possible number of bin edges is ``1``.
    """  # noqa: E501
    
    if not isinstance(X, Tensor):
        raise ValueError(f'X must be a PyTorch tensor, however: {type(X)=}')
    if X.ndim != 2:
        raise ValueError(f'X must have exactly two dimensions, however: {X.ndim=}')
    if X.shape[0] < 2:
        raise ValueError(f'X must have at least two rows, however: {X.shape[0]=}')
    if X.shape[1] < 1:
        raise ValueError(f'X must have at least one column, however: {X.shape[1]=}')
    if not X.isfinite().all():
        raise ValueError('X must not contain nan/inf/-inf.')
    if (X == X[0]).all(dim=0).any():
        raise ValueError(
            'All columns of X must have at least two distinct values.'
            ' However, X contains columns with just one distinct value.'
        )
    if n_bins <= 1 or n_bins >= len(X):
        raise ValueError(
            'n_bins must be more than 1, but less than len(X), however:'
            f' {n_bins=}, {len(X)=}'
        )

    if tree_kwargs is None:
        if y is not None or regression is not None or verbose:
            raise ValueError(
                'If tree_kwargs is None, then y must be None, regression must be None'
                ' and verbose must be False'
            )

        _upper = 2**24  # 16_777_216
        if len(X) > _upper:
            warnings.warn(
                f'Computing quantile-based bins for more than {_upper} million objects'
                ' may not be possible due to the limitation of PyTorch'
                ' (for details, see https://github.com/pytorch/pytorch/issues/64947;'
                ' if that issue is successfully resolved, this warning may be irrelevant).'  # noqa
                ' As a workaround, subsample the data, i.e. instead of'
                '\ncompute_bins(X, ...)'
                '\ndo'
                '\ncompute_bins(X[torch.randperm(len(X), device=X.device)[:16_777_216]], ...)'  # noqa
                '\nOn CUDA, the computation can still fail with OOM even after'
                ' subsampling. If this is the case, try passing features by groups:'
                '\nbins = sum('
                '\n    compute_bins(X[:, idx], ...)'
                '\n    for idx in torch.arange(len(X), device=X.device).split(group_size),'  # noqa
                '\n    start=[]'
                '\n)'
                '\nAnother option is to perform the computation on CPU:'
                '\ncompute_bins(X.cpu(), ...)'
            )
        del _upper

        # NOTE[DIFF]
        # The code below is more correct than the original implementation,
        # because the original implementation contains an unintentional divergence
        # from what is written in the paper. That divergence affected only the
        # quantile-based embeddings, but not the tree-based embeddings.
        # For historical reference, here is the original, less correct, implementation:
        # https://github.com/yandex-research/tabular-dl-num-embeddings/blob/c1d9eb63c0685b51d7e1bc081cdce6ffdb8886a8/bin/train4.py#L612C30-L612C30
        # (explanation: limiting the number of quantiles by the number of distinct
        #  values is NOT the same as removing identical quantiles after computing them).
        bins = [
            q.unique()
            for q in torch.quantile(
                X, torch.linspace(0.0, 1.0, n_bins + 1).to(X), dim=0
            ).T
        ]
        _check_bins(bins)
        return bins

    else:
        if sklearn_tree is None:
            raise RuntimeError(
                'The scikit-learn package is missing.'
                ' See README.md for installation instructions'
            )
        if y is None or regression is None:
            raise ValueError(
                'If tree_kwargs is not None, then y and regression must not be None'
            )
        if y.ndim != 1:
            raise ValueError(f'y must have exactly one dimension, however: {y.ndim=}')
        if len(y) != len(X):
            raise ValueError(
                f'len(y) must be equal to len(X), however: {len(y)=}, {len(X)=}'
            )
        if y is None or regression is None:
            raise ValueError(
                'If tree_kwargs is not None, then y and regression must not be None'
            )
        if 'max_leaf_nodes' in tree_kwargs:
            raise ValueError(
                'tree_kwargs must not contain the key "max_leaf_nodes"'
                ' (it will be set to n_bins automatically).'
            )

        if verbose:
            if tqdm is None:
                raise ImportError('If verbose is True, tqdm must be installed')
            tqdm_ = tqdm
        else:
            tqdm_ = lambda x: x  # noqa: E731

        if X.device.type != 'cpu' or y.device.type != 'cpu':
            warnings.warn(
                'Computing tree-based bins involves the conversion of the input PyTorch'
                ' tensors to NumPy arrays. The provided PyTorch tensors are not'
                ' located on CPU, so the conversion has some overhead.',
                UserWarning,
            )
        X_numpy = X.cpu().numpy()
        y_numpy = y.cpu().numpy()
        tree_type = tree_kwargs.pop('type', 'default')
        if tree_type == 'default':
            bins = []
            for column in tqdm_(X_numpy.T):
                feature_bin_edges = [float(column.min()), float(column.max())]
                tree = (
                    (
                        sklearn_tree.DecisionTreeRegressor
                        if regression
                        else sklearn_tree.DecisionTreeClassifier
                    )(max_leaf_nodes=n_bins, **tree_kwargs)
                    .fit(column.reshape(-1, 1), y_numpy)
                    .tree_
                )
                for node_id in range(tree.node_count):
                    # The following condition is True only for split nodes. Source:
                    # https://scikit-learn.org/1.0/auto_examples/tree/plot_unveil_tree_structure.html#tree-structure
                    if tree.children_left[node_id] != tree.children_right[node_id]:
                        feature_bin_edges.append(float(tree.threshold[node_id]))
                bins.append(torch.as_tensor(feature_bin_edges).unique())

        elif tree_type == 'xgb_global':
            @memory.cache
            def fit_xgboost(model_cls, X_df, y_numpy):
                n_jobs = 16 if X_df.shape[0] > 30000 and X_df.shape[1] > 500 else 1
                model = model_cls(
                    enable_categorical=True,  # Needed to make use of category dtype
                    tree_method="hist",       # Required for categorical support
                    n_jobs=n_jobs,
                )
                model.fit(X_df, y_numpy)
                threshold_gain_dict = extract_feature_threshold_gains(model)
                return threshold_gain_dict

                
            bins = []
            num_columns = [f"num_{i}" for i in range(X_numpy.shape[1])]
            df_num = pd.DataFrame(X_numpy, columns=num_columns)
            if X_cat is None:
                X_df = df_num
            else:
                X_cat = X_cat.cpu().numpy()
                cat_columns = [f"cat_{i}" for i in range(X_cat.shape[1])]
                df_cat = pd.DataFrame(X_cat, columns=cat_columns)
                for col in df_cat.columns:
                    df_cat[col] = df_cat[col].astype("category")
                X_df = pd.concat([df_num, df_cat], axis=1)
            model_cls = xgb.XGBRegressor if regression else xgb.XGBClassifier
            threshold_gain_dict = fit_xgboost(model_cls, X_df, y_numpy)
            bins = select_top_thresholds_global(threshold_gain_dict, X.shape[1], n_bins - 1)
            for i in range(X.shape[1]):
                bins[i].extend([X_numpy[:, i].min(), X_numpy[:, i].max()])
            bins = [torch.as_tensor(v).unique() for v in bins]

        else:
            assert 0
        
        _check_bins(bins)
        
        return [x.to(device=X.device, dtype=X.dtype) for x in bins]
