import pandas as pd
import torch
from scipy import optimize
from typing import Union
from impugen.transform import TabularTransform, TabularTransformSet
from hydra.utils import get_method

__all__ = ['simulate_nan', 'missing_completely_at_random', 'missing_at_random', 'missing_not_at_random']


def simulate_nan(cfg, transform: TabularTransformSet, data: pd.DataFrame, drop_observed: bool = False) -> pd.DataFrame:
    """
    Generate a DataFrame with artificially introduced missing values according to
    a specified missingness mechanism (MCAR, MAR, or MNAR).

    The target column in `data` is not masked by default. If `drop_observed` is True,
    then only the cells that are already marked by the missingness mechanism remain
    (the rest are replaced with NaN). Otherwise, the chosen mechanism will replace
    a certain subset of cells with NaN.

    Args:
        cfg: A configuration object that includes the missingness specification, e.g.
             cfg.missing._target_ = "path.to.missing_completely_at_random"
             or one of the other missingness functions.
        transform (TabularTransformSet): A transform set object that includes columns and target info.
        data (pd.DataFrame): The original dataset from which missingness will be introduced.
        drop_observed (bool, optional): If True, replaces the non-masked cells with NaN in the
            feature columns, and sets the target column to NaN entirely.

    Returns:
        pd.DataFrame: A new DataFrame with missing values introduced according to `cfg.missing`.
    """
    x = data.copy()[transform.no_target.columns]   # Feature columns
    if transform.target_column is not None:
        y = data.copy()[transform.target_column]         # Target column
    else:
        y = data.copy().iloc[:, 0:0]

    # Retrieve the chosen missingness function from the config
    mask = get_method(cfg.missing._target_)(x, transform=transform.no_target, **cfg.missing)

    if drop_observed:
        x[~mask] = pd.NA
        y[:] = pd.NA
    else:
        x[mask] = pd.NA

    data = pd.concat([x, y], axis=1)
    return data


def dataframe_only(func):
    """
    A decorator that allows the wrapped missingness function to accept a pandas DataFrame
    instead of a PyTorch tensor. Inside the decorated function:
      1. The DataFrame is checked for existing NaN values, which are stored as a mask.
      2. The DataFrame is transformed to a tensor via the TabularTransform, ignoring
         the target column if necessary.
      3. The wrapped missingness function is called on the tensor to generate a new mask.
      4. The result is mapped back to a DataFrame to form a final mask indicating which
         cells are newly considered missing.

    Args:
        func (callable):
            The missingness function to wrap (e.g. missing_completely_at_random).

    Returns:
        callable: A function that accepts a DataFrame and returns a boolean mask of missing cells.
    """

    def decorated(
        x: pd.DataFrame,
        p: float,
        p_obs: float = None,
        p_params: float = None,
        random_state: int = 42,
        transform: TabularTransform = None,
        **kwargs
    ):
        """
        Wrapped version of the missingness function. Converts a DataFrame to a tensor, applies
        `func` to create a mask, and then reverts to DataFrame form to compute the final
        mask.

        Args:
            x (pd.DataFrame): The input DataFrame.
            p (float): The missingness proportion or probability target.
            p_obs (float, optional): Additional parameter used by certain missingness patterns (MAR).
            p_params (float, optional): Additional parameter used by certain missingness patterns (MNAR).
            random_state (int): Random seed for reproducibility.
            transform (TabularTransform, optional): Transformer used to convert DataFrame to tensor.
            **kwargs: Additional arguments forwarded to the missingness function `func`.

        Returns:
            np.ndarray: Boolean mask of the same shape as the original DataFrame `x`,
            indicating which cells are newly set to NaN.
        """
        assert isinstance(x, pd.DataFrame) and transform is not None, \
            "Input must be a DataFrame and a TabularTransform must be provided."

        # Track original NaNs in the data
        orig_nan_mask = x.isna().to_numpy()

        # Convert to tensor
        x_tensor = transform.transform(
            x, return_as_tensor=True, scaler='none', onehot=False
        ).cpu()
        # Ensure no NaNs break the missingness logic
        x_tensor = x_tensor.nan_to_num(0.5)

        # Generate mask using the core missingness function
        mask = func(x=x_tensor, p=p, p_obs=p_obs, p_params=p_params, random_state=random_state, **kwargs)

        # Apply the mask to x_tensor
        x_tensor[mask] = torch.nan

        # Convert back to DataFrame
        df = transform.inverse_transform(x_tensor, scaler='none', onehot=False)
        # The final mask is the difference between new NaNs and original NaNs
        final_mask = df.isna().to_numpy() != orig_nan_mask
        return final_mask

    return decorated


