from abc import ABC, abstractmethod
from imblearn.over_sampling import RandomOverSampler, SMOTE, SMOTENC
import pandas as pd
import numpy as np
import pandas as pd
import random
import logging
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import  MinMaxScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from collections import Counter
from sdv.single_table import CTGANSynthesizer, TVAESynthesizer, CopulaGANSynthesizer
from src.syngen.baselines.ctabgan.main import CTABGAN
from sdv.sampling import Condition
from sdv.metadata import SingleTableMetadata
from sklearn.neighbors import NearestNeighbors
from src.syngen.baselines.ttvae.ttvae import TTVAE
from src.syngen.baselines.ttvae.ttvae_tbs import TTVAETBS
from src.syngen.cttvae.cttvae import CTTVAE
from src.syngen.cttvae.cttvae_tbs import CTTVAETBS
from sklearn.preprocessing import OneHotEncoder
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
from typing import Tuple 
from src.syngen.scripts.utils import set_global_seed


_log = logging.getLogger(__name__)


def split_dataset(df: pd.DataFrame, target_column, test_size=0.2, random_state=42) -> Tuple[pd.DataFrame, pd.DataFrame]:
    X = df.drop(columns=[target_column])
    y = df[target_column]

    # splitting while preserving the same class proportions
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=random_state
    )

    # Merge back features and labels for both subsets
    train_data = pd.concat([X_train, y_train], axis=1)
    test_data = pd.concat([X_test, y_test], axis=1)
    return train_data, test_data


class InterpolationMethods(ABC):

    def __init__(self, sampling_strategy='minority', random_state=None):
        self.sampling_strategy = sampling_strategy
        self.random_state = random_state
        self.train_time = None
        self.gen_time = None

    
    def _fit_preprocessor(self, X):
        # Detect categorical columns: object dtype or low-cardinality integers
        self.categorical_cols = [
            col for col in X.columns
            if X[col].dtype == 'object' or
               (pd.api.types.is_integer_dtype(X[col]) and X[col].nunique() <= 10)
        ]
        self.numerical_cols = [col for col in X.columns if col not in self.categorical_cols]

        transformers = []
        if self.categorical_cols:
            transformers.append((
                'cat',
                OneHotEncoder(handle_unknown='ignore', sparse_output=False),
                self.categorical_cols
            ))

        if self.numerical_cols:
            transformers.append(('num', 'passthrough', self.numerical_cols))

        self.preprocessor = ColumnTransformer(transformers)
        X_encoded = self.preprocessor.fit_transform(X)
        return X_encoded


    def _inverse_preprocess(self, X_encoded):
        X_cat_df = pd.DataFrame()
        n_cat_features = 0
        if self.categorical_cols:
            cat_transformer = self.preprocessor.named_transformers_['cat']
            n_cat_features = len(cat_transformer.get_feature_names_out())
            X_cat = cat_transformer.inverse_transform(X_encoded[:, :n_cat_features])
            X_cat_df = pd.DataFrame(X_cat, columns=self.categorical_cols)

        X_df = X_cat_df
        if self.numerical_cols:
            X_num = X_encoded[:, n_cat_features:]
            X_num_df = pd.DataFrame(X_num, columns=self.numerical_cols)
            X_df = pd.concat([X_cat_df, X_num_df], axis=1)

        # Force column order to match original dataset
        return X_df[self.categorical_cols + self.numerical_cols]


    def generate(self, train_data, target_column, n_to_generate=None):
        X = train_data.drop(columns=[target_column])
        y = train_data[target_column]
        X_encoded = self._fit_preprocessor(X)

        _log.info(Counter(y))
        X_resampled, y_resampled = self.sampler.fit_resample(X_encoded, y)
        _log.info(Counter(y_resampled))

        n_original = len(X)
        if n_to_generate is not None:
            X_new = X_resampled[n_original:n_original + n_to_generate]
            y_new = y_resampled[n_original:n_original + n_to_generate]
        else:
            X_new = X_resampled[n_original:]
            y_new = y_resampled[n_original:]

        # Inverse transform the new synthetic data
        X_decoded = self._inverse_preprocess(X_new)
        X_decoded[target_column] = y_new.reset_index(drop=True)

        # Final synthetic dataset with same column names and order as original
        synthetic_data = X_decoded[train_data.columns]
        
        augmented_data = pd.concat([train_data, synthetic_data], axis=0).reset_index(drop=True)
        return augmented_data


class RandomOverSampling(InterpolationMethods):

    def __init__(self, sampling_strategy='minority', random_state=None):
        super().__init__(sampling_strategy, random_state)
        self.sampler = RandomOverSampler(sampling_strategy=sampling_strategy, random_state=random_state)

    
    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='minority', n_to_generate=None):
        return super().generate(train_data, target_column=conditional_column, n_to_generate=n_to_generate)


