from collections import namedtuple
import numpy as np
import torch
from torch import optim
from torch import nn
import torch.utils.data
from torch.nn import (
    BatchNorm1d,
    Dropout,
    LeakyReLU,
    Linear,
    Module,
    ReLU,
    Sequential,
    Sigmoid,
)
import warnings

#import opacus
from ...base import Synthesizer

from ...transform.table import TableTransformer
from .ctgan.data_sampler import DataSampler
from .ctgan.ctgan import CTGANSynthesizer

#new
from .slicedKL import slicedKLclass
def set_requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad = requires_grad

class Discriminator(Module):
    def __init__(self, input_dim, discriminator_dim, loss, pac=1):
        super(Discriminator, self).__init__()
        torch.cuda.manual_seed(0)
        torch.manual_seed(0)

        dim = input_dim * pac
        #  print ('now dim is {}'.format(dim))
        self.pac = pac
        self.pacdim = dim

        seq = []
        for item in list(discriminator_dim):
            seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
            dim = item

        seq += [Linear(dim, 1)]
        if loss == "cross_entropy":
            seq += [Sigmoid()]
        self.seq = Sequential(*seq)

    def calc_gradient_penalty(
        self, real_data, fake_data, device="cpu", pac=1, lambda_=10
    ):
        alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
        alpha = alpha.repeat(1, pac, real_data.size(1))
        alpha = alpha.view(-1, real_data.size(1))

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        disc_interpolates = self(interpolates)

        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones(disc_interpolates.size(), device=device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradient_penalty = (
            (gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1) ** 2
        ).mean() * lambda_

        return gradient_penalty

    def forward(self, input):
        assert input.size()[0] % self.pac == 0
        return self.seq(input.view(-1, self.pacdim))


class Residual(Module):
    def __init__(self, i, o):
        super(Residual, self).__init__()
        self.fc = Linear(i, o)
        self.bn = BatchNorm1d(o)
        self.relu = ReLU()

    def forward(self, input):
        out = self.fc(input)
        out = self.bn(out)
        out = self.relu(out)
        return torch.cat([out, input], dim=1)


class Generator(Module):
    def __init__(self, embedding_dim, generator_dim, data_dim):
        super(Generator, self).__init__()
        dim = embedding_dim
        seq = []
        for item in list(generator_dim):
            seq += [Residual(dim, item)]
            dim += item
        seq.append(Linear(dim, data_dim))
        self.seq = Sequential(*seq)

    def forward(self, input):
        data = self.seq(input)
        return data


# custom for calcuate grad_sample for multiple loss.backward()
def _custom_create_or_extend_grad_sample(
    param: torch.Tensor, grad_sample: torch.Tensor, batch_dim: int
) -> None:
    """
    Create a 'grad_sample' attribute in the given parameter, or accumulate it
    if the 'grad_sample' attribute already exists.
    This custom code will not work when using optimizer.virtual_step()
    """

    if hasattr(param, "grad_sample"):
        param.grad_sample = param.grad_sample + grad_sample
        # param.grad_sample = torch.cat((param.grad_sample, grad_sample), batch_dim)
    else:
        param.grad_sample = grad_sample
def sliced_wasserstein_distance_diff_priv(first_samples,
                                second_samples,
                                thetas,
                                p=1,                                
                                device='cuda',
                                sigma_proj=1,
                                sigma_noise = 1,
                                noise_samples=None
                                ):
    # first samples are the data to protect
    # second samples are the data_fake
    
    #first_samples, second_samples = make_sample_size_equal(first_samples, second_samples)

    #dim = second_samples.size(1)
    nb_sample = second_samples.size(0)
    #projections = rand_projections_diff_priv(dim, num_projections,sigma_proj)
    #projections = projections.to(device)
    noise2 = torch.randn((nb_sample,thetas.shape[0]))*sigma_noise
    noise2 = noise2.to(device)
    if noise_samples is not None:
        noise = noise_samples * sigma_noise
        noise = noise.to(device)  
    else:
        noise = torch.randn((nb_sample,thetas.shape[0]))*sigma_noise
        noise = noise.to(device)    
    #print(first_samples.shape)
    #print(second_samples.shape)
    first_projections = torch.matmul(first_samples,torch.transpose(thetas[:,0,:], 0,1)) + noise 
    second_projections = torch.matmul(second_samples,torch.transpose(thetas[:,0,:], 0,1)) + noise2 
    wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
                                      torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.pow(torch.mean(torch.pow(wasserstein_distance, p), dim=1), 1. / p) # averaging the sorted distance
    return torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)  # averaging over the random direction