@dataframe_only
def missing_completely_at_random(
    x: Union[torch.Tensor, pd.DataFrame],
    p: float,
    random_state: int = 42,
    **kwargs
) -> torch.Tensor:
    """
    Generate a mask of missing values completely at random (MCAR).
    Each cell has a probability `p` of being missing, independently.

    Args:
        x (torch.Tensor): A 2D tensor of shape (N, D).
        p (float): Probability of assigning a missing value to a cell.
        random_state (int, optional): Random seed for reproducibility.
        **kwargs: Unused additional arguments.

    Returns:
        torch.Tensor: Boolean mask of the same shape as `x`. True indicates a missing cell.
    """
    generator = torch.Generator(device=x.device).manual_seed(random_state)
    mask = torch.rand(x.shape, device=x.device, dtype=torch.float32, generator=generator) < p
    return mask


@dataframe_only
def missing_at_random(
    x: torch.Tensor,
    p: float,
    p_obs: float,
    random_state: int = 42,
    **kwargs
) -> torch.Tensor:
    """
    Generate a mask of missing values under a Missing At Random (MAR) scenario.
    A subset of columns (d_obs) is used as "observed" or predictor columns. Another subset (d_na)
    is used as "missing" columns. For each row, a logistic model is used to determine the
    probability of a missing cell, based on the observed columns.

    Args:
        x (torch.Tensor): A 2D tensor of shape (N, D).
        p (float): Overall missing rate to achieve.
        p_obs (float): Fraction of columns used as observed/predictors. Must be in (0, 1].
        random_state (int, optional): Random seed for reproducibility.
        **kwargs: Additional arguments (unused here).

    Returns:
        torch.Tensor: Boolean mask of shape (N, D). True indicates a missing cell.
    """
    generator = torch.Generator(device=x.device).manual_seed(random_state)
    n, d = x.shape
    mask = torch.zeros(x.shape, device=x.device, dtype=torch.bool)

    # Shuffle columns to pick observed vs. missing subset
    idx = torch.randperm(d, generator=generator)
    d_obs = max(int(p_obs * d), 1)
    d_na = d - d_obs
    idx_obs, idx_nas = idx[:d_obs], idx[d_obs:]

    # Generate random coefficients and intercepts for a logistic model
    coeffs = _pick_coeffs(x, idx_obs, idx_nas, generator=generator)
    intercepts = _fit_intercepts(x[:, idx_obs], coeffs, p)

    # Sigmoid probabilities for each row => probability of missing
    ps = torch.sigmoid(x[:, idx_obs].mm(coeffs) + intercepts)

    # Draw Bernoulli to form mask for missing columns
    ber = torch.rand(n, d_na, generator=generator)
    mask[:, idx_nas] = ber < ps

    return mask