class Smote(InterpolationMethods):

    def __init__(self, sampling_strategy='minority', random_state=None):
        super().__init__(sampling_strategy, random_state)
        self.sampler = SMOTE(sampling_strategy=sampling_strategy, random_state=random_state)
        # self.preprocessor = None
    

    def _generate_fully_synthetic(self, train_data: pd.DataFrame, target_column):
        n_total = len(train_data)

        X = train_data.drop(columns=[target_column])
        y = train_data[target_column]

        # Fit preprocessor to full dataset
        X_encoded = self._fit_preprocessor(X)
        # print("Preprocessor fitted:", hasattr(self.preprocessor.named_transformers_['cat'], 'categories_'))

        # Use same encoded data to get per-class slices
        encoded_df = pd.DataFrame(X_encoded, columns=self.preprocessor.get_feature_names_out())
        encoded_df[target_column] = y.reset_index(drop=True)

        proportions = y.value_counts(normalize=True)
        samples_per_class = {cls: int(n_total * prop) for cls, prop in proportions.items()}

        synthetic_parts = []

        for cls, n_samples in samples_per_class.items():
            class_df = encoded_df[encoded_df[target_column] == cls].reset_index(drop=True)
            if len(class_df) < 2:
                continue

            X_cls = class_df.drop(columns=[target_column]).values
            assert X_cls.ndim == 2, f"Expected 2D array, got shape {X_cls.shape}"

            nn = NearestNeighbors(n_neighbors=2).fit(X_cls)
            neighbors = nn.kneighbors(return_distance=False)

            synth_rows = []
            for _ in range(n_samples):
                i = np.random.randint(len(X_cls))
                j = neighbors[i][1]
                lam = np.random.rand()
                interpolated = lam * X_cls[i] + (1 - lam) * X_cls[j]
                synth_rows.append(interpolated)

            synth_df = pd.DataFrame(synth_rows, columns=self.preprocessor.get_feature_names_out())
            # Decode before storing
            X_decoded = self._inverse_preprocess(synth_df.values)
            X_decoded[target_column] = cls
            synthetic_parts.append(X_decoded)

        synthetic_df = pd.concat(synthetic_parts, axis=0).reset_index(drop=True)
        return synthetic_df[train_data.columns]


    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='minority', n_to_generate=None):
        if sampling_strategy == "all":
            return self._generate_fully_synthetic(train_data, target_column=conditional_column)
        else:
            return super().generate(train_data, target_column=conditional_column, n_to_generate=n_to_generate)


class SmoteENC(InterpolationMethods):

    def __init__(self, categorical_features, sampling_strategy='minority', random_state=None):
        super().__init__(sampling_strategy, random_state)
        self.categorical_features_names = categorical_features
        self.categorical_indices = None
        self.sampler = None

    def _fit_sampler(self, X):
        # Convert categorical column names to indices
        self.categorical_indices = [X.columns.get_loc(col) for col in self.categorical_features_names]
        self.sampler = SMOTENC(categorical_features=self.categorical_indices, sampling_strategy=self.sampling_strategy, random_state=self.random_state)
    

    def _generate_fully_synthetic(self, train_data: pd.DataFrame, target_column):
        n_total = len(train_data)

        X = train_data.drop(columns=[target_column])
        y = train_data[target_column]

        # Fit preprocessor to full dataset
        X_encoded = self._fit_preprocessor(X)

        # Use same encoded data to get per-class slices
        encoded_df = pd.DataFrame(X_encoded, columns=self.preprocessor.get_feature_names_out())
        encoded_df[target_column] = y.reset_index(drop=True)

        proportions = y.value_counts(normalize=True)
        samples_per_class = {cls: int(n_total * prop) for cls, prop in proportions.items()}

        synthetic_parts = []

        for cls, n_samples in samples_per_class.items():
            class_df = encoded_df[encoded_df[target_column] == cls].reset_index(drop=True)
            if len(class_df) < 2:
                continue

            X_cls = class_df.drop(columns=[target_column]).values

            nn = NearestNeighbors(n_neighbors=2).fit(X_cls)
            neighbors = nn.kneighbors(return_distance=False)

            synth_rows = []
            for _ in range(n_samples):
                i = np.random.randint(len(X_cls))
                j = neighbors[i][1]
                lam = np.random.rand()
                interpolated = lam * X_cls[i] + (1 - lam) * X_cls[j]
                synth_rows.append(interpolated)

            synth_df = pd.DataFrame(synth_rows, columns=self.preprocessor.get_feature_names_out())
            # Decode before storing
            X_decoded = self._inverse_preprocess(synth_df.values)
            X_decoded[target_column] = cls
            synthetic_parts.append(X_decoded)

        synthetic_df = pd.concat(synthetic_parts, axis=0).reset_index(drop=True)
        return synthetic_df[train_data.columns]


    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='minority', n_to_generate=None):
        if self.sampler is None:
            self._fit_sampler(train_data)
        if sampling_strategy == "all":
            return self._generate_fully_synthetic(train_data, target_column=conditional_column)
        else:
            return super().generate(train_data, target_column=conditional_column, n_to_generate=n_to_generate)


