import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import (
    BatchNorm1d,
    Dropout,
    LeakyReLU,
    Linear,
    Module,
    ReLU,
    Sequential,
    Sigmoid,
)

from torch.utils.data import DataLoader, TensorDataset

#from opacus import PrivacyEngine

from snsynth.base import Synthesizer

#from ._generator import Generator
from ._discriminator import Discriminator
from .ctgan.data_sampler import DataSampler
from .slicedKL import slicedKLclass

def sliced_wasserstein_distance_diff_priv(first_samples,
                                second_samples,
                                thetas,
                                p=1,                                
                                device='cuda',
                                sigma_proj=1,
                                sigma_noise = 1,
                                noise_samples=None
                                ):
    # first samples are the data to protect
    # second samples are the data_fake
    
    #first_samples, second_samples = make_sample_size_equal(first_samples, second_samples)

    #dim = second_samples.size(1)
    nb_sample = second_samples.size(0)
    #projections = rand_projections_diff_priv(dim, num_projections,sigma_proj)
    #projections = projections.to(device)
    noise2 = torch.randn((nb_sample,thetas.shape[0]))*sigma_noise
    noise2 = noise2.to(device)
    if noise_samples is not None:
        noise = noise_samples * sigma_noise
        noise = noise.to(device)  
    else:
        noise = torch.randn((nb_sample,thetas.shape[0]))*sigma_noise
        noise = noise.to(device)    
    #print(first_samples.shape)
    #print(second_samples.shape)
    first_projections = torch.matmul(first_samples,torch.transpose(thetas[:,0,:], 0,1)) + noise 
    second_projections = torch.matmul(second_samples,torch.transpose(thetas[:,0,:], 0,1)) + noise2 
    wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
                                      torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
    wasserstein_distance = torch.pow(torch.mean(torch.pow(wasserstein_distance, p), dim=1), 1. / p) # averaging the sorted distance
    return torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)  # averaging over the random direction


#mmd util start
def euclidsq(x, y):
    return torch.pow(torch.cdist(x, y), 2)

def prepare(x_de, x_nu):
    return euclidsq(x_de, x_de), euclidsq(x_de, x_nu), euclidsq(x_nu, x_nu)

def gaussian_gramian(esq, σ):
    return torch.exp(torch.div(-esq, 2 * σ**2))

USE_SOLVE = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def kmm_ratios(Kdede, Kdenu, λ):
    n_de, n_nu = Kdenu.shape
    if λ > 0:
        A = Kdede + λ * torch.eye(n_de).to(device)
    else:
        A = Kdede
    # Equivalent implement based on 1) solver and 2) matrix inversion
    if USE_SOLVE:
        B = torch.sum(Kdenu, 1, keepdim=True)
        return (n_de / n_nu) * torch.linalg.solve(A,B)
    else:
        B = Kdenu
        return torch.matmul(torch.matmul(torch.inverse(A), B), torch.ones(n_nu, 1).to(device))

eps_ratio = 0.0001
clip_ratio = True

def estimate_ratio_compute_mmd(x_de, x_nu, σs=[]):
    
    dsq_dede, dsq_denu, dsq_nunu = prepare(x_de, x_nu)
    
    if len(σs) == 0:
        with torch.no_grad():
        # A heuristic is to use the median of pairwise distances as σ, suggested by Sugiyama's book
            sigma = torch.sqrt(
                torch.median(
                    torch.cat([dsq_dede, dsq_denu, dsq_nunu], 1)
                )#median
            )
            c = 2
            σs = sigma*torch.as_tensor([1/c,0.333/c,0.2/c,5/c,3/c,.1/c],device=device)
            #print("heuristic sigma: ", sigma)
            
    is_first = True
    ratio = None
    mmdsq = None
    for σ in σs:
        Kdede = gaussian_gramian(dsq_dede, σ)
        Kdenu = gaussian_gramian(dsq_denu, σ)
        Knunu = gaussian_gramian(dsq_nunu, σ)
        if is_first:
            ratio = kmm_ratios(Kdede, Kdenu, eps_ratio)
            #mmdsq = mmdsq_of(Kdede, Kdenu, Knunu)
            is_first = False
        else:
            ratio += kmm_ratios(Kdede, Kdenu, eps_ratio)
            #mmdsq += mmdsq_of(Kdede, Kdenu, Knunu)
    
    ratio = ratio / len(σs)
    #ratio = torch.relu(ratio) if clip_ratio else ratio
    #mmd = torch.sqrt(torch.relu(mmdsq))
    #return ratio, mmd
    return ratio

