import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy
from math import log
from abc import abstractmethod
import pandas as pd
import numpy as np

from torch.optim.lr_scheduler import MultiplicativeLR
from torch.utils.data import TensorDataset
import tabular_deep_smote.data_transformer as data_transformer
import tabular_deep_smote.metric_learn_losses as losses
import tabular_deep_smote.smote_variations as smote_variations


#######################################################################
## ENCODER
#######################################################################

class Skip(nn.Module):
    def __init__(self, i, o, skip_dim, activation=True):
        super().__init__()
        self.fc = nn.Linear(i+skip_dim, o)
        self.act = nn.LeakyReLU(0.2)
        self.activation = activation

    def forward(self, input_, skip):
        concat = torch.cat([input_, skip], dim=1)
        out = self.fc(concat.type(torch.float32))
        if self.activation:
            out = self.act(out)
        return out


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims=None, latent_dim=None, skip_connect=False):
        super().__init__()

        self.skip_connect = skip_connect

        assert isinstance(input_dim, int)
        dim_list = [input_dim]

        if hidden_dims is None:
            hidden_dims = [256, 128]
        else:
            assert isinstance(hidden_dims, list)
            assert all(isinstance(dim, int) for dim in hidden_dims)
        dim_list += hidden_dims

        if latent_dim is None:
            latent_dim = int(log(self.input_dim, 2))
        else:
            assert isinstance(latent_dim, int)
        dim_list += [latent_dim]

        seq = []
        for i in range(1, len(dim_list)):
            if i != len(dim_list)-1:
                if skip_connect and i != 1:
                    seq.extend([Skip(dim_list[i - 1], dim_list[i], input_dim)])
                else:
                    seq.extend((
                        nn.Linear(dim_list[i - 1], dim_list[i]),
                        nn.LeakyReLU(0.2)
                    ))
            else:
                if skip_connect and i != 1:
                    seq.extend([Skip(dim_list[i - 1], dim_list[i], input_dim, activation=False)])
                else:
                    seq.extend((
                        nn.Linear(dim_list[i - 1], dim_list[i]),
                    ))
        self.layers = nn.Sequential(*seq)

    def forward(self, x):
        if self.skip_connect:
            out = x.type(torch.float32)
            for idx, layer in enumerate(self.layers):
                if isinstance(layer, Skip):
                    out = layer(out, x)
                else:
                    out = layer(out)
            return out
        else:
            return self.layers(x.type(torch.float32))

#######################################################################
## DECODER
#######################################################################

class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dims=None, latent_dim=None, skip_connect=False):
        super(Decoder, self).__init__()

        self.skip_connect = skip_connect

        if latent_dim is None:
            latent_dim = int(log(self.input_dim, 2))
        else:
            assert isinstance(latent_dim, int)
        dim_list = [latent_dim]

        if hidden_dims is None:
            hidden_dims = [64, 128, 256]
        else:
            assert isinstance(hidden_dims, list)
            assert all(isinstance(dim, int) for dim in hidden_dims)
        dim_list += hidden_dims

        assert isinstance(input_dim, int)
        dim_list += [input_dim]

        seq = []
        for i in range(1, len(dim_list)):
            if i != len(dim_list)-1:
                if skip_connect and i != 1:
                    seq.extend([Skip(dim_list[i - 1], dim_list[i], latent_dim)])
                else:
                    seq.extend((
                        nn.Linear(dim_list[i - 1], dim_list[i]),
                        nn.LeakyReLU(0.2)
                    ))
            else:
                if skip_connect and i != 1:
                    seq.extend([Skip(dim_list[i - 1], dim_list[i], latent_dim, activation=False)])
                else:
                    seq.extend((
                        nn.Linear(dim_list[i - 1], dim_list[i]),
                    ))
        self.layers = nn.Sequential(*seq)

    def forward(self, x):
        if self.skip_connect:
            out = x.type(torch.float32)
            for idx, layer in enumerate(self.layers):
                if isinstance(layer, Skip):
                    out = layer(out, x)
                else:
                    out = layer(out)
            return out
        else:
            return self.layers(x.type(torch.float32))


#######################################################################
## Aux Classes
#######################################################################