class BaseSDVGenerator(ABC):

    def __init__(self, epochs, batch_size, sampling_strategy, pac=None, verbose=True, random_state=None):
        self.epochs = epochs
        self.batch_size = batch_size
        self.sampling_strategy = sampling_strategy
        self.verbose = verbose
        self.pac = pac
        self.model = None
        self.train_time = None
        self.gen_time = None
        self.final_loss = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.random_state = random_state


    @abstractmethod
    def get_model_class(self):
        pass


    def fit(self, X, categorical_columns=None, conditional_column=None, save_path=''):
        metadata = SingleTableMetadata()
        metadata.detect_from_dataframe(X)
        model_class = self.get_model_class()

        kwargs = {
            "metadata": metadata,
            "epochs": self.epochs,
            "batch_size": self.batch_size,
            "verbose": self.verbose,
        }
        if self.pac is not None:
            kwargs["pac"] = self.pac

        self.model = model_class(**kwargs)
        if self.random_state is not None:
            set_global_seed(self.random_state)
        self.model.fit(X)


    def load_model(self, model_path):
        self.model = self.get_model_class().load(model_path)


    def generate_minority(self, train_data, conditional_column, n_to_generate=None):
        """
        Generate synthetic samples for just the minority class.
        """
        class_counts = train_data[conditional_column].value_counts()
        min_class = class_counts.idxmin()
        max_count = class_counts.max()
        balanced_samples = []

        for cls in class_counts.index:
            if cls != min_class:
                continue

            if n_to_generate is None:
                n_to_generate = max_count - class_counts[min_class]
            _log.info(f"Number of samples to generate for class '{cls}': {n_to_generate}")

            if n_to_generate > 0:
                try:
                    condition = Condition(
                        column_values={conditional_column: min_class},
                        num_rows=n_to_generate
                    )

                    samples = self.model.sample_from_conditions(
                        conditions=[condition]
                    )

                    balanced_samples.append(samples)

                except Exception as e:
                    _log.warning(f"Conditional generation for class {cls} failed: {e}")

        if balanced_samples:
            balanced_df = pd.concat(balanced_samples).reset_index(drop=True)
            return pd.concat([train_data, balanced_df], axis=0).reset_index(drop=True)
        
        return None


    def generate_all(self, train_data: pd.DataFrame, n_to_generate=None):
        try:
            if n_to_generate is None:
                n_to_generate = len(train_data)
            _log.info(f"Number of samples to generate: {n_to_generate}")
            return self.model.sample(n_to_generate)
        except Exception as e:
            _log.warning(f"Generation failed: {e}")
            return None


    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='all', n_to_generate=None):
        if sampling_strategy == "minority":
            synthetic_data = self.generate_minority(train_data, conditional_column=conditional_column, n_to_generate=n_to_generate)
        elif sampling_strategy == "all":
            synthetic_data = self.generate_all(train_data, n_to_generate=n_to_generate)
        else:
            raise ValueError(f"Unknown sampling strategy: '{sampling_strategy}'.")

        return synthetic_data


class CTGANGenerator(BaseSDVGenerator):

    def __init__(self, epochs, batch_size, sampling_strategy, pac=1, verbose=True, random_state=None):
        super().__init__(epochs, batch_size, sampling_strategy, pac, verbose, random_state)
        self.losses = []

    def get_model_class(self):
        return CTGANSynthesizer
    
    def fit(self, X, categorical_columns=None, conditional_column=None, save_path=''):
        super().fit(X, categorical_columns, save_path)
        self.final_loss = self.model.get_loss_values()["Generator Loss"].iloc[-1] if self.model.get_loss_values() is not None else None


