"""
feature_mixins.py
=================

Composable mix-ins that add *specific* capabilities—imputation, unconditional
generation, class-conditional generation, and arbitrary-conditional generation—
to any subclass of :class:`~impugen.base.Base`.

Each mix-in is **stateless** and cooperates via Python’s MRO:
every public method that might be overridden ends with
``super().<method>(...)`` so that later mix-ins or the base class can
continue the call-chain.

Utilities
---------
`postprocess_df`
    Detach GPU tensors (by serialising to an in-memory CSV) and flush the CUDA
    cache so returned DataFrames live in host memory.
"""
from __future__ import annotations

import warnings
from abc import abstractmethod
from copy import deepcopy
from io import StringIO

import numpy as np
import pandas as pd
import torch

from ..utils import create_postprocess, seed_context
from ..utils.data import get_conditional_y_from_rare_targets
from ..utils.eval import evaluate_generation, evaluate_imbalance_learning, evaluate_missing, evaluate_prediciton


def postprocess_df(df: pd.DataFrame) -> pd.DataFrame:
    torch.cuda.empty_cache()
    buffer = StringIO(df.to_csv(index=False))
    return pd.read_csv(buffer)


class BaseAutoEncoderMixIn:
    d_model: int = 1  # no latent projection → one value per column

    def encode(self, df_or_tensor, *args, **kwargs) -> torch.Tensor:
        x = df_or_tensor
        if isinstance(x, pd.DataFrame):
            x = self.tabular_transform.transform(x, return_as_tensor=True)
        # Replace NaNs so downstream math won’t break
        x = x.nan_to_num().to(self.device, self.dtype)
        return x

    def decode(self, tensor: torch.Tensor, *args, **kwargs):
        tensor = tensor.to(self.device, self.dtype)
        num_dim = self.tabular_transform.numerical_dim

        num = tensor[:, :num_dim]
        cat = tensor[:, num_dim:]
        # View/reshape in case the caller supplied a flattened batch
        return num.view(len(num), -1), cat.view(len(cat), -1)