#mmd util start
def euclidsq(x, y):
    return torch.pow(torch.cdist(x, y), 2)

def prepare(x_de, x_nu):
    return euclidsq(x_de, x_de), euclidsq(x_de, x_nu), euclidsq(x_nu, x_nu)

def gaussian_gramian(esq, σ):
    return torch.exp(torch.div(-esq, 2 * σ**2))

USE_SOLVE = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def kmm_ratios(Kdede, Kdenu, λ):
    n_de, n_nu = Kdenu.shape
    if λ > 0:
        A = Kdede + λ * torch.eye(n_de).to(device)
    else:
        A = Kdede
    # Equivalent implement based on 1) solver and 2) matrix inversion
    if USE_SOLVE:
        B = torch.sum(Kdenu, 1, keepdim=True)
        return (n_de / n_nu) * torch.linalg.solve(A,B)
    else:
        B = Kdenu
        return torch.matmul(torch.matmul(torch.inverse(A), B), torch.ones(n_nu, 1).to(device))

eps_ratio = 0.0001
clip_ratio = True

def estimate_ratio_compute_mmd(x_de, x_nu, σs=[]):
    
    dsq_dede, dsq_denu, dsq_nunu = prepare(x_de, x_nu)
    
    if len(σs) == 0:
        with torch.no_grad():
        # A heuristic is to use the median of pairwise distances as σ, suggested by Sugiyama's book
            sigma = torch.sqrt(
                torch.median(
                    torch.cat([dsq_dede, dsq_denu, dsq_nunu], 1)
                )#median
            )
            c = 2
            σs = sigma*torch.as_tensor([1/c,0.333/c,0.2/c,5/c,3/c,.1/c],device=device)
            #print("heuristic sigma: ", sigma)
            
    is_first = True
    ratio = None
    mmdsq = None
    for σ in σs:
        Kdede = gaussian_gramian(dsq_dede, σ)
        Kdenu = gaussian_gramian(dsq_denu, σ)
        Knunu = gaussian_gramian(dsq_nunu, σ)
        if is_first:
            ratio = kmm_ratios(Kdede, Kdenu, eps_ratio)
            #mmdsq = mmdsq_of(Kdede, Kdenu, Knunu)
            is_first = False
        else:
            ratio += kmm_ratios(Kdede, Kdenu, eps_ratio)
            #mmdsq += mmdsq_of(Kdede, Kdenu, Knunu)
    
    ratio = ratio / len(σs)
    #ratio = torch.relu(ratio) if clip_ratio else ratio
    #mmd = torch.sqrt(torch.relu(mmdsq))
    #return ratio, mmd
    return ratio

def sliced_kl_kmm_diff_priv(first_samples,
                                second_samples,
                                thetas, #n_slice * slice_dim * d
                                p=1,                                
                                device='cuda',
                                sigma_proj=1,
                                sigma_noise = 1,
                                noise_samples=None,
                                n_slice=40,
                                slice_dim=2
                                ):
    # first samples are the data to protect
    # second samples are the data_fake
    
    #first_samples, second_samples = make_sample_size_equal(first_samples, second_samples)

    #dim = second_samples.size(1)
    nb_sample = second_samples.size(0)
    thetas = thetas.to(device)
    sigma_noise = sigma_noise.to(device)
    #projections = rand_projections_diff_priv(dim, num_projections,sigma_proj)
    #projections = projections.to(device)
    noise2 = torch.randn((n_slice,nb_sample,slice_dim),device=device)*sigma_noise
    if noise_samples is not None:
        noise = noise_samples * sigma_noise.to(device)
    else:
        noise = torch.randn((n_slice,nb_sample,slice_dim))*sigma_noise.to(device)

    first_projections = torch.matmul(first_samples,torch.transpose(thetas, 1,2)) + noise
    second_projections = torch.matmul(second_samples,torch.transpose(thetas, 1,2)) + noise2

    #ratio=torch.zeros_like(first_projections,device=device)
    
    
    #find good kernel width to share across the slices
    with torch.no_grad():
        # A heuristic is to use the median of pairwise distances as σ, suggested by Sugiyama's book
        diff = torch.linalg.vector_norm(first_projections,dim=2) # - second_projections)**2 #n_slice * batch * slice_dim
        sigma = torch.sqrt(torch.median(diff**2))
        #print("heuristic sigma: ", sigma)
        #sigmaprime = torch.sqrt(torch.median(second_projections**2))
        #print((sigma - sigmaprime)/sigma)
        
        
        c = 2
        σs = sigma*torch.as_tensor([1/c,0.333/c,0.2/c,5/c,3/c,.1/c],device=device)
    
    def ratio_func(first,second):
        return estimate_ratio_compute_mmd(first,second,σs = σs)
    vec_mmd=torch.vmap(ratio_func,in_dims=0,out_dims=0)
    
    ratio = vec_mmd(first_projections,second_projections).view(n_slice,-1)
    #for i in range(first_projections.shape[1]):
    #    ratio[:,i] = estimate_ratio_compute_mmd(first_projections[:,i].view(-1,1), second_projections[:,i].view(-1,1), []).squeeze() 
        
    #print(ratio[:,0].shape)
    #print(ratio.shape)
    #print(ratio[:,0])
    
    epsilon=torch.full(ratio.size(),1e-10,device=device)
    ratio = torch.maximum(ratio,epsilon)
    #print(torch.log(ratio))
    kl = torch.mean(ratio*torch.log(ratio))
    #chi_squared = torch.mean(torch.pow(ratio-1,2))
    return kl