def sliced_kl_kmm_diff_priv(first_samples,
                                second_samples,
                                thetas, #n_slice * slice_dim * d
                                p=1,                                
                                device='cuda',
                                sigma_proj=1,
                                sigma_noise = 1,
                                noise_samples=None,
                                n_slice=40,
                                slice_dim=2
                                ):
    # first samples are the data to protect
    # second samples are the data_fake
    
    #first_samples, second_samples = make_sample_size_equal(first_samples, second_samples)

    #dim = second_samples.size(1)
    nb_sample = second_samples.size(0)
    thetas = thetas.to(device)
    sigma_noise = sigma_noise.to(device)
    #projections = rand_projections_diff_priv(dim, num_projections,sigma_proj)
    #projections = projections.to(device)
    noise2 = torch.randn((n_slice,nb_sample,slice_dim),device=device)*sigma_noise
    if noise_samples is not None:
        noise = noise_samples * sigma_noise.to(device)
    else:
        noise = torch.randn((n_slice,nb_sample,slice_dim))*sigma_noise.to(device)
        
    first_projections = torch.matmul(first_samples,torch.transpose(thetas, 1,2)) + noise
    second_projections = torch.matmul(second_samples,torch.transpose(thetas, 1,2)) + noise2

    #ratio=torch.zeros_like(first_projections,device=device)
    
    
    #find good kernel width to share across the slices
    with torch.no_grad():
        # A heuristic is to use the median of pairwise distances as σ, suggested by Sugiyama's book
        diff = torch.linalg.vector_norm(first_projections,dim=2) # - second_projections)**2 #n_slice * batch * slice_dim
        sigma = torch.sqrt(torch.median(diff**2))
        #print("heuristic sigma: ", sigma)
        #sigmaprime = torch.sqrt(torch.median(second_projections**2))
        #print((sigma - sigmaprime)/sigma)
        
        
        c = 2
        σs = sigma*torch.as_tensor([1/c,0.333/c,0.2/c,5/c,3/c,.1/c],device=device)
    
    def ratio_func(first,second):
        return estimate_ratio_compute_mmd(first,second,σs = σs)
    vec_mmd=torch.vmap(ratio_func,in_dims=0,out_dims=0)
    
    ratio = vec_mmd(first_projections,second_projections).view(n_slice,-1)
    #for i in range(first_projections.shape[1]):
    #    ratio[:,i] = estimate_ratio_compute_mmd(first_projections[:,i].view(-1,1), second_projections[:,i].view(-1,1), []).squeeze() 
        
    #print(ratio[:,0].shape)
    #print(ratio.shape)
    #print(ratio[:,0])
    
    epsilon=torch.full(ratio.size(),1e-10,device=device)
    ratio = torch.maximum(ratio,epsilon)
    #print(torch.log(ratio))
    kl = torch.mean(ratio*torch.log(ratio))
    #chi_squared = torch.mean(torch.pow(ratio-1,2))
    return kl
#mmd util end

class Residual(Module):
    def __init__(self, i, o):
        super(Residual, self).__init__()
        self.fc = Linear(i, o)
        self.bn = BatchNorm1d(o)
        self.relu = ReLU()

    def forward(self, input):
        out = self.fc(input)
        out = self.bn(out)
        out = self.relu(out)
        return torch.cat([out, input], dim=1)
class Generator(Module):
    def __init__(self, embedding_dim, generator_dim, data_dim):
        super(Generator, self).__init__()
        dim = embedding_dim
        seq = []
        for item in list(generator_dim):
            seq += [Residual(dim, item)]
            dim += item
        seq.append(Linear(dim, data_dim))
        self.seq = Sequential(*seq)

    def forward(self, input):
        data = self.seq(input)
        return data

