# License: https://github.com/yandex-research/tabm/blob/main/LICENSE

import itertools
from typing import Any, Literal, Optional

import torch
import torch.nn as nn
from torch import Tensor
import numpy as np

from model_utils.TabM import *

# def swish(x):
#     return x * torch.sigmoid(x)

# def calc_diffusion_step_embedding(device, diffusion_steps, diffusion_step_embed_dim_in):
#     """
#     Embed a diffusion step $t$ into a higher dimensional space
#     E.g. the embedding vector in the 128-dimensional space is
#     [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]

#     Parameters:
#     diffusion_steps (torch.long tensor, shape=(batchsize, 1)):     
#                                 diffusion steps for batch data
#     diffusion_step_embed_dim_in (int, default=128):  
#                                 dimensionality of the embedding space for discrete diffusion steps
    
#     Returns:
#     the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
#     """

#     assert diffusion_step_embed_dim_in % 2 == 0

#     half_dim = diffusion_step_embed_dim_in // 2
#     _embed = np.log(10000) / (half_dim - 1)
#     _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device)
#     _embed = diffusion_steps * _embed
#     diffusion_step_embed = torch.cat((torch.sin(_embed),
#                                       torch.cos(_embed)), 1)

#     return diffusion_step_embed



# ======================================================================================
# The model
# ======================================================================================
class TabM(nn.Module):
    """MLP & TabM."""

    def __init__(
        self,
        *,
        config,
        device,
        n_num_features: int,
        cat_cardinalities: list[int],
        n_output: Optional[int],
        backbone: dict,
        bins,  # For piecewise-linear encoding/embeddings.
        num_embeddings,
        arch_type: Literal[
            # Plain feed-forward network without any kind of ensembling.
            'plain',
            #
            # TabM
            'tabm',
            #
            # TabM-mini
            'tabm-mini',
            #
            # TabM-packed
            'tabm-packed',
            #
            # TabM. The first adapter is initialized from the normal distribution.
            # This variant was not used in the paper, but it may be useful in practice.
            'tabm-normal',
            #
            # TabM-mini. The adapter is initialized from the normal distribution.
            # This variant was not used in the paper.
            'tabm-mini-normal',
        ],
        k,
        share_training_batches: bool = True,
        # diffusion_step_embed_dim_in=512, diffusion_step_embed_dim_mid=1024
        diffusion_step_embed_dim_in=1024, diffusion_step_embed_dim_mid=2048
    ) -> None:
        # >>> Validate arguments.
        assert n_num_features >= 0
        assert n_num_features or cat_cardinalities
        if arch_type == 'plain':
            assert k is None
            assert (
                share_training_batches
            ), 'If `arch_type` is set to "plain", then `simple` must remain True'
        else:
            assert k is not None
            assert k > 0

        super().__init__()
        self.device = device
        if config["model_type"] == "CDTD":
            self.add_noise = False
        else:
            self.add_noise = True
        self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
        # >>> Continuous (numerical) features
        first_adapter_sections = []  # See the comment in `_init_first_adapter`.

        if n_num_features == 0:
            assert bins is None
            self.num_module = None
            d_num = 0

        elif num_embeddings is None:
            assert bins is None
            self.num_module = None
            d_num = n_num_features
            first_adapter_sections.extend(1 for _ in range(n_num_features))

        else:
            if bins is None:
                self.num_module = make_module(
                    **num_embeddings, n_features=n_num_features
                )
            else:
                assert num_embeddings['type'].startswith('PiecewiseLinearEmbeddings')
                self.num_module = make_module(**num_embeddings, bins=bins)
            d_num = n_num_features * num_embeddings['d_embedding']
            first_adapter_sections.extend(
                num_embeddings['d_embedding'] for _ in range(n_num_features)
            )

        # >>> Categorical features
        self.cat_module = (
            OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
        )
        first_adapter_sections.extend(cat_cardinalities)
        d_cat = sum(cat_cardinalities)

        # >>> Backbone
        d_flat = d_num + d_cat
        self.d_backbone_in = d_flat
        self.minimal_ensemble_adapter = None
        # Any backbone can be here but we provide only MLP
        self.backbone = make_module(d_in=d_flat, **backbone)

        if arch_type != 'plain':
            assert k is not None
            first_adapter_init = (
                None
                if arch_type == 'tabm-packed'
                else 'normal'
                if arch_type in ('tabm-mini-normal', 'tabm-normal')
                # For other arch_types, the initialization depends
                # on the presense of num_embeddings.
                else 'random-signs'
                if num_embeddings is None
                else 'normal'
            )

            if arch_type in ('tabm', 'tabm-normal'):
                # Like BatchEnsemble, but all multiplicative adapters,
                # except for the very first one, are initialized with ones.
                assert first_adapter_init is not None
                make_efficient_ensemble(
                    self.backbone,
                    LinearEfficientEnsemble,
                    k=k,
                    ensemble_scaling_in=True,
                    ensemble_scaling_out=True,
                    ensemble_bias=True,
                    scaling_init='ones',
                )
                _init_first_adapter(
                    _get_first_ensemble_layer(self.backbone).r,  # type: ignore[code]
                    first_adapter_init,
                    first_adapter_sections,
                )

            elif arch_type in ('tabm-mini', 'tabm-mini-normal'):
                # MiniEnsemble
                assert first_adapter_init is not None
                self.minimal_ensemble_adapter = ScaleEnsemble(
                    k,
                    d_flat,
                    init='random-signs' if num_embeddings is None else 'normal',
                )
                _init_first_adapter(
                    self.minimal_ensemble_adapter.weight,  # type: ignore[code]
                    first_adapter_init,
                    first_adapter_sections,
                )

            elif arch_type == 'tabm-packed':
                # Packed ensemble.
                # In terms of the Packed Ensembles paper by Laurent et al.,
                # TabM-packed is PackedEnsemble(alpha=k, M=k, gamma=1).
                assert first_adapter_init is None
                make_efficient_ensemble(self.backbone, NLinear, n=k)

            else:
                raise ValueError(f'Unknown arch_type: {arch_type}')

        # >>> Output
        d_block = backbone['d_block']
        d_out = n_output
        self.output = (
            nn.Linear(d_block, d_out)
            if arch_type == 'plain'
            else NLinear(k, d_block, d_out)  # type: ignore[code]
        )

        # >>>
        self.arch_type = arch_type
        self.k = k
        self.share_training_batches = share_training_batches

        # >>> Timestep embedding (layers)
        # self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
        # self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, d_flat)


    def forward(
        # self, x_num, noise_labels) -> Tensor:
        self, x_num) -> Tensor:
        # print("in", x_num.shape)
        #print("begin", x_num.shape)
        # x_num = x_num.squeeze(dim=1)
        # print(x_num.shape)
        # B = noise_labels.shape[0]
        # noise_labels = noise_labels.view(B, 1)
        # noise_labels_embed = calc_diffusion_step_embedding(self.device, noise_labels, self.diffusion_step_embed_dim_in)
        #print("test0", noise_labels_embed.shape)
        #print(noise_labels.shape)
        # noise_labels_embed = swish(self.fc_t1(noise_labels_embed))
        # noise_labels_embed = swish(self.fc_t2(noise_labels_embed))

        #print("test 1: ",  x.shape)
        #noise_labels_embed = noise_labels_embed.unsqueeze(1)
        #print("", noise_labels_embed.shape)

        
        x = []
        if x_num is not None:
            x.append(x_num if self.num_module is None else self.num_module(x_num))
        x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
        if self.add_noise: # For models where noise is already added ex. CDTD
            #print("X before", x.shape)
            x = x #+ noise_labels_embed
            #print("X after", x.shape)

        if self.k is not None:
            if self.share_training_batches or not self.training:
                # (B, D) -> (B, K, D)
                x = x[:, None].expand(-1, self.k, -1)
            else:
                # (B * K, D) -> (B, K, D)
                x = x.reshape(len(x) // self.k, self.k, *x.shape[1:])
            if self.minimal_ensemble_adapter is not None:
                x = self.minimal_ensemble_adapter(x)
        else:
            assert self.minimal_ensemble_adapter is None
        
        

            #F1_x = self.proj(x) + noise_labels_embed #Aus Paper wo tabellendaten erzeugt werden wo tabellendaten
        #else:
            #F1_x = self.proj(x)
        #print("x back", x.shape)
        x = self.backbone(x)
        x = self.output(x)

        if self.k is None:
            # Adjust the output shape for plain networks to make them compatible
            # with the rest of the script (loss, metrics, predictions, ...).
            # (B, D_OUT) -> (B, 1, D_OUT)
            x = x[:, None]
        #print("out", x)
            #print("eval", torch.mean(x, dim=1).unsqueeze(dim=1).shape)
            
        return torch.mean(x, dim=1).unsqueeze(dim=1)
    


@torch.inference_mode()
def _init_first_adapter(
    weight: Tensor,
    distribution: Literal['normal', 'random-signs'],
    init_sections: list[int],
) -> None:
    """Initialize the first adapter.

    NOTE
    The `init_sections` argument is a historical artifact that accidentally leaked
    from irrelevant experiments to the final models. Perhaps, the code related
    to `init_sections` can be simply removed, but this was not tested.
    """
    assert weight.ndim == 2
    assert weight.shape[1] == sum(init_sections)

    if distribution == 'normal':
        init_fn_ = nn.init.normal_
    elif distribution == 'random-signs':
        init_fn_ = init_random_signs_
    else:
        raise ValueError(f'Unknown distribution: {distribution}')

    section_bounds = [0, *torch.tensor(init_sections).cumsum(0).tolist()]
    for i in range(len(init_sections)):
        # NOTE
        # As noted above, this section-based initialization is an arbitrary historical
        # artifact. Consider the first adapter of one ensemble member.
        # This adapter vector is implicitly split into "sections",
        # where one section corresponds to one feature. The code below ensures that
        # the adapter weights in one section are initialized with the same random value
        # from the given distribution.
        w = torch.empty((len(weight), 1), dtype=weight.dtype, device=weight.device)
        init_fn_(w)
        weight[:, section_bounds[i] : section_bounds[i + 1]] = w


def _get_first_ensemble_layer(backbone: MLP) -> LinearEfficientEnsemble:
    if isinstance(backbone, MLP):
        return backbone.blocks[0][0]  # type: ignore[code]
    else:
        raise RuntimeError(f'Unsupported backbone: {backbone}')