#mmd util end

class SFDCTGAN(CTGANSynthesizer, Synthesizer):
    """SFDCTGAN Synthesizer.

    GAN-based synthesizer that uses conditional masks to learn tabular data.

    :param epsilon: Privacy budget for the model.
    :param sigma: The noise scale for the gradients.  Noise scale and batch size influence
        how fast the privacy budget is consumed, which in turn influences the
        convergence rate and the quality of the synthetic data.
    :param batch_size: The batch size for training the model.
    :param epochs: The number of epochs to train the model.
    :param embedding_dim: The dimensionality of the embedding layer.
    :param generator_dim: The dimensionality of the generator layer.
    :param discriminator_dim: The dimensionality of the discriminator layer.
    :param generator_lr: The learning rate for the generator.
    :param discriminator_lr: The learning rate for the discriminator.
    :param verbose: Whether to print the training progress.
    :param diabled_dp: Allows training without differential privacy, to diagnose
        whether any model issues are caused by privacy or are simply the
        result of GAN instability or other issues with hyperparameters.

    """
    def __init__(
        self,
        embedding_dim=128,
        generator_dim=(256, 256),
        discriminator_dim=(256, 256),
        generator_lr=2e-4,
        generator_decay=1e-6,
        discriminator_lr=2e-4,
        discriminator_decay=1e-6,
        batch_size=500, #500
        discriminator_steps=1,
        verbose=True,
        epochs=300,
        pac=1,
        cuda=True,
        disabled_dp=True,
        delta=None,
        sigma=5,
        max_per_sample_grad_norm=1.0,
        epsilon=1,
        loss="sfd-dp",
        #new
        n_KL_slices=100,
        n_KL_slice_dim=2
    ):

        assert batch_size % 2 == 0

        self._embedding_dim = embedding_dim
        self._generator_dim = generator_dim
        self._discriminator_dim = discriminator_dim

        self._generator_lr = generator_lr
        self._generator_decay = generator_decay
        self._discriminator_lr = discriminator_lr
        self._discriminator_decay = discriminator_decay

        self._batch_size = batch_size
        self._discriminator_steps = discriminator_steps
        self._verbose = verbose
        self._epochs = epochs
        self.pac = pac

        # opacus parameters
        self.sigma = sigma
        self.disabled_dp = disabled_dp
        self.delta = delta
        self.max_per_sample_grad_norm = max_per_sample_grad_norm
        self.epsilon = epsilon
        self.epsilon_list = []
        self.alpha_list = []
        self.loss_d_list = []
        self.loss_g_list = []
        self.verbose = verbose
        self.loss = loss

        if not cuda or not torch.cuda.is_available():
            device = "cpu"
        elif isinstance(cuda, str):
            device = cuda
        else:
            device = "cuda"

        self._device = torch.device(device)

        self._transformer = None
        self._data_sampler = None
        self._generator = None

        if self.loss != "cross_entropy":
            # Monkeypatches the _create_or_extend_grad_sample function when calling opacus
            #opacus.grad_sample.utils.create_or_extend_grad_sample = (
            #    _custom_create_or_extend_grad_sample
            #)
            pass
        
        #new
        #kl_device = torch.device("cuda:0" if cuda else "cpu")
        self.n_KL_slices = n_KL_slices
        self.n_KL_slice_dim = n_KL_slice_dim
        
        self.wd_clf = 1
        print('sliced KL')
        
        
    def train(
        self,
        data,
        transformer=None,
        categorical_columns=[],
        ordinal_columns=[],
        continuous_columns=[],
        update_epsilon=None, 
        preprocessor_eps=0.0,
        nullable=False,
    ):
        if update_epsilon:
            self.epsilon = update_epsilon
            
        train_data = self._get_train_data(
            data,
            style='gan',
            transformer=transformer,
            categorical_columns=categorical_columns, 
            ordinal_columns=ordinal_columns, 
            continuous_columns=continuous_columns, 
            nullable=nullable,
            preprocessor_eps=preprocessor_eps
        )

        train_data = np.array([
            [float(x) if x is not None else 0.0 for x in row] for row in train_data
        ])
        noise_data = torch.randn(train_data.shape[0], self.n_KL_slices * self.n_KL_slice_dim).to(self._device)
        noise_data_kmm = torch.randn(self.n_KL_slices, train_data.shape[0], self.n_KL_slice_dim).to(self._device)
        #self.data_dim = train_data.shape[1]
        self._data_sampler = DataSampler(
            train_data,
            self._transformer.transformers,
            return_ix = True
        )

        data_dim = self._transformer.output_width
        #print(data_dim)
        
        self._generator = Generator(
            self._embedding_dim + self._data_sampler.dim_cond_vec(),
            self._generator_dim,
            data_dim,
        ).to(self._device)
        