class CTABGANGenerator:
    def __init__(self, epochs, batch_size, sampling_strategy, verbose=True, l2scale=1e-5, categorical_columns=[],
                log_columns=[], mixed_columns={}, general_columns=[], non_categorical_columns=[], integer_columns=[],
                problem_type={}, class_dim=(256, 256, 256, 256), random_dim=100, num_channels=64, lr=2e-4, random_state=None):
        self.epochs = epochs
        self.batch_size = batch_size
        self.sampling_strategy = sampling_strategy
        self.verbose = verbose
        self.l2scale = l2scale
        self.categorical_columns = categorical_columns
        self.log_columns = log_columns
        self.mixed_columns = mixed_columns
        self.general_columns = general_columns
        self.non_categorical_columns = non_categorical_columns
        self.integer_columns = integer_columns
        self.problem_type = problem_type
        self.class_dim = tuple(class_dim)
        self.random_dim = random_dim
        self.num_channels = num_channels
        self.l2scale = l2scale
        self.batch_size = batch_size
        self.epochs = epochs
        self.lr = lr
        self.model = None
        self.loss_values = None

        self.train_time = None
        self.gen_time = None
        self.final_loss = None

        self.random_state = random_state


    def fit(self, X, categorical_columns=None, conditional_column=None, save_path=''):
        self.target_column = conditional_column
        categorical_columns = self.categorical_columns if categorical_columns is None else categorical_columns

        if conditional_column and conditional_column not in categorical_columns:
            categorical_columns = categorical_columns + [conditional_column]

        self.categorical_columns = categorical_columns

        self.model = CTABGAN(
            df=X,
            test_ratio=0.0,
            categorical_columns=self.categorical_columns if categorical_columns is None else categorical_columns,
            log_columns=self.log_columns,
            mixed_columns=self.mixed_columns,
            general_columns=self.general_columns,
            non_categorical_columns=self.non_categorical_columns,
            integer_columns=self.integer_columns,
            problem_type=self.problem_type,
            class_dim=self.class_dim,
            random_dim=self.random_dim,
            num_channels=self.num_channels,
            l2scale=self.l2scale,
            batch_size=self.batch_size,
            epochs=self.epochs,
            lr=self.lr,
        )
        if self.random_state is not None:
            set_global_seed(self.random_state)
        self.model.fit()
        self.model.save(save_path)
        self.loss_values = pd.DataFrame(self.model.history)
        self.final_loss = self.loss_values["g_loss"].iloc[-1] if not self.loss_values.empty else None

    def generate_all(self, train_data, n_to_generate=None):
        if n_to_generate is None:
            n_to_generate = len(train_data)
        samples = self.model.generate_samples(n_to_generate)

        print(samples[self.target_column].unique())


        return samples

    def generate_minority(self, train_data: pd.DataFrame, conditional_column: str, n_to_generate=None):
        class_counts = train_data[conditional_column].value_counts()
        minority_class = class_counts.idxmin()
        max_count = class_counts.max()

        if n_to_generate is None:
            n_to_generate = max_count - class_counts[minority_class]

        extra_samples = self.model.generate_samples(n_to_generate * 2)
        filtered = extra_samples[extra_samples[conditional_column] == minority_class]

        if len(filtered) >= n_to_generate:
            final_samples = filtered.sample(n=n_to_generate, random_state=42)
        else:
            final_samples = filtered

        return pd.concat([train_data, final_samples], axis=0).reset_index(drop=True)

    def generate(self, train_data, conditional_column, sampling_strategy='all', n_to_generate=None):
        if sampling_strategy == 'all':
            return self.generate_all(train_data, n_to_generate)
        elif sampling_strategy == 'minority':
            return self.generate_minority(train_data, conditional_column, n_to_generate)
        else:
            raise ValueError(f"Unknown sampling strategy: {sampling_strategy}")


class CopulaGANGenerator(BaseSDVGenerator):

    def __init__(self, epochs, batch_size, sampling_strategy, pac=1, verbose=True, random_state=None):
        super().__init__(epochs, batch_size, sampling_strategy, pac, verbose, random_state)
        self.losses = []
    
    def get_model_class(self):
        return CopulaGANSynthesizer
    
    def fit(self, X, categorical_columns=None, conditional_column=None, save_path=''):
        super().fit(X, categorical_columns, save_path)
        self.final_loss = self.model.get_loss_values()["Generator Loss"].iloc[-1] if self.model.get_loss_values() is not None else None


class TVAEGenerator(BaseSDVGenerator):

    def __init__(self, epochs, batch_size, sampling_strategy, pac=None, verbose=True, random_state=None):
        super().__init__(epochs, batch_size, sampling_strategy, pac, verbose, random_state)

    def get_model_class(self):
        return TVAESynthesizer
    
    def fit(self, X, categorical_columns=None, conditional_column=None, save_path=''):
        super().fit(X, categorical_columns, save_path)
        self.final_loss = self.model.get_loss_values()["Loss"].iloc[-1] if self.model.get_loss_values() is not None else None