@dataframe_only
def missing_not_at_random(
    x: torch.Tensor,
    p: float,
    p_params: float,
    exclude_inputs: bool = True,
    random_state: int = 42,
    **kwargs
) -> torch.Tensor:
    """
    Generate a mask of missing values under a Missing Not At Random (MNAR) scenario.

    This can create missingness in the same columns that define the missingness process
    (self-masking), or in other columns, depending on `exclude_inputs`.

    Steps:
      1. Randomly select a subset of columns to use in the logistic model (if exclude_inputs=True).
      2. Use the logistic model to set missingness in the other columns (or the same if exclude_inputs=False).
      3. Additionally, if exclude_inputs=True, the "input" columns used for the logistic model
         are randomly masked at rate `p`.

    Args:
        x (torch.Tensor): A 2D tensor of shape (N, D).
        p (float): Overall missing rate for the input columns (if exclude_inputs=True).
        p_params (float): Fraction of columns used in the logistic model or as "inputs."
        exclude_inputs (bool):
            If True, the logistic model defines missingness in the other columns only.
            If False, the logistic model can define missingness in the same columns it depends on.
        random_state (int, optional): Random seed for reproducibility.
        **kwargs: Additional arguments (unused here).

    Returns:
        torch.Tensor: Boolean mask of shape (N, D). True indicates a missing cell.
    """
    generator = torch.Generator(device=x.device).manual_seed(random_state)
    n, d = x.shape
    mask = torch.zeros(x.shape, device=x.device, dtype=torch.bool)

    # Determine how many columns define the logistic model (d_params)
    # and how many are masked by that model (d_na).
    d_params = max(int(p_params * d), 1) if exclude_inputs else d
    d_na = d - d_params if exclude_inputs else d

    idx = torch.randperm(d, generator=generator)
    idx_params = idx[:d_params] if exclude_inputs else torch.arange(d)
    idx_nas = idx[d_params:] if exclude_inputs else torch.arange(d)

    # Generate random coefficients for logistic model
    coeffs = _pick_coeffs(x, idx_params, idx_nas, generator=generator)
    intercepts = _fit_intercepts(x[:, idx_params], coeffs, p)

    # Probability of missing
    ps = torch.sigmoid(x[:, idx_params].mm(coeffs) + intercepts)
    ber = torch.rand(n, d_na, generator=generator)

    mask[:, idx_nas] = ber < ps

    # Additional missingness in the logistic input columns, if exclude_inputs=True
    if exclude_inputs:
        mask[:, idx_params] = torch.rand(n, d_params, generator=generator) < p

    return mask


def _pick_coeffs(
    x: torch.Tensor,
    idx_obs: torch.Tensor,
    idx_nas: torch.Tensor,
    generator: torch.Generator
) -> torch.Tensor:
    """
    Generate random coefficients to map from observed columns (idx_obs) to
    "missing" columns (idx_nas). Normalize the coefficients so that the linear
    combination has a reasonable scale.

    Args:
        x (torch.Tensor): The data array of shape (N, D).
        idx_obs (torch.Tensor): Indices of columns used as predictors in the logistic model.
        idx_nas (torch.Tensor): Indices of columns that will be assigned missingness.
        generator (torch.Generator): Random generator for reproducibility.

    Returns:
        torch.Tensor: A matrix of shape (len(idx_obs), len(idx_nas)) containing coefficients.
    """
    d_obs = len(idx_obs)
    d_na = len(idx_nas)
    coeffs = torch.randn(d_obs, d_na, dtype=x.dtype, generator=generator)

    # Normalize scale
    w = x[:, idx_obs].mm(coeffs)
    coeffs /= torch.std(w, 0, keepdim=True)
    return coeffs


def _fit_intercepts(X: torch.Tensor, coeffs: torch.Tensor, p: float) -> torch.Tensor:
    """
    For each column in the 'missing' set, find an intercept that achieves a mean
    missingness probability of p. This uses a simple bisection on the intercept
    to match the average predicted missingness to p.

    Args:
        X (torch.Tensor): The observed columns for the logistic model (N, d_obs).
        coeffs (torch.Tensor): The coefficient matrix of shape (d_obs, d_na).
        p (float): Desired average missingness rate.

    Returns:
        torch.Tensor: A 1D tensor of intercepts for each 'missing' column in coeffs.
    """
    d_obs, d_na = coeffs.shape
    intercepts = torch.zeros(d_na)

    for j in range(d_na):
        def f(bias):
            # Evaluate mean probability
            return torch.sigmoid(X.mv(coeffs[:, j]) + bias).mean().item() - p

        try:
            # Bisection to find intercept where mean(prob) = p
            intercept = optimize.bisect(f, -500, 500)
        except Exception:
            intercept = 0.  # If bisection fails, default intercept to 0
        intercepts[j] = intercept

    return intercepts