class SFDPGAN(Synthesizer):
    def __init__(
        self,
        binary=False,
        latent_dim=64,
        generator_dim = (128,128),
        batch_size=64,
        epochs=1000,
        delta=None,
        epsilon=1.0,
        lr = 1e-4,
        #new
        loss="kmm",#
        n_KL_slices=100,
        n_KL_slice_dim=2
    ):
        self.binary = binary
        self.latent_dim = latent_dim
        self.generator_dim = generator_dim
        self.batch_size = batch_size
        self.epochs = epochs
        self.delta = delta
        self.epsilon = epsilon
        self.lr = lr
        self.loss = loss
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.pd_cols = None
        self.pd_index = None
        
        self.n_KL_slices = n_KL_slices
        self.n_KL_slice_dim = n_KL_slice_dim
        
        self.wd_clf = 1
        print('sliced private GAN')

    def train(
        self,
        data,
        categorical_columns=None,
        ordinal_columns=None,
        update_epsilon=None,
        transformer=None,
        continuous_columns=None,
        preprocessor_eps=0.0,
        nullable=False,
    ):
        if update_epsilon:
            self.epsilon = update_epsilon

        train_data = self._get_train_data(
            data,
            style='gan',
            transformer=transformer,
            categorical_columns=categorical_columns, 
            ordinal_columns=ordinal_columns, 
            continuous_columns=continuous_columns, 
            nullable=nullable,
            preprocessor_eps=preprocessor_eps
        )

        data = np.array(train_data)
        self.data_dim = data.shape[1]
        noise_data = torch.randn(data.shape[0], self.n_KL_slices * self.n_KL_slice_dim).to(self.device)
        noise_data_kmm = torch.randn(self.n_KL_slices, data.shape[0], self.n_KL_slice_dim).to(self.device)
        #data_with_noise = np.concatenate((data, noise_data), axis=1)
        #data = data_with_noise
        
        # self._data_sampler = DataSampler(
        #     train_data,
        #     transformer,
        #     return_ix = True
        # )


        if isinstance(data, pd.DataFrame):
            for col in data.columns:
                data[col] = pd.to_numeric(data[col], errors="ignore")
            self.pd_cols = data.columns
            self.pd_index = data.index
            data = data.to_numpy()
        elif isinstance(data, list):
            data = np.array(data)
        elif not isinstance(data, np.ndarray):
            raise ValueError("Data must be a numpy array or pandas dataframe")
        data = np.hstack(((np.arange(data.shape[0])[np.newaxis]).T,data))
        

        dataset = TensorDataset(
            torch.from_numpy(data.astype("float32")).to(self.device)
        )
        dataloader = DataLoader(
            dataset, batch_size=self.batch_size, shuffle=True, drop_last=True
        )
        
        
        ####################FIGURE OUT WHAT NOISE TO ADD
        epsilon = self.epsilon #- preprocessor_eps
        if self.delta is None:
            self.delta =  1 / (48842) #* np.sqrt(48842)
            
        delta=self.delta
        mprime = self.n_KL_slice_dim * self.n_KL_slices
        d=self.data_dim #588 #data_dim#train_data.shape[1]
        #print(data_dim)
        #print(epsilon)
        # First, try approx-opt alpha expression to guess sigma
        aa = mprime/d
        bb = 2 * np.sqrt(mprime * np.log(1/delta) / d)
        cc = epsilon
        sigma_propose = ( (-bb + np.sqrt(bb**2 + 4 * aa *cc))  /(2*aa) )**(-1) # Will always be real and positive

        # Iterative optimization
        iters = 10
        for i in range(iters):
            # Compute implied approximate-opt alpha 
            alpha_star = 1 + np.sqrt(sigma_propose**2 * d * np.log(1/delta) / mprime)

            # Check if implied alpha is outside allowed range and pick closest allowed
            if (alpha_star**2 - alpha_star) > d * sigma_propose**2 / 2:
                # quadratic formula, won't be imaginary or unambiguous
                alpha_star = 0.5 * (1 + np.sqrt(1 + 2 * d * sigma_propose**2))
            # Recompute sigma in case alpha changed, using exact formula for epsilon
            val = 2 * d * (epsilon - np.log(1/delta) / (alpha_star - 1) ) / (mprime * alpha_star  )
            while val <= 0:
                if ((1.2*alpha_star)**2 - 1.2*alpha_star) > d * sigma_propose**2:
                    val = 0.001
                    print('WARNING: unable to find a valid sigma to achieve specified epsilon, delta combination. Using a large sigma.')
                    break
                else:
                    alpha_star *= 1.2
                    #print('WARNING: unable to find a valid sigma to achieve specified epsilon, delta. Using a large sigma.')
                    val = 2 * d * (epsilon - np.log(1/delta) / (alpha_star - 1) ) / (mprime * alpha_star  ) #.01
            val2 = (alpha_star**2 - alpha_star) / d
            sigma_propose = np.sqrt(  (1 / val) + val2 ) #Automatically satisfies constraint if val > 0
        sigma = sigma_propose
        noise = sigma
        epsilon_actual = mprime * alpha_star / (2 * sigma**2 * (d - (alpha_star**2 - alpha_star)*sigma**(-2))) + np.log(1/delta)/(alpha_star - 1)
        print('User specified (epsilon, delta)=(' + str(epsilon) + ',' + str(delta) + '); Chosen sigma = ' + str(sigma) + '; Actual epsilon = ' + str(epsilon_actual))
        
        self.sigma_noise = sigma
        ################################################

        self.generator = Generator(
            self.latent_dim, self.generator_dim, data.shape[1]-1,  #binary=self.binary
        ).to(self.device)
        discriminator = Discriminator(data.shape[1]-1).to(self.device)
        optimizer_d = optim.Adam(discriminator.parameters(), lr=self.lr)

        # privacy_engine = PrivacyEngine(
        #     discriminator,
        #     batch_size=self.batch_size,
        #     sample_size=len(data),
        #     alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
        #     noise_multiplier=3.5,
        #     max_grad_norm=1.0,
        #     clip_per_layer=True,
        # )
        self.kldiv = slicedKLclass(d,self.n_KL_slices,self.n_KL_slice_dim,self.device).to(self.device)
        #privacy_engine.attach(optimizer_d)
        optimizer_g = optim.Adam(self.generator.parameters(), lr=1e-4)

        criterion = nn.BCELoss()

        if self.delta is None:
            self.delta = 1 / (data.shape[0] * np.sqrt(data.shape[0]))
        
        maxi_maxi_norm = 0
        for epoch in range(self.epochs):