class SwapNoise(object):
    """
    One possible noising mechanism for tabular data
    Swapping values within the same column
    """
    def __init__(self, swap_prob):
        self.swap_prob = swap_prob

    def apply(self, x):
        swap_map = torch.bernoulli(self.swap_prob * torch.ones((x.shape)).to(x.device))
        corrupted_x = torch.where(swap_map == 1, x[torch.randperm(x.shape[0])], x)
        return corrupted_x, swap_map


class OversampleResult:
    """
    The object type returned from the oversample method
    """
    def __init__(self, x_all, y_all, x_gen, interpolation):
        self.x_all = x_all
        self.y_all = y_all
        self.x_gen = x_gen
        self.interpolation = interpolation


class TrainResults:
    """
    The object type returned from the fit method
    """
    def __init__(self, best_epoch, best_losses):
        self.best_epoch = best_epoch
        self.best_losses = best_losses

#######################################################################
## Aux Functions
#######################################################################


def get_num_maj_min(y):
    """
    :param y: array-like object representing labels (0 = minority ; 1 = majority)
    :return: (num_majority, num_minority)
    """
    num_min = int(y.sum(dim=0))
    num_maj = y.shape[0] - num_min
    return num_maj, num_min


def get_num_samples(num_maj, num_min, oversample_ratio):
    """
    :param num_maj: #majority samples
    :param num_min: #monority samples
    :param oversample_ratio: desired oversample ratio
    :return: number of minority samples to be generated
    """
    assert(0 < oversample_ratio <= 1)
    if oversample_ratio:
        num_samples = int(num_maj * oversample_ratio) - num_min
    else:
        num_samples = num_maj - num_min
    assert (num_samples > 0)
    return num_samples


def df_to_tensor(df):
    return torch.from_numpy(df.values).float()


def convert_to_tensor_if_needed(x, squeeze=False):
    if isinstance(x, pd.DataFrame):
        x = df_to_tensor(x)
    elif isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    elif torch.is_tensor(x):
        pass
    else:
        raise Exception("Type not supported")
    return x.squeeze() if squeeze else x

#######################################################################
## Base Model (AE + VAE)
#######################################################################