#      #######################Get dimension of embedded data
        mean = torch.zeros(2, self._embedding_dim, device=self._device)
        std = mean + 1
        fakez = torch.normal(mean=mean, std=std)
        condvec = self._data_sampler.sample_condvec(2)

        if condvec is None:
            c1, m1, col, opt = None, None, None, None
        else:
            c1, m1, col, opt = condvec
            c1 = torch.from_numpy(c1).to(self._device)
            m1 = torch.from_numpy(m1).to(self._device)
            fakez = torch.cat([fakez, c1], dim=1)

        fake = self._generator(fakez)
        fakeact = self._apply_activate(fake)
        if c1 is not None:
            fakey = torch.cat([fakeact, c1], dim=1)
        else:
            fakey = fakeact
        #print(fakey.shape)
        d=fakey.shape[1] #596 #588 #data_dim#train_data.shape[1]
        print('d')
        print(d)
        ####################FIGURE OUT WHAT NOISE TO ADD
        epsilon = self.epsilon #- preprocessor_eps
        if self.delta is None:
            print(train_data.shape[0])
            
            self.delta =  1 / np.sqrt(train_data.shape[0]) #* np.sqrt(48842)
            
        delta=self.delta
        mprime = self.n_KL_slice_dim * self.n_KL_slices
        
        #print(data_dim)
        #print(epsilon)
        # First, try approx-opt alpha expression to guess sigma
        aa = mprime/d
        bb = 2 * np.sqrt(mprime * np.log(1/delta) / d)
        cc = epsilon
        sigma_propose = ( (-bb + np.sqrt(bb**2 + 4 * aa *cc))  /(2*aa) )**(-1) # Will always be real and positive

        # Iterative optimization
        iters = 10
        for i in range(iters):
            # Compute implied approximate-opt alpha 
            alpha_star = 1 + np.sqrt(sigma_propose**2 * d * np.log(1/delta) / mprime)

            # Check if implied alpha is outside allowed range and pick closest allowed
            if (alpha_star**2 - alpha_star) > d * sigma_propose**2 / 2:
                # quadratic formula, won't be imaginary or unambiguous
                alpha_star = 0.5 * (1 + np.sqrt(1 + 2 * d * sigma_propose**2))
            # Recompute sigma in case alpha changed, using exact formula for epsilon
            val = 2 * d * (epsilon - np.log(1/delta) / (alpha_star - 1) ) / (mprime * alpha_star  )
            while val <= 0:
                if ((1.2*alpha_star)**2 - 1.2*alpha_star) > d * sigma_propose**2:
                    val = 0.001
                    print('WARNING: unable to find a valid sigma to achieve specified epsilon, delta combination. Using a large sigma.')
                    break
                else:
                    alpha_star *= 1.2
                    #print('WARNING: unable to find a valid sigma to achieve specified epsilon, delta. Using a large sigma.')
                    val = 2 * d * (epsilon - np.log(1/delta) / (alpha_star - 1) ) / (mprime * alpha_star  ) #.01
            val2 = (alpha_star**2 - alpha_star) / d
            sigma_propose = np.sqrt(  (1 / val) + val2 ) #Automatically satisfies constraint if val > 0
        sigma = sigma_propose
        noise = sigma
        epsilon_actual = mprime * alpha_star / (2 * sigma**2 * (d - (alpha_star**2 - alpha_star)*sigma**(-2))) + np.log(1/delta)/(alpha_star - 1)
        print('User specified (epsilon, delta)=(' + str(epsilon) + ',' + str(delta) + '); Chosen sigma = ' + str(sigma) + '; Actual epsilon = ' + str(epsilon_actual))
        
        self.sigma_noise = sigma
        ################################################

        

        discriminator = Discriminator(
            data_dim + self._data_sampler.dim_cond_vec(),
            self._discriminator_dim,
            self.loss,
            self.pac,
        ).to(self._device)
        
        self.kldiv = slicedKLclass(d,self.n_KL_slices,self.n_KL_slice_dim,self._device).to(self._device)

        optimizerG = optim.Adam(
            self._generator.parameters(),
            lr=self._generator_lr,
            betas=(0.5, 0.9),
            weight_decay=self._generator_decay,
        )
        optimizerD = optim.Adam(
            self.kldiv.parameters(),
            lr=self._discriminator_lr,
            betas=(0.5,0.9),
            weight_decay=self._discriminator_decay,
        )

        #privacy_engine = opacus.PrivacyEngine(
        #    discriminator,
        #    batch_size=self._batch_size,
        #    sample_size=train_data.shape[0],
        #    alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
        #    noise_multiplier=self.sigma,
        #    max_grad_norm=self.max_per_sample_grad_norm,
        #    clip_per_layer=True,
        #)
        self.disabled_dp = True #Remove since we are doing KL
        if 0: #not self.disabled_dp:
            privacy_engine.attach(optimizerD)

        one = torch.tensor(1, dtype=torch.float).to(self._device)
        mone = one * -1

        real_label = 1
        fake_label = 0
        criterion = nn.BCELoss()

        assert self._batch_size % 2 == 0
        mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
        std = mean + 1

        steps_per_epoch = max(len(train_data) // self._batch_size, 1) 
        dsteps_per_epoch = 3
        for i in range(self._epochs):
            if not self.disabled_dp:
                # if self.loss == 'cross_entropy':
                #    autograd_grad_sample.clear_backprops(discriminator)
                # else:
                for p in discriminator.parameters():
                    if hasattr(p, "grad_sample"):
                        del p.grad_sample

                if self.delta is None:
                    self.delta = 1 / (
                        train_data.shape[0] * np.sqrt(train_data.shape[0])
                    )

                epsilon, best_alpha = optimizerD.privacy_engine.get_privacy_spent(
                    self.delta
                )

                self.epsilon_list.append(epsilon)
                self.alpha_list.append(best_alpha)
                if self.epsilon < epsilon:
                    if self._epochs == 1:
                        raise ValueError(
                            "Inputted epsilon and sigma parameters are too small to"
                            + " create a private dataset. Try increasing either parameter "
                            + "and rerunning."
                        )
                    else:
                        break
            else:
                epsilon = 0
                best_alpha = 0 
            maxi_maxi_norm = 0
            for id_ in range(steps_per_epoch):
                for d_id_ in range(dsteps_per_epoch):
                    
                    ############################
                    # New test line
                    ###########################
                    set_requires_grad(self.kldiv, requires_grad = True)
                    set_requires_grad(self._generator, requires_grad = False)
                    ##############################
                    fakez = torch.normal(mean=mean, std=std)

                    condvec = self._data_sampler.sample_condvec(self._batch_size)
                    if condvec is None:
                        c1, m1, col, opt = None, None, None, None
                        real, ix = self._data_sampler.sample_data(self._batch_size, col, opt)
                        #real = real_and_noise[:,:self.data_dim]
                        X_noise = noise_data[ix,:] # real_and_noise[:,self.data_dim:]
                        X_noise_kmm = noise_data_kmm[:,ix,:]
                    else:
                        c1, m1, col, opt = condvec
                        c1 = torch.from_numpy(c1).to(self._device)
                        m1 = torch.from_numpy(m1).to(self._device)
                        fakez = torch.cat([fakez, c1], dim=1)

                        perm = np.arange(self._batch_size)
                        np.random.shuffle(perm)
                        real,ix = self._data_sampler.sample_data(
                            self._batch_size, col[perm], opt[perm]
                        )
                
                        c2 = c1[perm]
                        #real = real_and_noise[:,:(self.data_dim-c2.shape[1])]
                        X_noise = noise_data[ix,:]
                        X_noise_kmm = noise_data_kmm[:,ix,:]

                    fake = self._generator(fakez)
                    fakeact = self._apply_activate(fake)

                    real = torch.from_numpy(real.astype("float32")).to(self._device)

                    if c1 is not None:
                        fake_cat = torch.cat([fakeact, c1], dim=1)
                        real_cat = torch.cat([real, c2], dim=1)
                    else:
                        real_cat = real
                        fake_cat = fakeact
                    optimizerD.zero_grad()
                    if self.loss != 'SWD-DP' and self.loss != 'kmm':
                        if self.loss == "cross_entropy":
                            y_fake = discriminator(fake_cat)

                            #   print ('y_fake is {}'.format(y_fake))
                            label_fake = torch.full(
                                (int(self._batch_size / self.pac),),
                                fake_label,
                                dtype=torch.float,
                                device=self._device,
                            )

                            #    print ('label_fake is {}'.format(label_fake))

                            error_d_fake = criterion(y_fake.squeeze(), label_fake)
                            error_d_fake.backward()
                            optimizerD.step()

                            # train with real
                            label_true = torch.full(
                                (int(self._batch_size / self.pac),),
                                real_label,
                                dtype=torch.float,
                                device=self._device,
                            )
                            y_real = discriminator(real_cat)
                            error_d_real = criterion(y_real.squeeze(), label_true)
                            error_d_real.backward()
                            optimizerD.step()

                            loss_d = error_d_real + error_d_fake

                        elif self.loss == 'sfd':
                            X = real_cat.view(real_cat.shape[0], -1)
                            Y = fake_cat.view(fake_cat.shape[0], -1)


                            error_d = self.kldiv.forward(X,Y)
                            error_d.backward()
                            optimizerD.step()

                            loss_d = self.wd_clf * error_d
                        elif self.loss == 'sfd-dp':
                            maxi_norm = torch.sqrt(torch.max(torch.sum(real_cat.view(real_cat.shape[0], -1)**2,dim=1))).to(self._device)
                            source_features_norm = torch.div(real_cat,1 ) #(2*maxi_norm))
                            target_features_norm = torch.div(fake_cat,1 ) #(2*maxi_norm))
                            X = source_features_norm.view(source_features_norm.shape[0], -1)
                            Y = target_features_norm.view(target_features_norm.shape[0], -1)

                             #   KL_est = self.kldiv.forward(X,Y,sigma_noise = self.sigma_noise)
                             #   wasserstein_distance = KL_est
                             #   wasserstein_distance = torch.add(wasserstein_distance,torch.log(2*maxi_norm))

                            #Y = real_cat.view(real_cat.shape[0], -1)
                            #X = fake_cat.view(fake_cat.shape[0], -1)

                            error_d = self.kldiv.forward(X,Y,X_noise,sigma_noise = self.sigma_noise * 2*maxi_norm)
                            #error_d = torch.add(error_d, -torch.log(2*maxi_norm))
                            error_d.backward()
                            optimizerD.step()

                            loss_d = self.wd_clf * error_d
                        else:

                            y_fake = discriminator(fake_cat)
                            mean_fake = torch.mean(y_fake)
                            mean_fake.backward(one)

                            y_real = discriminator(real_cat)
                            mean_real = torch.mean(y_real)
                            mean_real.backward(mone)

                            optimizerD.step()

                            loss_d = -(mean_real - mean_fake)
                    else:
                        loss_d = 0

                #max_grad_norm = []
                #for p in self.kldiv.parameters():
                #    param_norm = p.grad.data.norm(2).item()
                #    max_grad_norm.append(param_norm)
                # pen = calc_gradient_penalty(discriminator, real_cat, fake_cat, self.device)

                # pen.backward(retain_graph=True)
                # loss_d.backward()
                # optimizer_d.step()
                
                ############################
                # New test line
                ###########################
                set_requires_grad(self.kldiv, requires_grad = False)
                set_requires_grad(self._generator, requires_grad = True)
                ##############################

                fakez = torch.normal(mean=mean, std=std)
                condvec = self._data_sampler.sample_condvec(self._batch_size)

                if condvec is None:
                    c1, m1, col, opt = None, None, None, None
                else:
                    c1, m1, col, opt = condvec
                    c1 = torch.from_numpy(c1).to(self._device)
                    m1 = torch.from_numpy(m1).to(self._device)
                    fakez = torch.cat([fakez, c1], dim=1)

                fake = self._generator(fakez)
                fakeact = self._apply_activate(fake)

                if self.loss != 'sfd' and self.loss != 'sfd-dp':
                    if c1 is not None:
                        y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
                    else:
                        y_fake = discriminator(fakeact)

                # if condvec is None:
                cross_entropy = 0
                # else:
                #    cross_entropy = self._cond_loss(fake, c1, m1)

                if self.loss == "cross_entropy":
                    label_g = torch.full(
                        (int(self._batch_size / self.pac),),
                        real_label,
                        dtype=torch.float,
                        device=self._device,
                    )
                    # label_g = torch.full(int(self.batch_size/self.pack,),1,device=self.device)
                    loss_g = criterion(y_fake.squeeze(), label_g)
                    loss_g = loss_g + cross_entropy
                elif self.loss == 'sfd':
                    XX = real_cat.view(real_cat.shape[0], -1)
                    if c1 is not None:
                        fakey = torch.cat([fakeact, c1], dim=1)
                        YY = fakey.view(fakey.shape[0], -1)
                    else:
                        YY = fakeact.view(fakeact.shape[0],-1)

                        
                    error_g = self.kldiv.forward(XX,YY)
                    
                    loss_g = -self.wd_clf * error_g
                
                
                elif self.loss == 'sfd-dp':
                    #XX = real_cat.view(real_cat.shape[0], -1)
                    if c1 is not None:
                        fakey = torch.cat([fakeact, c1], dim=1)
                        #YY = fakey.view(fakey.shape[0], -1)
                    else:
                        fakey = fakeact
                        #YY = fakeact.view(fakeact.shape[0],-1)
                        
                    
                    with torch.no_grad():
                        maxi_norm = torch.sqrt(torch.max(torch.sum(real_cat.view(real_cat.shape[0], -1)**2,dim=1))).to(self._device)
                    source_features_norm = torch.div(real_cat,1)#(2*maxi_norm))
                    target_features_norm = torch.div(fakey,1)#(2*maxi_norm))
                    X = source_features_norm.view(source_features_norm.shape[0], -1)
                    Y = target_features_norm.view(target_features_norm.shape[0], -1)


                         #   KL_est = self.kldiv.forward(X,Y,sigma_noise = self.sigma_noise)
                         #   wasserstein_distance = KL_est
                         #   wasserstein_distance = torch.add(wasserstein_distance,torch.log(2*maxi_norm))





                        #Y = real_cat.view(real_cat.shape[0], -1)
                        #X = fake_cat.view(fake_cat.shape[0], -1)

  
                    error_g = self.kldiv.forward(X,Y,X_noise,sigma_noise = self.sigma_noise * 2 * maxi_norm)
                    #error_g = torch.add(error_g, -torch.log(2*maxi_norm))

                        
                    #error_g = self.kldiv.forward(XX,YY)
                    
                    loss_g = -self.wd_clf * error_g
                elif self.loss == 'SWD-DP':
                    #XX = real_cat.view(real_cat.shape[0], -1)
                    if c1 is not None:
                        fakey = torch.cat([fakeact, c1], dim=1)
                        #YY = fakey.view(fakey.shape[0], -1)
                    else:
                        fakey = fakeact
                        #YY = fakeact.view(fakeact.shape[0],-1)
                        
                    #if epoch > self.epoch_to_start_align:
                    with torch.no_grad():
                        maxi_norm = torch.sqrt(torch.max(torch.sum(real_cat.view(real_cat.shape[0], -1)**2,dim=1))).to(self._device)
                        if maxi_norm > maxi_maxi_norm and id_ == 0:
                            maxi_maxi_norm = maxi_norm

                    source_features_norm = torch.div(real_cat,(2*maxi_maxi_norm))
                    target_features_norm = torch.div(fakey,(2*maxi_maxi_norm))
                    wasserstein_distance = sliced_wasserstein_distance_diff_priv(source_features_norm.view(source_features_norm.shape[0], -1),      
                                                                                 target_features_norm.view(target_features_norm.shape[0], -1),
                                                                                 self.kldiv.thetas,
                                                        2,
                                                        self._device,
                                                        sigma_noise=self.sigma_noise, 
                                                                                 noise_samples=X_noise)
                    wasserstein_distance = torch.mul(wasserstein_distance,2*maxi_maxi_norm)

         
                    #source_preds = self.data_classifier(source_features)                
                    #self.criterion = nn.CrossEntropyLoss()
                    #clf_loss = self.criterion(source_preds, source_y)
                    loss_g = self.wd_clf * wasserstein_distance
                    
                elif self.loss == 'kmm':
                    if c1 is not None:
                        fakey = torch.cat([fakeact, c1], dim=1)
                        #YY = fakey.view(fakey.shape[0], -1)
                    else:
                        fakey = fakeact
                        #YY = fakeact.view(fakeact.shape[0],-1)
                        
                    #if epoch > self.epoch_to_start_align:
                    with torch.no_grad():
                        maxi_norm = torch.sqrt(torch.max(torch.sum(real_cat.view(real_cat.shape[0], -1)**2,dim=1))).to(self._device)
                        if maxi_norm > maxi_maxi_norm and id_ == 0:
                            maxi_maxi_norm = maxi_norm

                    source_features_norm = torch.div(real_cat,1) #(2*maxi_norm))
                    target_features_norm = torch.div(fakey,1) #(2*maxi_norm))
                    
                    kl = sliced_kl_kmm_diff_priv(source_features_norm.view(source_features_norm.shape[0], -1), 
                                                target_features_norm.view(target_features_norm.shape[0], -1),
                                                self.kldiv.thetas,
                                                2,
                                                self._device,
                                                sigma_noise= self.sigma_noise * 2 * maxi_maxi_norm, 
                                                noise_samples=X_noise_kmm,
                                                n_slice=self.n_KL_slices, 
                                                slice_dim=self.n_KL_slice_dim)
                    loss_g = self.wd_clf * kl
                else:
                    loss_g = -torch.mean(y_fake) + cross_entropy
                              
                optimizerG.zero_grad()
                loss_g.backward()
                optimizerG.step()

            self.loss_d_list.append(loss_d)
            self.loss_g_list.append(loss_g)
            if self.verbose:
                if self.loss == 'SWD-DP' or self.loss == 'kmm':
                    ld = 0
                else:
                    ld = loss_d.detach().cpu()
                print(
                    "sliced KL Epoch %d, Loss G: %.4f, Loss D: %.4f"                
                    % (i + 1, loss_g.detach().cpu(), ld),
                    flush=True,
                )
                print("epsilon is {e}, alpha is {a}".format(e=self.epsilon, a=best_alpha))

        return self.loss_d_list, self.loss_g_list, self.epsilon_list, self.alpha_list

    def generate(self, n, condition_column=None, condition_value=None):
        """
        TODO: Add condition_column support from CTGAN
        """
        self._generator.eval()

        # output_info = self._transformer.output_info
        steps = n // self._batch_size + 1
        data = []
        for i in range(steps):
            mean = torch.zeros(self._batch_size, self._embedding_dim)
            std = mean + 1
            fakez = torch.normal(mean=mean, std=std).to(self._device)

            condvec = self._data_sampler.sample_condvec(self._batch_size)

            if condvec is None:
                pass
            else:
                c1, m1, col, opt = condvec
                c1 = torch.from_numpy(c1).to(self._device)
                fakez = torch.cat([fakez, c1], dim=1)

            fake = self._generator(fakez)
            fakeact = self._apply_activate(fake)
            data.append(fakeact.detach().cpu().numpy())

        data = np.concatenate(data, axis=0)
        data = data[:n]

        return self._transformer.inverse_transform(data)

    def fit(self, data, *ignore, transformer=None, categorical_columns=[], ordinal_columns=[], continuous_columns=[], preprocessor_eps=0.0, nullable=False):
        self.train(data, transformer=transformer, categorical_columns=categorical_columns, ordinal_columns=ordinal_columns, continuous_columns=continuous_columns, preprocessor_eps=preprocessor_eps, nullable=nullable)

    def sample(self, n_samples):
        return self.generate(n_samples)