#             eps, best_alpha = optimizer_d.privacy_engine.get_privacy_spent(self.delta)

#             if self.epsilon < eps:
#                 if epoch == 0:
#                     raise ValueError(
#                         "Inputted epsilon and sigma parameters are too small to"
#                         + " create a private dataset. Try increasing either parameter and rerunning."
#                     )
#                 break
            
            for i, data in enumerate(dataloader):
                discriminator.zero_grad()

                real_data = data[0].to(self.device)
                #real, ix = self._data_sampler.sample_data(self.batch_size,None, None)
                ix = (real_data[:,0]).int()
                real_data = real_data[:,1:]
                real_data = real_data.to(self.device)
                
                ## GET IX!!!!!!!!!!!!
                X_noise = noise_data[ix,:]
                X_noise_kmm = noise_data_kmm[:,ix,:]
                
                self.generator.zero_grad()
                #########################discriminator
                if 0: #original

                    # train with fake data
                    noise = torch.randn(
                        self.batch_size, self.latent_dim, 1, 1, device=self.device
                    )
                    noise = noise.view(-1, self.latent_dim)
                    fake_data = self.generator(noise)
                    label_fake = torch.full(
                        (self.batch_size,), 0, dtype=torch.float, device=self.device
                    )
                    output = discriminator(fake_data.detach())
                    loss_d_fake = criterion(output.squeeze(), label_fake)
                    loss_d_fake.backward()
                    optimizer_d.step()

                    # train with real data
                    label_true = torch.full(
                        (self.batch_size,), 1, dtype=torch.float, device=self.device
                    )
                    output = discriminator(real_data.float())
                    loss_d_real = criterion(output.squeeze(), label_true)
                    loss_d_real.backward()
                    optimizer_d.step()

                    max_grad_norm = []
                    for p in discriminator.parameters():
                        param_norm = p.grad.data.norm(2).item()
                        max_grad_norm.append(param_norm)
                elif self.loss == 'kmm':
                    # fake data
                    noise = torch.randn(
                        self.batch_size, self.latent_dim, 1, 1, device=self.device
                    )
                    noise = noise.view(-1, self.latent_dim)
                    fake_data = self.generator(noise)
                    
                        
                    #if epoch > self.epoch_to_start_align:
                    with torch.no_grad():
                        maxi_norm = torch.sqrt(torch.max(torch.sum(real_data.view(real_data.shape[0], -1)**2,dim=1))).to(self.device)
                        if maxi_norm > maxi_maxi_norm and epoch == 0:
                            maxi_maxi_norm = maxi_norm
                    #print(maxi_norm)
                    source_features_norm = torch.div(real_data,1) #(2*maxi_norm))
                    target_features_norm = torch.div(fake_data,1) #(2*maxi_norm))
                    kl = sliced_kl_kmm_diff_priv(source_features_norm.view(source_features_norm.shape[0], -1), 
                                                target_features_norm.view(target_features_norm.shape[0], -1),
                                                self.kldiv.thetas,
                                                2,
                                                self.device,
                                                sigma_noise= self.sigma_noise * 2 * maxi_maxi_norm, 
                                                noise_samples=X_noise_kmm,
                                                n_slice=self.n_KL_slices, 
                                                slice_dim=self.n_KL_slice_dim)
                    loss_g = self.wd_clf * kl
                    
                else: #SWD
                    # fake data
                    noise = torch.randn(
                        self.batch_size, self.latent_dim, 1, 1, device=self.device
                    )
                    noise = noise.view(-1, self.latent_dim)
                    fake_data = self.generator(noise)
                    
                    #Sliced wasserstein loss
                    with torch.no_grad():
                        maxi_norm = torch.sqrt(torch.max(torch.sum(real_data.view(real_data.shape[0], -1)**2,dim=1))).to(self.device)
                        if maxi_norm > maxi_maxi_norm and epoch == 0:
                            maxi_maxi_norm = maxi_norm

                    source_features_norm = torch.div(real_data,(2*maxi_maxi_norm))
                    target_features_norm = torch.div(fake_data,(2*maxi_maxi_norm))
                    wasserstein_distance = sliced_wasserstein_distance_diff_priv(source_features_norm.view(source_features_norm.shape[0], -1), 
                                                                                 target_features_norm.view(target_features_norm.shape[0], -1), 
                                                                                 self.kldiv.thetas,
                                                        2,
                                                        self.device,
                                                        sigma_noise=self.sigma_noise, 
                                                                                 noise_samples=X_noise)
                    wasserstein_distance = torch.mul(wasserstein_distance,2*maxi_maxi_norm)

         
                    #source_preds = self.data_classifier(source_features)                
                    #self.criterion = nn.CrossEntropyLoss()
                    #clf_loss = self.criterion(source_preds, source_y)
                    loss_g = self.wd_clf * wasserstein_distance   

                #privacy_engine.max_grad_norm = max_grad_norm
                ###########################################################

                # train generator
                
                #label_g = torch.full(
                #    (self.batch_size,), 1, dtype=torch.float, device=self.device
                #)
                #output_g = discriminator(fake_data)
                #loss_g = criterion(output_g.squeeze(), label_g)
                loss_g.backward()
                optimizer_g.step()

                # manually clear gradients
                #for p in discriminator.parameters():
                #    if hasattr(p, "grad_sample"):
                #        del p.grad_sample
                # autograd_grad_sample.clear_backprops(discriminator)
                
                if self.delta is None:
                    self.delta = 1 / data.shape[0]
            
            
            print(
                    "sliced KL Epoch %d, Loss G: %.4f, Loss D: %.4f"                
                    % (epoch + 1, loss_g.detach().cpu(), 0),
                    flush=True,
                )
                

    def generate(self, n):
        steps = n // self.batch_size + 1
        data = []
        for i in range(steps):
            noise = torch.randn(
                self.batch_size, self.latent_dim, 1, 1, device=self.device
            )
            noise = noise.view(-1, self.latent_dim)

            fake_data = self.generator(noise)
            data.append(fake_data.detach().cpu().numpy())

        data = np.concatenate(data, axis=0)
        data = data[:n]

        return self._transformer.inverse_transform(data)

    def fit(self, data, *ignore, transformer=None, categorical_columns=[], ordinal_columns=[], continuous_columns=[], preprocessor_eps=0.0, nullable=False):
        self.train(data, transformer=transformer, categorical_columns=categorical_columns, ordinal_columns=ordinal_columns, continuous_columns=continuous_columns, preprocessor_eps=preprocessor_eps, nullable=nullable)

    def sample(self, n_samples):
        return self.generate(n_samples)
