import pandas as pd
import torch
from typing import Callable

from .model.wgf_imp import NeuralGradFlowImputer
from ...base import *
from ...base.base import _IDENTITY


class KnewImp(
    BaseImputerMixIn,
    Base
):
    """
    https://github.com/JustusvLiebig/NewImp
    Chen, Z., Li, H., Wang, F., Zhang, O., Xu, H., Jiang, X., ... & Wang, E. H. (2024).
    Rethinking the Diffusion Models for Numerical Tabular Data Imputation from the Perspective of
    Wasserstein Gradient Flow.
    CoRR.
    """

    def __init__(self, batch_size, args, kernel_batch_size=4096, **kwargs):
        super().__init__()
        self.model = NeuralGradFlowImputer(batchsize=batch_size, kernel_batch_size=kernel_batch_size, **args)
        self.register_buffer('mean', torch.zeros(1, self.column_dim))

    def normalize(self, df_or_tensor, *args, **kwargs) -> torch.Tensor:
        x = df_or_tensor
        if isinstance(x, pd.DataFrame):
            x = self.tabular_transform.transform(x, return_as_tensor=True)

        x = x.to(self.device, self.dtype)
        return x

    def denormalize(self, tensor: torch.Tensor, *args, **kwargs):
        x = tensor.to(self.device, self.dtype)

        num = x[:, :self.numerical_dim]
        cat = x[:, self.numerical_dim:]
        # View/reshape in case the caller supplied a flattened batch
        return num.view(len(num), -1), cat.view(len(cat), -1)

    def fit(self, scenario: Callable[[pd.DataFrame], pd.DataFrame] = _IDENTITY):  # type: ignore[override]
        cfg = self._cfg
        train_df = pd.read_csv(cfg.dataset.train_path)
        train_df = scenario(train_df)
        self._transform.fit(train_df)
        data = self.tabular_transform.transform(train_df, return_as_tensor=True)
        self.mean.data = data.nanmean(dim=0, keepdim=True)
        self.model.fit_transform(self.normalize(data).numpy())
        return super().fit(scenario)

    def _impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
        tokens = self.tabular_transform.transform(df, return_as_tensor=True).to(self.device, self.dtype)
        nan_mask = tokens.isnan()
        X = self.normalize(tokens)

        X[nan_mask] = (self.model.noise * torch.randn(nan_mask.shape, device=X.device, dtype=X.dtype) + self.mean)[
            nan_mask]
        imputed = self.model.knew_imp_sampling_minibatch(X, score_func=self.model.score_net.functorch_score,
                                                         iter_steps=self.model.sampling_step,
                                                         mask_matrix=~nan_mask)
        num, cat = self.denormalize(imputed)
        return self.tabular_transform.inverse_transform(num, cat)