class TTVAEGenerator:
    def __init__(self, epochs, batch_size, sampling_strategy, verbose=True, l2scale=1e-5, loss_factor=2, latent_dim=32,
                 embedding_dim=128, nhead=8, dim_feedforward=1028, dropout=0.1, random_state=None):
        self.epochs = epochs
        self.batch_size = batch_size
        self.sampling_strategy = sampling_strategy
        self.verbose = verbose
        self.l2scale = l2scale
        self.loss_factor = loss_factor
        self.latent_dim = latent_dim
        self.embedding_dim = embedding_dim
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.model = None
        self.training_history = []
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        _log.info(f"Using device: {self.device}")
        _log.info(f"GPU available: {torch.cuda.is_available()}")
        # _log.info(f"Model on GPU: {next(self.encoder.parameters()).is_cuda}")

        self.train_time = None
        self.gen_time = None
        self.final_loss = None

        self.random_state = random_state


    def fit(self, df: pd.DataFrame, categorical_columns, conditional_column=None, save_path=''):
        if conditional_column is not None:
            if conditional_column not in df.columns:
                raise ValueError(f"Conditional column '{conditional_column}' not found in DataFrame.")
            target_col = conditional_column
        else:
            target_col = df.columns[-1]
        _log.info(f"Categorical columns: {categorical_columns}")
        discrete_cols = list(categorical_columns) + [target_col]

        if self.random_state is not None:
            set_global_seed(self.random_state)

        self.model = TTVAE(
            epochs=self.epochs,
            batch_size=self.batch_size,
            device=self.device,
            verbose=self.verbose,
            l2scale=self.l2scale,
            loss_factor=self.loss_factor,
            latent_dim=self.latent_dim,
            embedding_dim=self.embedding_dim,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout
        )

        _log.info(f"Saving model to {save_path}")
        _log.info(f"Fitting TTVAE model...")
        self.model.fit(df, discrete_columns=discrete_cols, save_path=save_path)

        self.training_history = self.model.loss_values
        self.final_loss = self.training_history['Loss'].iloc[-1] if not self.training_history.empty else None

    def load_model(self, model_path):
        self.model = torch.load(model_path, map_location=self.device)

    def generate_all(self, train_data: pd.DataFrame, n_to_generate=None):
        try:
            if n_to_generate is None:
                n_to_generate = len(train_data)
            _log.info(f"Number of samples to generate: {n_to_generate}")
            generated_data = self.model.sample(n_samples=n_to_generate)
            return pd.DataFrame(generated_data)
        except Exception as e:
            _log.warning(f"Generation failed: {e}")
            return None
        
    def generate_minority(self, train_data: pd.DataFrame, conditional_column, n_to_generate=None):
        class_counts = train_data[conditional_column].value_counts()
        minority_class = class_counts.idxmin()
        max_count = class_counts.max()

        if n_to_generate is None:
            n_to_generate = max_count - class_counts[minority_class]
        _log.info(f"Number of samples to generate for class '{minority_class}': {n_to_generate}")

        # if n_to_generate <= 0:
        #     _log.warning("No need to generate minority samples — already balanced.")
        #     return train_data.copy()

        batch_size = 2 * n_to_generate  # generate more to filter from
        collected = []

        while True:
            synthetic_samples = self.model.sample(n_samples=batch_size)
            synthetic_df = pd.DataFrame(synthetic_samples)
            minority_df = synthetic_df[synthetic_df[conditional_column] == minority_class]

            needed = n_to_generate - len(collected)
            if len(minority_df) > needed:
                collected.append(minority_df.sample(n=needed, random_state=42))
                break
            else:
                collected.append(minority_df)
                if sum(len(df) for df in collected) >= n_to_generate:
                    break

        final_minority_samples = pd.concat(collected, axis=0).head(n_to_generate)
        augmented_df = pd.concat([train_data, final_minority_samples], axis=0).reset_index(drop=True)
        return augmented_df


    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='all', n_to_generate=None):
        if self.random_state is not None:
            set_global_seed(self.random_state)
            
        if sampling_strategy == 'all':
            synthetic_data = self.generate_all(train_data, n_to_generate=n_to_generate)
        elif sampling_strategy == 'minority':
            synthetic_data = self.generate_minority(train_data, conditional_column=conditional_column, n_to_generate=n_to_generate)
        else:
            raise ValueError(f"Unknown sampling strategy: '{sampling_strategy}'.")

        return synthetic_data


