import numpy as np
import torch as th

from ctgan.synthesizers.ctgan import CTGANSynthesizer


class CtganOversampler:
    def __init__(self, emb_dim):
        self.emb_dim = emb_dim

    def fit_resample(self, x_train, y_train, categorical_features):
        num_features = x_train.shape[1]
        epoch = 150
        embedding_dim = self.emb_dim
        batch_size = 32
        pac = 8
        ctgan = CTGANSynthesizer(epochs=epoch, batch_size=batch_size, pac=pac, embedding_dim=embedding_dim,
                                 generator_dim=(256, 128), discriminator_dim=(64, 128, 256))
        data_w_label = np.concatenate((x_train, y_train[:, None]), 1)  # add labels
        if categorical_features:
            discrete_columns = categorical_features + [num_features]
        else:
            discrete_columns = [num_features]  # only label is discrete, and it is appended as last column of the data
        ctgan.fit(data_w_label, discrete_columns=discrete_columns)
        num_required_minority = y_train.shape[0] - 2*int(th.count_nonzero(y_train))
        minority_sampled = th.zeros((0, num_features))
        while(True):
            sampled = th.tensor(ctgan.sample(num_required_minority))
            sampled = sampled[sampled[:, num_features] == 1]
            sampled = sampled[:, 0:-1]  # remove labels
            minority_sampled = th.concat((minority_sampled, sampled), dim=0)
            if minority_sampled.shape[0] > num_required_minority:
                minority_sampled = minority_sampled[0:num_required_minority, :]
                break
        y_gen = th.ones(minority_sampled.shape[0], dtype=y_train.dtype)
        x_bal = th.cat([x_train] + [minority_sampled])
        y_bal = th.cat([y_train] + [y_gen])
        return (x_bal, y_bal)