class BaseModel(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device

    @abstractmethod
    def forward(self, *inputs):
        raise NotImplementedError

    @abstractmethod
    def loss(self, *inputs):
        raise NotImplementedError

    @abstractmethod
    def _encode(self, *inputs):
        """
        Encoding method called at NN training. Namely:
        1. Needs to include gradient calculations (do not use torch.no_grad())
        2. Should not include preprocessing prior to the NN
        3. Assumes input tensors are on the model's device
        :param inputs: input tensors to the NN Encoder
        :return: output tensors (encoded tensors - outputs of the NN)
        """
        raise NotImplementedError

    @abstractmethod
    def _decode(self, *inputs):
        """
        Decoding method called at NN training. Namely:
        1. Needs to include gradient calculations (do not use torch.no_grad())
        2. Should not include post-processing after the NN
        3. Assumes input tensors are on the model's device
        :param inputs: input tensors to the NN Decoder
        :return: output tensors (decoded tensors - outputs of the NN prior to post-processing)
        """
        raise NotImplementedError

    def encode(self, x):
        """
        Full encoding scheme - preprocessing + encoding via NN
        :param x: input to be encoded
        :return: encodings (returned on same device as input x)
        """
        orig_device = x.device
        if self.transform_data:
            if orig_device.type != 'cpu':
                x = x.cpu()
            x = torch.from_numpy(self.transformer.transform(x))
        x = x.to(self.device)
        with torch.no_grad():
            encoding = self._encode(x)
        return encoding.to(orig_device)

    def decode(self, z):
        """
        Full decoding scheme - decoding via NN + post-processing
        :param z: encoding to be decoded
        :return: decodings (returned on same device as input z)
        """
        orig_device = z.device
        z = z.to(self.device)
        with torch.no_grad():
            x_hat = self._decode(z)
        if self.transform_data:
            if self.device != 'cpu':
                x_hat = x_hat.cpu()
            x_hat = torch.from_numpy(self.transformer.inverse_transform(x_hat.numpy()))
        return x_hat.to(orig_device)

    def step(self):
        self.enc_dec_optim.step()
        self.enc_dec_scheduler.step()
        for optimizer in self.loss_optimizers:
            optimizer.step()

    def rec_loss(self, x_batch, x_hat, y_batch):
        if self.transform_data:
            output_info = self.transformer.output_info_list
            st = 0
            loss = []
            for column_info in output_info:
                for span_info in column_info:
                    if span_info.activation_fn is None:
                        ed = st + span_info.dim
                        eq = x_batch[:, st] - x_hat[:, st]
                        loss.append(eq**2)
                        st = ed
                    elif span_info.activation_fn == 'softmax':
                        ed = st + span_info.dim
                        loss.append(cross_entropy(x_hat[:, st:ed], torch.argmax(x_batch[:, st:ed], dim=-1), reduction='none'))
                        st = ed
                    else:
                        raise Exception('activation_fn not supported')
            loss = torch.mean(torch.stack(loss), dim=0)
            if self.rec_reweight_loss:
                return losses.weighted_loss(loss, y_batch, self.reweight_factors)
            else:
                return torch.mean(loss)
        else:
            if self.rec_reweight_loss:
                return losses.weighted_mse_loss(x_batch, x_hat, y_batch, self.reweight_factors)
            else:
                return nn.functional.mse_loss(x_batch, x_hat)

    def fit(self, train_data, validation_data, best_checkpoint_save_path):
        """
        :param train_data: (x,y) tuple used for training
        :param validation_data:(x,y) tuple used of validation
        :param best_checkpoint_save_path: file name used for storing the best model parameters
        :return:
        """
        # Data ########################################################################################

        x_train, y_train = train_data
        x_train = convert_to_tensor_if_needed(x_train)
        y_train = convert_to_tensor_if_needed(y_train, squeeze=True)
        assert(x_train.shape[0] == y_train.shape[0])

        if self.transform_data:
            self.transformer = data_transformer.DataTransformer()
            discrete_columns = self.categorical_features if self.categorical_features is not None else []
            self.transformer.fit(x_train,
                                 discrete_columns=discrete_columns,
                                 transform_numericals=self.mode_specific_normalization)
            x_train = torch.from_numpy(self.transformer.transform(x_train))

        x_train = x_train.type(torch.float32)
        y_train = y_train.type(torch.float32)
        x_min = x_train[y_train == 1]
        x_maj = x_train[y_train == 0]

        if self.train_on == 'all':
            if self.balance_b4_train:
                from imblearn.over_sampling import RandomOverSampler
                x_train_balanced, y_train_balanced = RandomOverSampler(random_state=42).fit_resample(x_train, y_train)
                tabular_bal = TensorDataset(torch.from_numpy(x_train_balanced), torch.from_numpy(y_train_balanced))
            else:
                tabular_bal = TensorDataset(x_train, y_train)
        elif self.train_on == 'min':
            tabular_bal = TensorDataset(x_min, torch.ones(x_min.shape[0]))
        else:
            raise Exception("Argument train_on is not supported")

        train_loader = torch.utils.data.DataLoader(tabular_bal,
                                                   batch_size=self.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=True)

        if validation_data is not None:
            x_val, y_val = validation_data
            if self.transform_data:
                x_val = torch.from_numpy(self.transformer.transform(x_val))
            val_dataset = TensorDataset(x_val, y_val)
            validation_loader = torch.utils.data.DataLoader(val_dataset,
                                                            batch_size=self.batch_size,
                                                            shuffle=True,
                                                            num_workers=0,
                                                            drop_last=True)

        if self.reweight_loss or self.rec_reweight_loss:
            self.init_reweight_factors(num_min=x_min.shape[0], num_maj=x_maj.shape[0])

        # Arch ################################################################################
        """
        Model Arch is data-dependent. Hence the models are defined at fit (once the data itself is available)
        """
        self.input_dim = x_train.shape[1]
        self.latent_dim = int(x_train.shape[1] * self.latent_dim_ratio)
        self.enc_hidden_dims = \
            [int(dim[:-1])*self.input_dim if dim[-1]=='x' else int(dim) for dim in self.enc_hidden_dims]
        self.dec_hidden_dims = \
            [int(dim[:-1])*self.input_dim if dim[-1]=='x' else int(dim) for dim in self.dec_hidden_dims]
        """
        if model_type == VAE:
            self.encoder = Encoder(self.input_dim, self.enc_hidden_dims, self.latent_dim * 2).to(self.device)
        else:
        """
        self.encoder = Encoder(self.input_dim, self.enc_hidden_dims, self.latent_dim).to(self.device)

        self.decoder = Decoder(self.input_dim, self.dec_hidden_dims, self.latent_dim).to(self.device)
        self.enc_dec_optim = torch.optim.AdamW(list(self.encoder.parameters()) + list(self.decoder.parameters()),
                                               lr=self.lr)
        self.enc_dec_scheduler = MultiplicativeLR(self.enc_dec_optim, lr_lambda=(lambda epoch: self.lr_decay))
        self.declare_losses()

        # Train Loop ###########################################################################

        train_loss_history = {loss: [] for loss in self.loss_list}
        val_loss_history = {loss: [] for loss in self.loss_list}

        best_train_loss = None
        best_validation_loss = None
        iter_no_improve = 0
        early_stop_train = self.early_stop_train
        early_stop_val = self.early_stop_val
        best_epoch = 0
        epoch = 0

        swap_noising = SwapNoise(self.swap_prob)

        while self.early_stop_no_limit or epoch <= self.epochs:
            self.reset_losses()
            self.train()
            for x_batch, y_batch in train_loader:
                x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                # noise
                if self.swap_prob != 0:
                    x_batch_to_encode, swap_map = swap_noising.apply(x_batch)
                else:
                    x_batch_to_encode = x_batch
                # zero gradients for each batch
                self.zero_grad()
                # Forward
                model_outputs = self(x_batch_to_encode)
                # Calc Loss
                comb_loss = self.loss(x_batch, y_batch, x_train, y_train, model_outputs)
                # Backwards
                comb_loss.backward()
                # Steps
                self.step()

            # update epoch loss (average)
            epoch_losses = {loss: val / len(train_loader) for loss, val in self.losses.items()}
            epoch_losses_str = [f'{loss}:{val:.6f}' for loss, val in epoch_losses.items()]
            if self.verbose:
                print(f'Epoch: {epoch} {epoch_losses_str}')
            for loss in self.loss_list:
                train_loss_history[loss].append(epoch_losses[loss])

            if validation_data is not None:
                self.reset_losses()
                self.eval()
                with torch.no_grad():
                    for x_batch, y_batch in validation_loader:
                        x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                        model_outputs = self(x_batch)
                        self.loss(x_batch, y_batch, x_train, y_train, model_outputs)
                    # update epoch loss
                    val_epoch_losses = {loss: val / len(validation_loader) for loss, val in self.losses.items()}
                    epoch_losses_str = [f'{loss}:{val:.6f}' for loss, val in val_epoch_losses.items()]
                    if self.verbose:
                        print(f'Valid: {epoch} {epoch_losses_str}')
                    for loss in self.loss_list:
                        val_loss_history[loss].append(val_epoch_losses[loss])

                    if self.early_stop_val_type == 'only_val':
                        val_loss = val_epoch_losses['combined_loss']
                    else:
                        val_loss = val_epoch_losses['combined_loss'] + epoch_losses['combined_loss']
                    if best_validation_loss is None or val_loss < best_validation_loss:
                        best_validation_loss = val_loss
                        iter_no_improve = 0
                        if self.verbose:
                            print('Saving..')
                        self.save(best_checkpoint_save_path, self.dataset)
                        best_epoch = epoch
                        best_losses = val_epoch_losses
                    else:
                        iter_no_improve += 1
                    if iter_no_improve > early_stop_val:
                        break

            else:
                if best_train_loss is None or epoch_losses['combined_loss'] < best_train_loss:
                    if self.verbose:
                        print('Saving..')
                    self.save(best_checkpoint_save_path, self.dataset)
                    best_epoch = epoch
                    best_losses = epoch_losses
                    best_train_loss = epoch_losses['combined_loss']
                    iter_no_improve = 0
                else:
                    iter_no_improve += 1
                if self.early_stop_no_limit and iter_no_improve > early_stop_train:
                    break
            epoch += 1

        """ FIXME: remove
        if self.gen_visuals:
            if model_type == AE:
                bit_map = {
                    'model_loss': True,
                    'metric_learn': True if self.lambda_metric_learn != 0 else False,
                    'combined_loss': False
                }
            elif model_type == VAE:
                bit_map = {
                    'mse': True,
                    'kld': True,
                    'combined_loss': True
                }
            visualizer.plot_loss(train_loss_history, bit_map, 'train_loss')
            visualizer.plot_loss(val_loss_history, bit_map, 'validation_loss')
        """
        # Load the best checkpoint
        self.load_state_dict(torch.load(best_checkpoint_save_path))
        return TrainResults(best_epoch, best_losses)

    def oversample(self, data, new_minority_save_path, oversample_ratio=1):
        """
        Oversample the dataset (x,y).
        :param data: (x,y) tuple. y are the labels (0 = majority; 1 = minority)
        :param new_minority_save_path: store the newly generated minority samples in new_minority_file_path. The path should include the file name.
        :param oversample_ratio: #minortiy + #new_minority = oversample_ratio * #majority.
        :return:
        """
        x, y = data
        x = convert_to_tensor_if_needed(x)
        y = convert_to_tensor_if_needed(y, squeeze=True)
        # balanced baseline classifier is used
        if self.filter_margin:
            x_min = x[y == 1]
            if self.classifier_type is None:
                raise Exception('classifier_type is required')
            from experiments import experiment_utils
            filter_classifier = experiment_utils.create_classifier(type=self.classifier_type, seed=42, class_weight='balanced')
            filter_classifier.fit(x.numpy(), y.numpy())
        #
        num_maj, num_min = get_num_maj_min(y)
        num_samples = get_num_samples(num_maj, num_min, oversample_ratio)
        x_enc = self.encode(x)
        remaining_num_sample = num_samples
        new_samples = torch.zeros((0, x.shape[1]))
        i = 0
        relax_factor = 1
        while remaining_num_sample > 0:
            i += 1
            relax_factor = 0.95 * relax_factor if i % 10 == 0 else relax_factor
            # Generate
            interpolations, base_indices, neighbor_indices, ratio = smote_variations.deep_smote_oversample(x, x_enc, y,
                                                                                                           self.smote_algo_type,
                                                                                                           self.m_neighbors,
                                                                                                           self.k_neighbors,
                                                                                                           self.knn_algorithm,
                                                                                                           num_samples=max(remaining_num_sample, 100),
                                                                                                           importance_oversampling=self.importance_oversampling,
                                                                                                           )
            new_decoded = self.decode(interpolations)
            # Filter
            if self.filter_margin:
                base_scores = filter_classifier.predict_proba(x_min[base_indices].numpy())[:, 1]
                neighbor_scores = filter_classifier.predict_proba(x_min[neighbor_indices].numpy())[:, 1]
                syn_scores = filter_classifier.predict_proba(new_decoded.numpy())[:, 1]
                # thr = np.minimum(base_scores, neighbor_scores) / 2    FIXME: remove
                # thr = (base_scores + ratio.reshape(-1)*(neighbor_scores - base_scores)) / 2   FIXME: remove
                thr = (neighbor_scores**ratio.reshape(-1) * base_scores**(1-ratio.reshape(-1))) / self.filter_margin
                thr = thr * relax_factor
                new_decoded = new_decoded[syn_scores > thr]
                # Add
                new_samples = torch.concat((new_samples, new_decoded[0: remaining_num_sample - 1]), dim=0)
                remaining_num_sample -= new_decoded.shape[0]

                """
                base_scores = filter_classifier.predict_proba(x_min[base_indices].numpy())[:, 1]
                neighbor_scores = filter_classifier.predict_proba(x_min[neighbor_indices].numpy())[:, 1]
                syn_scores = filter_classifier.predict_proba(new_decoded.numpy())[:, 1]
                # thr = np.minimum(base_scores, neighbor_scores) / 2    FIXME: remove
                # thr = (base_scores + ratio.reshape(-1)*(neighbor_scores - base_scores)) / 2   FIXME: remove
                thr = (neighbor_scores**ratio.reshape(-1) * base_scores**(1-ratio.reshape(-1))) / self.filter_margin
                thr = thr * relax_factor
                above_thr_idx = syn_scores > thr
                below_thr_idx = np.logical_not(above_thr_idx)
                new_decoded = new_decoded[above_thr_idx]
                # Add
                new_samples = torch.concat((new_samples, new_decoded[0: remaining_num_sample - 1]), dim=0)
                remaining_num_sample -= new_decoded.shape[0]

                # required for additional latent interpolations in the latent space
                x_min = x[y == 1]
                x_min_enc = x_enc[y == 1]       # in seep_smote
                x_base = x_min_enc[base_indices]  # in seep_smote
                x_base_orig = x_min[base_indices]
                x_neighbor = x_min_enc[neighbor_indices]  # in seep_smote

                for ratio in [0.4, 0.3, 0.2, 0.1, 0.0]:
                    if remaining_num_sample <= 0:
                        break
                    if ratio == 0:
                        new_decoded = x_base_orig[below_thr_idx]
                    else:
                        # keep only samples that were below thr
                        x_base = x_base[below_thr_idx]
                        x_base_orig = x_base_orig[below_thr_idx]
                        x_neighbor = x_neighbor[below_thr_idx]
                        base_scores = base_scores[below_thr_idx]
                        neighbor_scores = neighbor_scores[below_thr_idx]
                        # resample & decode
                        interpolations = x_base + ratio*(x_neighbor - x_base)
                        new_decoded = self.decode(interpolations)
                        thr = (neighbor_scores**ratio * base_scores**(1-ratio)) / self.filter_margin
                        # score
                        syn_scores = filter_classifier.predict_proba(new_decoded.numpy())[:, 1]
                        above_thr_idx = syn_scores > thr
                        below_thr_idx = np.logical_not(above_thr_idx)
                        new_decoded = new_decoded[above_thr_idx]
                    # Add
                    new_samples = torch.concat((new_samples, new_decoded), dim=0)
                    remaining_num_sample -= new_decoded.shape[0]
                """
            else:
                # Add
                new_samples = torch.concat((new_samples, new_decoded[0: remaining_num_sample - 1]), dim=0)
                remaining_num_sample -= new_decoded.shape[0]

        x_gen = new_samples.cpu()
        y_gen = torch.ones(x_gen.shape[0], dtype=y.dtype)
        torch.save([x_gen, y_gen], new_minority_save_path)
        x_all = torch.cat([x] + [x_gen])
        y_all = torch.cat([y] + [y_gen])
        return OversampleResult(x_all, y_all, x_gen, interpolations)


#######################################################################
## AE
#######################################################################

class TDSMOTE(BaseModel):

    """
    Tabular Deep SMOTE based on AE.

    Args:

        dataset_name (string):
            dataset name for temporary files (e.g. model checkpoints)

        categorical_features (tuple or list of ints):
            Indices of categorical features. Default = None.

        lambda_metric_learn (float):
            Factor of the metric_learn loss compared to the reconstruction loss. Default = 0.5.

        metric_learn_type (string):
            Valid strings - according to metric_learn_losses.py. Default: normalized_softmax

        label_smoothing (float):
            label smoothing value

        mode_specific_normalization (boolean):
            Represent numeric features using Mode Specific Normalization. Default = False.

        reweight_loss (boolean):
            Reweight the loss target by class. Default = False.

        balance_b4_train (boolean):
            Random oversample the minority class before train. Default = False.

        enc_hidden_dims (tuple or list of ints or formatted strings):
            Specifies the hidden layer dimensions of the Encoder.
            A valid value is either int or "(int)x" (e.g. '16x') which denotes int*num_features.
            Default = ['32x', '16x']

        dec_hidden_dims (tuple or list of ints or formatted strings):
            Specifies the hidden layer dimensions of the Decoder.
            A valid value is either int or "(int)x" (e.g. '16x') which denotes int*num_features.
            Default = ['8x', '16x', '32x']

        latent_dim_ratio (float):
            Latent dimension of AE is latent_dim_ratio * num_features. Default = 0.75.

        swap_prob (float):
            Swap noise probablity -
             probability to repalce a feature value with a value of the same feature but from a different sample
            Default = 0.0

        batch_size (int):
            Batch Size. Default = 8.

        lr (float):
            Learning rate. Default = 0.001.

        lr_decay (float):
            Learning Rate Decay. Default = 0.9999.

        epochs (int):
            Number of epochs in training (if not specified otherwise, e.g. early stop). Default = 100.

        early_stop_no_limit (boolean):
            Early stop according to early_stop_train. Default = False.

        early_stop_train (int):
            Stop training after no improvement on training set for early_stop_train epochs. Default = 10.

        early_stop_val (int):
            Stop training after no improvement on validation set for early_stop_val epochs. Default = 20.

        early_stop_val_type (string):
            Valid options:
                'only_val' - early stop on validation alone
                'val_n_train' - early stop on train + validation
            Default: 'only_val'

        knn_algorithm (string):
            KNN algorithm (distance metric). Default: 'brute'

        k_neighbors (int):
            Number of neighbors in KNN run (for deep smote). Default = 5.

        smote_algo_type (stinrg):
            Default = 'orig' (original SMOTE algorithm)

        importance_oversampling (boolean):
            Details in documentation. Default = False.

        filter_margin(float):
            Details in documentation. Default = None.
            None - synthetic filtering is not applied.

        classifier_type (string):
            'catboost' or 'svm'. Default = 'catboost'

        gen_visuals (boolean):
            Plot PCA of latent space and original space. Plot loss(epoch).
            Default = True

        verbose (boolean):
            Print training errors
            Default: True

    """

    def __init__(self,
                 dataset_name,
                 categorical_features,
                 lambda_metric_learn,
                 metric_learn_type,
                 label_smoothing,
                 mode_specific_normalization,
                 reweight_loss,
                 rec_reweight_loss,
                 balance_b4_train,
                 latent_dim_ratio,
                 enc_hidden_dims,
                 dec_hidden_dims,
                 # Train
                 swap_prob,
                 batch_size,
                 lr,
                 lr_decay,
                 epochs,
                 train_on,
                 early_stop_no_limit,
                 early_stop_train,
                 early_stop_val,
                 early_stop_val_type,
                 # Oversample
                 smote_algo_type,
                 m_neighbors,
                 k_neighbors,
                 knn_algorithm,
                 importance_oversampling,
                 filter_margin,
                 classifier_type,
                 gen_visuals,
                 verbose,
                 device):
        super().__init__(device)

        self.dataset = dataset_name
        self.categorical_features = categorical_features

        self.lambda_metric_learn = lambda_metric_learn
        self.metric_learn_type = metric_learn_type
        self.label_smoothing = label_smoothing

        self.mode_specific_normalization = mode_specific_normalization
        self.reweight_loss = reweight_loss
        self.rec_reweight_loss = rec_reweight_loss
        self.balance_b4_train = balance_b4_train
        self.latent_dim_ratio = latent_dim_ratio
        self.enc_hidden_dims = enc_hidden_dims
        self.dec_hidden_dims = dec_hidden_dims
        self.swap_prob = swap_prob

        self.batch_size = batch_size
        self.lr = lr
        self.lr_decay = lr_decay
        self.epochs = epochs
        self.train_on = train_on
        self.early_stop_no_limit = early_stop_no_limit
        self.early_stop_train = early_stop_train
        self.early_stop_val = early_stop_val
        self.early_stop_val_type = early_stop_val_type

        self.smote_algo_type = smote_algo_type
        self.m_neighbors = m_neighbors
        self.k_neighbors = k_neighbors
        self.knn_algorithm = knn_algorithm
        self.importance_oversampling = importance_oversampling
        self.filter_margin = filter_margin
        self.classifier_type = classifier_type

        self.gen_visuals = gen_visuals
        self.verbose = verbose

        self.transform_data = self.mode_specific_normalization or (self.categorical_features is not None)
        self.to(device)
        self.device = device


        self.reweight_factors = None  # set at fit call
        self.latent_dim = None  # set at fit call
        self.input_dim = None  # set at fit call
        self.encoder = None  # set at fit call
        self.decoder = None  # set at fit call

    def declare_losses(self):   # called at fit (once latent_dim is determined)
        self.ml_loss = losses.MetricLearnLoss(loss_type=self.metric_learn_type, device=self.device,
                                              latent_dim=self.latent_dim,
                                              reweight_loss=self.reweight_loss,
                                              reweight_factors=self.reweight_factors,
                                              label_smoothing=self.label_smoothing)
        self.contr_optim = torch.optim.Adam(self.ml_loss.parameters(), lr=self.ml_loss.lr)
        self.loss_optimizers = [
            self.contr_optim
        ]
        self.loss_list = ['model_loss',
                          'metric_learn',
                          'combined_loss']
        self.losses = {loss: 0.0 for loss in self.loss_list}

    def init_reweight_factors(self, num_min, num_maj):
        from math import sqrt
        factor = 0.75
        min_factor = sqrt(num_maj / num_min) * factor
        maj_factor = sqrt(num_min / num_maj) / factor
        self.reweight_factors = min_factor, maj_factor

        """ Other experimented options
        # Option 1
        total = num_min + num_maj
        min_factor = total / (2 * num_min)
        maj_factor = total / (2 * num_maj)

        # Option 2 - Inverse of Square Root of Number of Samples (ISNS)
        from math import sqrt
        min_factor = 1 / sqrt(num_min)
        maj_factor = 1 / sqrt(num_maj)

        # Option 3 - Effective number (Google)
        import numpy as np
        effective_num = 1.0 - np.power(0.99, [num_min, num_maj])
        weights = (1.0 - 0.99) / np.array(effective_num)
        weights = weights / np.sum(weights) * 2
        min_factor = weights[0]
        maj_factor = weights[1]

        # Option 4 - my
        from math import sqrt
        min_factor = sqrt(num_maj / num_min)
        maj_factor = 1
        """

    def _encode(self, x):
        return self.encoder(x)

    def _decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z = self._encode(x)
        x_hat = self._decode(z)
        return x_hat, z

    def model_loss(self, x_batch, x_hat, y_batch):
        return self.rec_loss(x_batch, x_hat, y_batch)

    def reset_losses(self):
        self.losses = {loss: 0.0 for loss in self.loss_list}

    def loss(self, x_batch, y_batch, x_train, y_train, model_outputs):

        x_hat = model_outputs[0]
        z = model_outputs[1]

        model_loss = self.model_loss(x_batch, x_hat, y_batch)
        metric_learn_loss = self.ml_loss(z, y_batch)
        comb_loss = model_loss + self.lambda_metric_learn * metric_learn_loss

        # update epoch losses
        self.losses['model_loss'] += model_loss.item()
        self.losses['metric_learn'] += metric_learn_loss.item()
        self.losses['combined_loss'] += comb_loss.item()

        return comb_loss

    def save(self, best_checkpoint_save_path, dataset):
        torch.save(self.state_dict(), best_checkpoint_save_path)

"""

class VAE(BaseModel):

    def __init__(self, args, hparams, device):
        super().__init__(device)
        self.transform_data = (not args.no_mode_specific_normalization) or (args.categorical_features is not None)
        self.hparams = hparams
        self.kld_weight = hparams['kld_weight']
        self.device = device

        self.to(device)

        self.input_dim = None  # set at fit call
        self.enc_hidden_dims = None  # set at fit call
        self.dec_hidden_dims = None  # set at fit call
        self.encoder = None  # set at fit call
        self.decoder = None  # set at fit call

    def declare_losses(self):
        # lambda hot clustering
        min_center = self.hparams['lambda_centers'] * torch.ones(self.latent_dim, device=self.device) #torch.tensor([1.5, 0, 0, 0, 0, 0])  # torch.rand(self.latent_dim)
        maj_center = (-1) * min_center  # torch.rand(self.latent_dim)

        self.loss_optimizers = []

        self.loss_list = [
            'mse',
            'kld',
            'combined_loss'
        ]

        self.losses = {loss: 0.0 for loss in self.loss_list}

    def _encode(self, x: torch.Tensor, is_forward=False):
        encoded = self.encoder(x)
        mu, log_var = torch.split(encoded, encoded.shape[-1] // 2, dim=1)
        if is_forward:
            eps = torch.randn_like(mu)
            std = torch.exp(0.5 * log_var)
            z = eps * std + mu
            return z, mu, log_var
        return mu

    def _decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z, mu, log_var = self._encode(x, is_forward=True)
        x_hat = self._decode(z)
        return x_hat, mu, log_var

    def reset_losses(self):
        self.losses = {loss: 0.0 for loss in self.loss_list}

    def loss(self, x_batch, y_batch, x_train, y_train, model_outputs):
        x_hat = model_outputs[0]
        mu = model_outputs[1]
        log_var = model_outputs[2]

        for i in range(self.batch_size):
            if i == 0:
                if y_batch[0] == 0:
                    class_centers = self.maj_center[None, :]
                else:
                    class_centers = self.min_center[None, :]
            else:
                if y_batch[i] == 0:
                    class_centers = torch.cat((class_centers, self.maj_center[None, :]), 0)
                else:
                    class_centers = torch.cat((class_centers, self.min_center[None, :]), 0)


        kld_loss = torch.mean(0.5 * torch.sum((mu-class_centers)** 2 + log_var.exp() - 1 - log_var, dim=1), dim=0)

        mse = self.rec_loss(x_batch, x_hat, y_batch)
        comb_loss = mse + self.kld_weight * kld_loss

        self.losses['mse'] += mse.item()
        self.losses['kld'] += kld_loss.item()
        self.losses['combined_loss'] += comb_loss.item()

        return comb_loss

    def save(self, best_checkpoint_save_path):
        torch.save(self.state_dict(), best_checkpoint_save_path)

"""