class TTVAETBSGenerator:
    def __init__(self, epochs, batch_size, sampling_strategy, verbose=True, l2scale=1e-5, loss_factor=2, latent_dim=32,
                 embedding_dim=128, nhead=8, dim_feedforward=1028, dropout=0.1, tbs_strategy='specific',
                 conditional_column=None, lambda_scale=0.4, random_state=None):
        self.epochs = epochs
        self.batch_size = batch_size
        self.sampling_strategy = sampling_strategy
        self.verbose = verbose
        self.l2scale = l2scale
        self.loss_factor = loss_factor
        self.latent_dim = latent_dim
        self.embedding_dim = embedding_dim
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.tbs_strategy = tbs_strategy
        self.lambda_scale = lambda_scale
        self.model = None
        self.training_history = []
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        _log.info(f"Using device: {self.device}")
        _log.info(f"GPU available: {torch.cuda.is_available()}")

        self.conditional_column = conditional_column

        self.train_time = None
        self.gen_time = None
        self.final_loss = None
        self.random_state = random_state


    def fit(self, df: pd.DataFrame, categorical_columns, conditional_column=None, save_path=''):
        if conditional_column is not None:
            if conditional_column not in df.columns:
                raise ValueError(f"Conditional column '{conditional_column}' not found in DataFrame.")
            target_col = conditional_column
        else:
            target_col = df.columns[-1]
        print(f"Categorical columns: {categorical_columns}")
        discrete_cols = list(categorical_columns) + [target_col]

        if self.random_state is not None:
            set_global_seed(self.random_state)

        self.model = TTVAETBS(
            epochs=self.epochs,
            batch_size=self.batch_size,
            device=self.device,
            verbose=self.verbose,
            l2scale=self.l2scale,
            loss_factor=self.loss_factor,
            latent_dim=self.latent_dim,
            embedding_dim=self.embedding_dim,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            tbs_strategy=self.tbs_strategy,
            lambda_scale=self.lambda_scale
        )

        
        _log.info(f"Saving model to {save_path}")
        _log.info(f"Fitting TTVAE model with TBS...")
        self.model.fit(df, discrete_columns=discrete_cols, condition_column=conditional_column, save_path=save_path)

        self.training_history = self.model.loss_values
        self.final_loss = self.training_history['Loss'].iloc[-1] if not self.training_history.empty else None

    def load_model(self, model_path):
        self.model = torch.load(model_path, map_location=self.device)

    def generate_all(self, train_data: pd.DataFrame, n_to_generate=None):
        try:
            if n_to_generate is None:
                n_to_generate = len(train_data)
            _log.info(f"Number of samples to generate: {n_to_generate}")
            generated_data = self.model.sample(n_samples=n_to_generate)
            return pd.DataFrame(generated_data)
        except Exception as e:
            _log.warning(f"Generation failed: {e}")
            return None
        
    def generate_minority(self, train_data: pd.DataFrame, conditional_column, n_to_generate=None):
        class_counts = train_data[conditional_column].value_counts()
        minority_class = class_counts.idxmin()
        max_count = class_counts.max()

        if n_to_generate is None:
            n_to_generate = max_count - class_counts[minority_class]
        _log.info(f"Number of samples to generate for class '{minority_class}': {n_to_generate}")

        # if n_to_generate <= 0:
        #     _log.warning("No need to generate minority samples — already balanced.")
        #     return train_data.copy()

        batch_size = 2 * n_to_generate  # generate more to filter from
        collected = []
    
        while True:
            synthetic_samples = self.model.sample(n_samples=batch_size)
            synthetic_df = pd.DataFrame(synthetic_samples)
            minority_df = synthetic_df[synthetic_df[conditional_column] == minority_class]
    
            needed = n_to_generate - len(collected)
            if len(minority_df) > needed:
                collected.append(minority_df.sample(n=needed, random_state=42))
                break
            else:
                collected.append(minority_df)
                if sum(len(df) for df in collected) >= n_to_generate:
                    break
                
        final_minority_samples = pd.concat(collected, axis=0).head(n_to_generate)
        augmented_df = pd.concat([train_data, final_minority_samples], axis=0).reset_index(drop=True)
        return augmented_df


    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='all', n_to_generate=None):
        if self.random_state is not None:
            set_global_seed(self.random_state)

        if sampling_strategy == 'all':
            synthetic_data = self.generate_all(train_data, n_to_generate=n_to_generate)
        elif sampling_strategy == 'minority':
            synthetic_data = self.generate_minority(train_data, conditional_column=conditional_column, n_to_generate=n_to_generate)
        else:
            raise ValueError(f"Unknown sampling strategy: '{sampling_strategy}'.")

        return synthetic_data