class BasePredictorMixIn:

    def __new__(cls, *args, **kwargs):
        cls.predict = seed_context(cls.predict)
        cls.predict_proba = seed_context(cls.predict_proba)
        return super().__new__(cls)

    @torch.no_grad()
    def predict(
            self,
            df: pd.DataFrame,
            batch_size: int = 4096,
            *,
            seed=None,  # passed to seed_context
            **kwargs,
    ) -> np.ndarray:
        self.save_model_mode()
        self.eval()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            df = df.copy(deep=True)
            n = len(df)
            chunks = [batch_size] * (n // batch_size)
            remainder = n % batch_size
            if remainder > 0:
                chunks.append(remainder)

            # Process data in batches
            prediction = []
            for idx_chunk in np.array_split(df.index, np.cumsum(chunks)):
                if len(idx_chunk) == 0:
                    continue
                minibatch = df.loc[idx_chunk, self.tabular_transform.columns]

                minibatch = minibatch.reset_index(drop=True).copy()
                minibatch[self._transform.target_column] = pd.NA

                prediction.append(self._predict(minibatch.reset_index(drop=True), **kwargs))

            self.load_model_mode()
            return np.concatenate(prediction, axis=0)

    @abstractmethod
    def _predict(self, df: pd.DataFrame, **kwargs) -> np.ndarray:
        raise NotImplementedError

    @torch.no_grad()
    def predict_proba(
            self,
            df: pd.DataFrame,
            batch_size: int = 4096,
            *,
            seed=None,  # passed to seed_context
            **kwargs,
    ) -> pd.DataFrame:
        if self.tgt not in self.tabular_transform.categorical_columns:
            return self.predict(df, batch_size, seed=seed, **kwargs)
        self.save_model_mode()
        self.eval()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            df = df.copy(deep=True)
            n = len(df)
            chunks = [batch_size] * (n // batch_size)
            remainder = n % batch_size
            if remainder > 0:
                chunks.append(remainder)

            # Process data in batches
            prediction = []
            for idx_chunk in np.array_split(df.index, np.cumsum(chunks)):
                if len(idx_chunk) == 0:
                    continue
                minibatch = df.loc[idx_chunk, self.tabular_transform.columns]

                minibatch = minibatch.reset_index(drop=True).copy()
                minibatch[self._transform.target_column] = pd.NA

                prediction.append(self._predict_proba(minibatch.reset_index(drop=True), **kwargs))

            self.load_model_mode()
            return np.concatenate(prediction, axis=0)

    @abstractmethod
    def _predict_proba(self, df: pd.DataFrame, **kwargs) -> np.ndarray:
        raise NotImplementedError


class BaseImputerMixIn:

    def __new__(cls, *args, **kwargs):
        cls.impute = create_postprocess(cls.impute, postprocess_df)
        cls.impute = seed_context(cls.impute)
        return super().__new__(cls)

    @torch.no_grad()
    def impute(
            self,
            df: pd.DataFrame,
            batch_size: int = 4096,
            *,
            seed=None,  # passed to seed_context
            mask_target_column: bool = False,
            **kwargs
    ) -> pd.DataFrame:
        self.save_model_mode()
        self.eval()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            df = df.copy(deep=True)
            n = len(df)
            chunks = [batch_size] * (n // batch_size)
            remainder = n % batch_size
            if remainder > 0:
                chunks.append(remainder)

            # Process data in batches
            for idx_chunk in np.array_split(df.index, np.cumsum(chunks)):
                if len(idx_chunk) == 0:
                    continue
                minibatch = df.loc[idx_chunk, self.tabular_transform.columns]
                mask = minibatch.isna()

                if mask_target_column:
                    minibatch = minibatch.reset_index(drop=True).copy()
                    minibatch[self._transform.target_column] = pd.NA

                imputed = self._impute(minibatch.reset_index(drop=True),
                                       mask_target_column=mask_target_column,
                                       **kwargs)
                imputed.index = idx_chunk

                if mask_target_column:
                    imputed[self._transform.target_column] = pd.NA

                df.loc[idx_chunk, self.tabular_transform.columns] = (
                    df.loc[idx_chunk, self.tabular_transform.columns]
                    .where(~mask, imputed)
                )

            self.load_model_mode()
            return df

    @abstractmethod
    def _impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
        """
        Abstract method to be implemented by concrete imputation model.
        It should return a DataFrame with imputed values.

        Args:
            df (pd.DataFrame): Batch of data to be imputed.
            **kwargs: Additional arguments as needed.

        Returns:
            pd.DataFrame: Imputed DataFrame.
        """
        raise NotImplementedError

    def evaluation(self, cfg):
        evaluate_missing(cfg, self, self.log_dir)
        return super().evaluation(cfg)


class BaseUnconditionalGeneratorMixIn:

    def __new__(cls, *args, **kwargs):
        cls.generate_uncond = create_postprocess(cls.generate_uncond, postprocess_df)
        cls.generate_uncond = seed_context(cls.generate_uncond)
        return super().__new__(cls)

    @torch.no_grad()
    def generate_uncond(
            self,
            n: int,
            batch_size: int = 4096,
            seed=None,  # passed to seed_context
            **kwargs
    ) -> pd.DataFrame:
        self.save_model_mode()
        self.eval()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            results = []
            chunks = [batch_size] * (n // batch_size)
            remainder = n % batch_size
            if remainder > 0:
                chunks.append(remainder)

            # Generate data in batches
            for chunk_size in chunks:
                results.append(self._generate_uncond(chunk_size, **kwargs))

        self.load_model_mode()
        df = pd.concat(results, ignore_index=True)
        return df

    @abstractmethod
    def _generate_uncond(self, n: int, **kwargs) -> pd.DataFrame:
        raise NotImplementedError

    def evaluation(self, cfg):
        evaluate_generation(cfg, self, self.log_dir)
        return super().evaluation(cfg)


class BaseClassConditionalGeneratorMixIn:

    def __new__(cls, *args, **kwargs):
        cls.generate_by_class = create_postprocess(cls.generate_by_class, postprocess_df)
        cls.generate_by_class = seed_context(cls.generate_by_class)
        return super().__new__(cls)

    def _preprocess_y_for_class_condition(self, y) -> pd.DataFrame:
        target_column = self.tgt

        if isinstance(y, pd.DataFrame):
            if pd.isna(y.drop(target_column, axis=1)).all().all():
                pass
            else:
                warnings.WarningMessage('input `y` contains additional information. ~~')
                y_ = y.drop(target_column, axis=1).copy()
                y_[:] = pd.NA
                y = pd.concat([y_, y[[target_column]]], axis=1)

        elif isinstance(y, (list, tuple)):
            # Directly create a DataFrame with target_column = y
            y = pd.DataFrame({target_column: y}, columns=self.tabular_transform.columns)

        elif isinstance(y, dict):
            # Example approach: expand each key 'e' repeated 'i' times
            #  and store them in target_column
            repeated = sum([[e] * i for e, i in y.items()], [])
            y = pd.DataFrame({target_column: repeated},
                             columns=self.tabular_transform.columns)

        elif isinstance(y, int):
            # If y is int, sample 'y' rows from the known target distribution
            random_values = np.random.choice(
                self.tabular_transform.target_distribution, y
            )
            y = pd.DataFrame({target_column: random_values},
                             columns=self.tabular_transform.columns)

        else:
            raise NotImplementedError(
                "cond_gen only supports pd.DataFrame, list, tuple, dict, or int."
            )

        return y

    @torch.no_grad()
    def generate_by_class(
            self,
            y: (pd.DataFrame, list, tuple, dict, int),
            batch_size: int = 4096,
            seed=None,
            **kwargs
    ) -> pd.DataFrame:
        self.save_model_mode()
        self.eval()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            df = self._preprocess_y_for_class_condition(deepcopy(y))
            n = len(df)
            chunks = [batch_size] * (n // batch_size)
            remainder = n % batch_size
            if remainder > 0:
                chunks.append(remainder)

            # Generate data in chunks, filling missing values
            for idx_chunk in np.array_split(df.index, np.cumsum(chunks)):
                if len(idx_chunk) == 0:
                    continue
                minibatch = df.loc[idx_chunk, self.tabular_transform.columns]
                mask = minibatch.isna()

                cond_generated = self._generate_by_class(minibatch.reset_index(drop=True), **kwargs)
                cond_generated.index = idx_chunk

                df[mask] = cond_generated[mask]

        self.load_model_mode()
        return df

    @abstractmethod
    def _generate_by_class(self, y: pd.DataFrame, **kwargs) -> pd.DataFrame:
        raise NotImplementedError


class BaseImbalanceMixin:

    def rebalance_targets(self, df: pd.DataFrame, seed=None, **kwargs):
        # by default, oversampling applied.
        cond_gen = pd.DataFrame(columns=self.tabular_transform.columns)
        if self.tgt in self.tabular_transform.categorical_columns:
            vc = df[self.tgt].value_counts()
            n_samples_per_class = vc.max() - vc
            cond_gen[self.tgt] = sum([[e] * n_samples for e, n_samples in n_samples_per_class.items()], [])

        elif self.tgt in self.tabular_transform.numerical_columns:
            cond_gen[self.tgt] = get_conditional_y_from_rare_targets(df[self.tgt], rare_thr=0.8)

        else:
            raise NotImplementedError

        oversampled = self.generate_by_class(cond_gen, seed=seed, **kwargs)
        return pd.concat([df, oversampled], ignore_index=True)


class BaseArbitraryConditionalGeneratorMixIn(BaseClassConditionalGeneratorMixIn):

    def __new__(cls, *args, **kwargs):
        cls.generate_by_condition = create_postprocess(cls.generate_by_condition, postprocess_df)
        cls.generate_by_condition = seed_context(cls.generate_by_condition)
        return super().__new__(cls)

    def _preprocess_y_for_condition(self, y) -> pd.DataFrame:
        target_column = self.tgt

        if isinstance(y, pd.DataFrame):
            pass

        elif isinstance(y, (list, tuple)):
            # Directly create a DataFrame with target_column = y
            y = pd.DataFrame({target_column: y}, columns=self.tabular_transform.columns)

        elif isinstance(y, dict):
            # Example approach: expand each key 'e' repeated 'i' times
            #  and store them in target_column
            repeated = sum([[e] * i for e, i in y.items()], [])
            y = pd.DataFrame({target_column: repeated},
                             columns=self.tabular_transform.columns)

        elif isinstance(y, int):
            # If y is int, sample 'y' rows from the known target distribution
            random_values = np.random.choice(
                self.tabular_transform.target_distribution, y
            )
            y = pd.DataFrame({target_column: random_values},
                             columns=self.tabular_transform.columns)

        else:
            raise NotImplementedError(
                "cond_gen only supports pd.DataFrame, list, tuple, dict, or int."
            )

        return y

    @torch.no_grad()
    def generate_by_condition(
            self,
            y: (pd.DataFrame, list, tuple, dict, int),
            batch_size: int = 4096,
            seed=None,
            **kwargs
    ) -> pd.DataFrame:
        self.save_model_mode()
        self.eval()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            df = self._preprocess_y_for_condition(deepcopy(y))
            n = len(df)
            chunks = [batch_size] * (n // batch_size)
            remainder = n % batch_size
            if remainder > 0:
                chunks.append(remainder)

            # Generate data in chunks, filling missing values
            for idx_chunk in np.array_split(df.index, np.cumsum(chunks)):
                if len(idx_chunk) == 0:
                    continue
                minibatch = df.loc[idx_chunk, self.tabular_transform.columns]
                mask = minibatch.isna()

                cond_generated = self._generate_by_condition(minibatch.reset_index(drop=True), **kwargs)
                cond_generated.index = idx_chunk

                df[mask] = cond_generated[mask]

        self.load_model_mode()
        return df

    @abstractmethod
    def _generate_by_condition(self, condition: pd.DataFrame, **kwargs) -> pd.DataFrame:
        raise NotImplementedError
