import pandas as pd
import torch

from .models.ctgan import CTGAN as CTGANModel
from .models.ctgan import Generator
from .models.tvae import Decoder
from .models.tvae import TVAE as TVAEModel
from ...base import *
from ...utils import get_kwargs


class CTGAN(
    BaseUnconditionalGeneratorMixIn,
    Base):

    """
    https://github.com/sdv-dev/CTGAN
    Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019).
    Modeling tabular data using conditional gan.
    Advances in neural information processing systems, 32.
    """

    def __init__(self, embedding_dim, generator_dim, discriminator_dim,
                 generator_lr, generator_decay, discriminator_lr, discriminator_decay,
                 batch_size, max_epochs, **kwargs):
        super().__init__(**get_kwargs(**kwargs))
        generator_dim = [int(x) for x in generator_dim.split(',')]
        discriminator_dim = [int(x) for x in discriminator_dim.split(',')]

        self.model = CTGANModel(
            embedding_dim=embedding_dim, generator_dim=generator_dim,
            discriminator_dim=discriminator_dim, generator_lr=generator_lr,
            generator_decay=generator_decay, discriminator_lr=discriminator_lr,
            discriminator_decay=discriminator_decay, batch_size=batch_size,
            epochs=max_epochs
        )

    def fit(self, scenario=lambda x: x):
        raw_df = scenario(pd.read_csv(self._cfg.dataset.train_path))[self.tabular_transform.columns].dropna()
        self.model.fit(raw_df, self.tabular_transform.categorical_columns)
        return super().fit(scenario)

    def generate_uncond(self, n: int, seed=None, **kwargs):
        with torch.no_grad():
            df = self.model.sample(n)
        for col, enc in zip(self.tabular_transform.categorical_columns, self.tabular_transform.label_encoders):
            df[col] = df[col].astype(enc._dtype)
        return df

    def _generate_uncond(self, n, **kwargs):
        pass

    def generate_by_class(self, y, seed=None, **kwargs):
        df = []
        for target, count in y[self._transform.target_column].value_counts().items():
            df.append(self.model.sample(count, self._transform.target_column, target))
        return pd.concat(df, ignore_index=True)

    def _generate_by_class(self, y, **kwargs):
        pass

    def on_save_checkpoint(self, checkpoint):
        checkpoint['transformer'] = self.model._transformer
        checkpoint['data_sampler'] = self.model._data_sampler

    def on_load_checkpoint(self, checkpoint) -> None:
        self.model._transformer = checkpoint['transformer']
        self.model._data_sampler = checkpoint['data_sampler']
        self.model._generator = Generator(
            self.model._embedding_dim + self.model._data_sampler.dim_cond_vec(),
            self.model._generator_dim,
            self.model._transformer.output_dimensions
        ).to(self.model._device)


class TVAE(BaseUnconditionalGeneratorMixIn, Base):

    """
    https://github.com/sdv-dev/CTGAN
    Xu, L., Skoularidou, M., Cuesta-Infante, A., & Veeramachaneni, K. (2019).
    Modeling tabular data using conditional gan.
    Advances in neural information processing systems, 32.
    """

    def __init__(self, embedding_dim, compress_dims, decompress_dims, l2scale, loss_factor,
                 batch_size, max_epochs, **kwargs):
        super().__init__(**get_kwargs(**kwargs))
        compress_dims = [int(x) for x in compress_dims.split(',')]
        decompress_dims = [int(x) for x in decompress_dims.split(',')]

        self.model = TVAEModel(
            embedding_dim=embedding_dim,
            compress_dims=compress_dims,
            decompress_dims=decompress_dims,
            l2scale=l2scale,
            batch_size=batch_size,
            epochs=max_epochs,
            loss_factor=loss_factor
        )

    def fit(self, scenario=lambda x: x):
        raw_df = scenario(pd.read_csv(self._cfg.dataset.train_path))[self.tabular_transform.columns].dropna()
        self.model.fit(raw_df, self.tabular_transform.categorical_columns)
        return super().fit(scenario)

    def generate_uncond(self, n: int, seed=None, **kwargs):
        with torch.no_grad():
            df = self.model.sample(n)
        for col, enc in zip(self.tabular_transform.categorical_columns, self.tabular_transform.label_encoders):
            df[col] = df[col].astype(enc._dtype)
        return df

    def _generate_uncond(self, n, **kwargs):
        pass

    def on_save_checkpoint(self, checkpoint):
        checkpoint['transformer'] = self.model.transformer

    def on_load_checkpoint(self, checkpoint) -> None:
        self.model.transformer = checkpoint['transformer']
        self.model.decoder = Decoder(
            self.model.embedding_dim,
            self.model.decompress_dims,
            self.model.transformer.output_dimensions
        ).to(self.model._device)