class CTTVAEGenerator:
    def __init__(self, epochs, batch_size, sampling_strategy, verbose=True, l2scale=1e-5, loss_factor=2, triplet_factor=1, latent_dim=32,
                 embedding_dim=128, nhead=8, dim_feedforward=1028, dropout=0.1, triplet_margin=0.2, cond_strategy='specific',
                 conditional_column=None, random_state=None):
        self.epochs = epochs
        self.batch_size = batch_size
        self.sampling_strategy = sampling_strategy
        self.verbose = verbose
        self.l2scale = l2scale
        self.loss_factor = loss_factor
        self.triplet_factor = triplet_factor
        self.latent_dim = latent_dim
        self.embedding_dim = embedding_dim
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.triplet_margin = triplet_margin
        self.model = None
        self.training_history = []
        self.random_state = random_state
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        _log.info(f"Using device: {self.device}")
        _log.info(f"GPU available: {torch.cuda.is_available()}")
        # _log.info(f"Model on GPU: {next(self.encoder.parameters()).is_cuda}")

        self.cond_strategy = cond_strategy
        self.conditional_column = conditional_column

        self.train_time = None
        self.gen_time = None
        self.final_loss = None

    def fit(self, df: pd.DataFrame, categorical_columns, conditional_column=None, save_path=''):
        if conditional_column is not None:
            if conditional_column not in df.columns:
                raise ValueError(f"Conditional column '{conditional_column}' not found in DataFrame.")
            target_col = conditional_column
        else:
            target_col = df.columns[-1]
        _log.info(f"Categorical columns: {categorical_columns}")
        discrete_cols = list(categorical_columns) + [target_col]

        if self.conditional_column is not None:
            conditional_column = self.conditional_column
        elif conditional_column is None:
            conditional_column = target_col

        _log.info(f"Conditional column: {conditional_column}")

        if self.random_state is not None:
            set_global_seed(self.random_state)
        
        self.model = CTTVAE(
            epochs=self.epochs,
            batch_size=self.batch_size,
            device=self.device,
            verbose=self.verbose,
            l2scale=self.l2scale,
            loss_factor=self.loss_factor,
            triplet_factor=self.triplet_factor,
            latent_dim=self.latent_dim,
            embedding_dim=self.embedding_dim,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            triplet_margin=self.triplet_margin
        )

        _log.info(f"Saving model to {save_path}")
        _log.info(f"Fitting CTTVAE model...")
        self.model.fit(df, discrete_columns=discrete_cols, cond_strategy='specific_column', condition_column=conditional_column, save_path=save_path)

        self.training_history = self.model.loss_values
        self.final_loss = self.training_history['Loss'].iloc[-1] if not self.training_history.empty else None

    def load_model(self, model_path):
        self.model = torch.load(model_path, map_location=self.device)

    # def generate_all(self, train_data: pd.DataFrame, n_to_generate=None):
    #     try:
    #         if n_to_generate is None:
    #             n_to_generate = len(train_data)
    #         _log.info(f"Generating {n_to_generate} synthetic samples using CTTVAE...")
    #         generated_data = self.model.sample(n_samples=n_to_generate)
    #         return pd.DataFrame(generated_data)
    #     except Exception as e:
    #         _log.warning(f"Generation failed: {e}")
    #         return None

    def generate_all(self, train_data: pd.DataFrame, conditional_column=None, n_to_generate=None):
        try:
            if n_to_generate is None:
                n_to_generate = len(train_data)
            _log.info(f"Generating {n_to_generate} synthetic samples using CTTVAE...")

            # Compute class distribution
            class_counts = train_data[conditional_column].value_counts(normalize=True)
            _log.info(f"Class distribution: {class_counts.to_dict()}")

            synthetic_dfs = []

            for class_value, prop in class_counts.items():
                # if class_value != 1:
                #     continue
                n_class_samples = int(round(prop * n_to_generate))
                _log.info(f"Generating {n_class_samples} samples for class {class_value}...")

                synthetic_samples = self.model.sample(
                    n_samples=n_class_samples,
                    condition_column=conditional_column,
                    condition_value=class_value
                )

                df_class = pd.DataFrame(synthetic_samples)
                synthetic_dfs.append(df_class)

            full_synthetic_df = pd.concat(synthetic_dfs, axis=0).reset_index(drop=True)

            # Log final class distribution
            if conditional_column in full_synthetic_df.columns:
                distribution = full_synthetic_df[conditional_column].value_counts(normalize=True)
                _log.info("Synthetic distribution (all):")
                _log.info(distribution)

            return full_synthetic_df

        except Exception as e:
            _log.warning(f"Generation failed: {e}")
            return None


    def generate_minority(self, train_data: pd.DataFrame, conditional_column, n_to_generate=None):
        class_counts = train_data[conditional_column].value_counts()
        minority_class = class_counts.idxmin()
        max_count = class_counts.max()

        if n_to_generate is None:
            n_to_generate = max_count - class_counts[minority_class]

        # if n_to_generate <= 0:
        #     _log.warning("No need to generate minority samples — already balanced.")
        #     return train_data.copy()

        # Sample synthetic data conditioned on the minority instance
        synthetic_samples = self.model.sample(
            n_samples=n_to_generate,
            condition_column=conditional_column,
            condition_value=minority_class
        )
        minority_samples = pd.DataFrame(synthetic_samples)

        # if len(minority_samples) > n_to_generate:
        #     minority_samples = minority_samples.sample(n=n_to_generate, random_state=42)
        # elif len(minority_samples) < n_to_generate:
        #     _log.warning(f"Generated only {len(minority_samples)} minority samples instead of {n_to_generate}")

        augmented_df = pd.concat([train_data, minority_samples], axis=0).reset_index(drop=True)
        return augmented_df

    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='all', n_to_generate=None):
        if self.random_state is not None:
            set_global_seed(self.random_state)

        if sampling_strategy == 'all':
            synthetic_data = self.generate_all(train_data, conditional_column=conditional_column, n_to_generate=n_to_generate)
        elif sampling_strategy == 'minority':
            synthetic_data = self.generate_minority(train_data, conditional_column=conditional_column, n_to_generate=n_to_generate)
        else:
            raise ValueError(f"Unknown sampling strategy: '{sampling_strategy}'.")

        return synthetic_data


class CTTVAETBSGenerator:
    def __init__(self, epochs, batch_size, sampling_strategy, verbose=True, l2scale=1e-5, loss_factor=2, triplet_factor=1, latent_dim=32,
                 embedding_dim=128, nhead=8, dim_feedforward=1028, dropout=0.1, triplet_margin=0.2, tbs_strategy='specific',
                 conditional_column=None, lambda_scale=0.4, random_state=None):
        self.epochs = epochs
        self.batch_size = batch_size
        self.sampling_strategy = sampling_strategy
        self.verbose = verbose
        self.l2scale = l2scale
        self.loss_factor = loss_factor
        self.triplet_factor = triplet_factor
        self.latent_dim = latent_dim
        self.embedding_dim = embedding_dim
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.triplet_margin = triplet_margin
        self.lambda_scale = lambda_scale
        self.model = None
        self.training_history = []
        self.random_state = random_state
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        _log.info(f"Using device: {self.device}")
        _log.info(f"GPU available: {torch.cuda.is_available()}")

        self.tbs_strategy = tbs_strategy
        self.conditional_column = conditional_column

        self.train_time = None
        self.gen_time = None
        self.final_loss = None

    def fit(self, df: pd.DataFrame, categorical_columns, conditional_column=None, save_path=''):
        if conditional_column is not None:
            if conditional_column not in df.columns:
                raise ValueError(f"Conditional column '{conditional_column}' not found in DataFrame.")
            target_col = conditional_column
        else:
            target_col = df.columns[-1]
        _log.info(f"Categorical columns: {categorical_columns}")
        discrete_cols = list(categorical_columns) + [target_col]

        if self.conditional_column is not None:
            conditional_column = self.conditional_column
        elif conditional_column is None:
            conditional_column = target_col

        _log.info(f"Conditional column: {conditional_column}")

        if self.random_state is not None:
            set_global_seed(self.random_state)

        self.model = CTTVAETBS(
            epochs=self.epochs,
            batch_size=self.batch_size,
            device=self.device,
            verbose=self.verbose,
            l2scale=self.l2scale,
            loss_factor=self.loss_factor,
            triplet_factor=self.triplet_factor,
            latent_dim=self.latent_dim,
            embedding_dim=self.embedding_dim,
            nhead=self.nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            triplet_margin=self.triplet_margin,
            lambda_scale=self.lambda_scale,
            tbs_strategy=self.tbs_strategy
        )

        _log.info(f"Saving model to {save_path}")
        _log.info(f"Fitting CTTVAE model with TBS...")
        self.model.fit(df, discrete_columns=discrete_cols, condition_column=conditional_column, save_path=save_path)

        self.training_history = self.model.loss_values
        self.final_loss = self.training_history['Loss'].iloc[-1] if not self.training_history.empty else None

    def load_model(self, model_path):
        self.model = torch.load(model_path, map_location=self.device)

    # def generate_all(self, train_data: pd.DataFrame, n_to_generate=None):
    #     try:
    #         if n_to_generate is None:
    #             n_to_generate = len(train_data)
    #         _log.info(f"Generating {n_to_generate} synthetic samples using CTTVAE...")
    #         generated_data = self.model.sample(n_samples=n_to_generate)
    #         return pd.DataFrame(generated_data)
    #     except Exception as e:
    #         _log.warning(f"Generation failed: {e}")
    #         return None


    def generate_all(self, train_data: pd.DataFrame, conditional_column=None, n_to_generate=None):
        try:
            if n_to_generate is None:
                n_to_generate = len(train_data)
            _log.info(f"Generating {n_to_generate} synthetic samples using CTTVAE...")

            # Compute class distribution
            class_counts = train_data[conditional_column].value_counts(normalize=True)
            _log.info(f"Class distribution: {class_counts.to_dict()}")

            synthetic_dfs = []

            for class_value, prop in class_counts.items():
                # if class_value != 1:
                #     continue
                n_class_samples = int(round(prop * n_to_generate))
                _log.info(f"Generating {n_class_samples} samples for class {class_value}...")

                synthetic_samples = self.model.sample(
                    n_samples=n_class_samples,
                    condition_column=conditional_column,
                    condition_value=class_value
                )

                df_class = pd.DataFrame(synthetic_samples)
                synthetic_dfs.append(df_class)

            full_synthetic_df = pd.concat(synthetic_dfs, axis=0).reset_index(drop=True)

            # Log final class distribution
            if conditional_column in full_synthetic_df.columns:
                distribution = full_synthetic_df[conditional_column].value_counts(normalize=True)
                _log.info("Synthetic distribution (all):")
                _log.info(distribution)

            return full_synthetic_df

        except Exception as e:
            _log.warning(f"Generation failed: {e}")
            return None
        

    def generate_minority(self, train_data: pd.DataFrame, conditional_column, n_to_generate=None):
        class_counts = train_data[conditional_column].value_counts()
        minority_class = class_counts.idxmin()
        max_count = class_counts.max()

        if n_to_generate is None:
            n_to_generate = max_count - class_counts[minority_class]

        # if n_to_generate <= 0:
        #     _log.warning("No need to generate minority samples — already balanced.")
        #     return train_data.copy()

        # Sample synthetic data conditioned on the minority instance
        synthetic_samples = self.model.sample(
            n_samples=n_to_generate,
            condition_column=conditional_column,
            condition_value=minority_class
        )
        minority_samples = pd.DataFrame(synthetic_samples)

        # if len(minority_samples) > n_to_generate:
        #     minority_samples = minority_samples.sample(n=n_to_generate, random_state=42)
        # elif len(minority_samples) < n_to_generate:
        #     _log.warning(f"Generated only {len(minority_samples)} minority samples instead of {n_to_generate}")

        augmented_df = pd.concat([train_data, minority_samples], axis=0).reset_index(drop=True)
        return augmented_df

    def generate(self, train_data: pd.DataFrame, conditional_column, sampling_strategy='all', n_to_generate=None):
        if self.random_state is not None:
            set_global_seed(self.random_state)

        if sampling_strategy == 'all':
            synthetic_data = self.generate_all(train_data, conditional_column=conditional_column, n_to_generate=n_to_generate)
        elif sampling_strategy == 'minority':
            synthetic_data = self.generate_minority(train_data, conditional_column=conditional_column, n_to_generate=n_to_generate)
        else:
            raise ValueError(f"Unknown sampling strategy: '{sampling_strategy}'.")

        return synthetic_data
    