import warnings
import numpy as np
import scipy.stats
import scipy.special
from sklearn.utils import check_random_state
from sklearn.utils.validation import check_is_fitted, check_X_y, check_array, NotFittedError
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
from sklearn.base import BaseEstimator
import importlib

import pykliep2 
import pykmm 
import adapt
from adapt.instance_based import KLIEP
importlib.reload(pykliep2)
importlib.reload(pykmm)
import adapt.metrics
importlib.reload(adapt.metrics)
from adapt.instance_based import KMM
from adapt.metrics import make_uda_scorer, neg_j_score
from sklearn.linear_model import Ridge

class PropertySplitter(BaseEstimator):
  def __init__(self, n_splits=4, sharpness_scale=10, epsilon_base=0.01,  random_state=None):
    assert n_splits == int(n_splits), 'Should be integer'
    assert sharpness_scale == int(sharpness_scale)
    assert epsilon_base >= 0 and epsilon_base <= 1, 'epsilon_base should be between 0 and 1'

    # Setup empirical quantiles
    # Beta(r, n+1-r) from where r is rank r=n -> Beta(n,1) when r=1 Beta(1,n)
    # See Equation 12 in https://arxiv.org/pdf/1905.12466.pdf
    r = np.arange(n_splits) 
    shift = np.arange(sharpness_scale).reshape(-1, 1) + 1
    # Use broadcasting to create matrix of beta variables
    param1 = sharpness_scale * r + shift
    param2 = n_splits * sharpness_scale + 1 - param1 #sharpness_scale*(n_splits + 1 - r) - shift
    beta_dists = scipy.stats.beta(param1, param2)

    # Save beta distributions 
    #  (note: these are the same regardless of the data)
    self._beta_dists = beta_dists
    
    # Public properties
    self.n_splits = n_splits
    self.sharpness_scale = sharpness_scale
    self.epsilon_base = epsilon_base
    self.random_state = random_state
    #self.max_kde_samples = max_kde_samples
    
  def get_n_splits(self):
    return self.n_splits

  def fit(self, Z_all=None, z_split=None):
    _ = self.split(Z_all, z_split) # Just run split function
    return self
  
  def split(self, Z_all=None, z_split=None, groups=None):
    Z_all, z_split = check_X_y(Z_all, z_split)

    u = _empirical_cdf_projection(z_split, shuffle=True, random_state=self.random_state)

    split_masks = self._beta_split(
      u, n_splits = self.n_splits, sharpness_scale = self.sharpness_scale, 
      epsilon_base = self.epsilon_base)
    
    indices = np.arange(z_split.shape[0])
    splits = [
      (indices[train_mask], indices[test_mask])
      for train_mask, test_mask in split_masks
    ]

    self.density_ = _compute_kde(z_split) #density of split feature 
    self.z_splits_ = z_split.copy() #All split values
    self.u_splits_ = u #ecdf of split values
    self.splits_ = splits #calculated splits

    # Generate train_idx/test_idx pairs
    for train_ind, test_ind in splits:
      yield train_ind, test_ind
      
  def sample_weights(self, z_gen=None, z_held=None, split_index=None, return_n_samples_effective=False, estimation_method=None, proposal_density=None, z_gen_init=None, C=1):
    '''
    Computes sample weights based on original data for splits and test points.

    Returns sample weights AND effective sample size
    '''
    #Z, z, groups = self._check_input(Z, z, groups)
    check_is_fitted(self, ['density_', 'z_splits_', 'u_splits_','splits_'])
    
    if split_index is None:
      warnings.warn('split_index was None, setting to 0 but it should be explicitly specified')
      split_index = 0
    
    if(estimation_method == "kde" and proposal_density==None and z_gen_init is not None):
        print("Calculating the proposal density using KDE based on initial samples")
        proposal_density = _compute_kde(z_gen_init)  # q(z)
    elif(estimation_method == "kde"):
        print("Calculating the proposal density using KDE based on all samples")
        proposal_density = _compute_kde(z_gen)
 
    weights, n_samples_effective = self._test_weights(
      z_gen, z_held, split_index, proposal_density, estimation_method=estimation_method,
      n_splits=self.n_splits, sharpness_scale=self.sharpness_scale, 
      epsilon_base=self.epsilon_base, C = C)

    if return_n_samples_effective:
      return  weights, n_samples_effective, proposal_density
    return weights, proposal_density
  
  def ks_evaluate_split(self, Z_gen, z_gen, Z_held, z_held, split_index, \
    estimation_method, proposal_density=None, w_samples=None, n_samples_effective=None, C = 1):
    
    Z_gen, z_gen = check_X_y(Z_gen, z_gen)
    Z_held, z_held = check_X_y(Z_held, z_held)
    
    if(w_samples is None and n_samples_effective is None):
        # Get sample weights
        _,w_samples, n_samples_effective = self.sample_weights(
          Z_gen, z_gen, z_held, estimation_method=estimation_method, proposal_density=proposal_density, 
          split_index=split_index, return_n_samples_effective=True, C = C   
        )
    
    # Loop through each feature of Z and compute test statistic / p-value
    return [
      _ks_weighted(
        z_col_gen, z_col_held, 
        w_samples * n_samples_effective, np.ones_like(z_held.ravel()))
      for z_col_gen, z_col_held in zip(Z_gen.T, Z_held.T)
    ]

  # Define the Kernel Mean Matching (KMM) function 
  # from: https://medium.com/@evertongomede/kernel-mean-matching-kmm-a-powerful-technique-for-domain-adaptation-abb01d53f75e
  def _kmm(self,source_data, target_data, lambd=1.0, kernel_width=1.0, num_iterations=1000):
    n_source = len(source_data)
    n_target = len(target_data)
    
    # Initialize weights for the source data
    weights = np.ones(n_source) / n_source

    for _ in range(num_iterations):
        # Compute kernel matrices
        K_ss = np.exp(-((source_data[:, np.newaxis] - source_data)**2) / (2 * kernel_width**2))
        K_st = np.exp(-((source_data[:, np.newaxis] - target_data)**2) / (2 * kernel_width**2))
        
        # Update weights
        weights = weights * np.sqrt(n_target / (weights @ K_st @ weights))
    
    return weights
  
  def _test_weights(self, z_gen, z_held, fold_idx, proposal_density = None,\
    estimation_method = None, n_splits = 4, sharpness_scale = 10, epsilon_base = 0.01, C = 1):
    # Project new point z based on saved "original training data" self.y_
    
    if(estimation_method == "kde"):
        #print("Using KDE for weight estimation")
        test_u = _project_test(z_gen, self.z_splits_)
        # project z_gen into the space occupied to by the original zs
        # Now compute p(s|z) = p(s|u) based on these u values
        log_p_s_given_z = self._compute_cond_probs(test_u, use_log_prob=True)[:, fold_idx]
        # Get the density of the test samples based on the fitted density model 
        # for original training data
        log_p_z = self.density_.score_samples(z_gen.reshape(-1,1))
        log_p_s = -np.log(self.n_splits) # Equally probable splits
        log_p_z_given_s = log_p_z + log_p_s_given_z - log_p_s  # p(z|s) = (p(z) p(s|z)) / p(s)

        # Estimate density for generated samples   
        proposal_density = _compute_kde(z_gen)  # q(z)
        log_q_z = proposal_density.score_samples(z_gen.reshape(-1, 1))
        
        # Now take the ratio of these weights to get this weight?
        log_weights = log_p_z_given_s - log_q_z
        weights = np.exp(log_weights)
    
    elif(estimation_method == "kmm1"):
        kmm = pykmm.KMM(kernel_type='rbf', B = 1)
        weights = kmm.fit(z_gen.reshape(-1, 1),z_held.reshape(-1, 1))
    elif(estimation_method == "kmm2"):
        #print(f"using gamma = {C/np.std(z_held.reshape(-1, 1))} C = {C} std = {np.std(z_held.reshape(-1, 1))}")
        kmm = KMM(
            estimator=Ridge(),
            Xt=z_held,
            kernel="rbf",  # Gaussian kernel
            gamma=C/np.std(z_held.reshape(-1, 1)),     # Bandwidth of the kernel
            verbose=0,
            random_state=0
        )
        weights = kmm.fit_weights(z_gen.reshape(-1, 1), z_held.reshape(-1, 1));            
    elif(estimation_method == "kliep1"):
        # https://github.com/srome/pykliep
        # kliep.fit(X_train, X_test) # keyword arguments are X_train and X_test
        # weights = kliep.predict(X_train)
        # get weights for X_train w.r.t X_test
        
        kliep = pykliep2.DensityRatioEstimator(random_state = 42, verbose = 0)
        kliep.fit(z_gen.reshape(-1,1), z_held.reshape(-1,1)) 
        weights = kliep.predict(z_gen.reshape(-1,1)) #* np.exp(log_p_s_given_z)/np.exp(log_p_s)
    elif(estimation_method == "kliep2"):

        kliep = KLIEP(kernel="rbf", gamma=[10**(i-4) for i in range(5)], random_state=0,verbose=0)
        weights = kliep.fit_weights(z_gen.reshape(-1, 1), z_held.reshape(-1, 1))    
    elif(estimation_method == "kliep_mult"):
        #print("Using KLIEP with multiplication for weight estimation")

        test_u = _project_test(z_gen, self.z_splits_)
        log_p_s_given_z = self._compute_cond_probs(test_u, use_log_prob=True)[:, fold_idx]
        log_p_s = -np.log(self.n_splits) # Equally probable splits
        kliep = pykliep2.DensityRatioEstimator(random_state = 42, verbose = 0)
        kliep.fit(z_gen.reshape(-1,1), z_held.reshape(-1,1)) 
        weights = kliep.predict(z_gen.reshape(-1,1)) * np.exp(log_p_s_given_z)/np.exp(log_p_s)
    elif(estimation_method == "kliep2_mult"):
        #print("Using KLIEP with multiplication for weight estimation")

        test_u = _project_test(z_gen, self.z_splits_)
        log_p_s_given_z = self._compute_cond_probs(test_u, use_log_prob=True)[:, fold_idx]
        log_p_s = -np.log(self.n_splits) # Equally probable splits
        kliep = KLIEP(kernel="rbf", gamma=[10**(i-4) for i in range(5)], random_state=0,verbose=0)
        weights = kliep.fit_weights(z_gen.reshape(-1, 1), z_held.reshape(-1, 1)) * np.exp(log_p_s_given_z)/np.exp(log_p_s) 
               
    elif(estimation_method == "bratio"):
        test_u = _project_test(z_gen, self.z_splits_)
        log_p_s_given_z = self._compute_cond_probs(test_u, use_log_prob=True)[:, fold_idx]
        log_p_s = -np.log(self.n_splits) # Equally probable splits
        weights = np.exp(log_p_s_given_z)/np.exp(log_p_s)
    # More common n_effective from the literature and should be used in 
    n_samples_effective = np.sum(weights)**2 / np.sum(weights**2) # 1/S_2 from book
    
    return weights, n_samples_effective
  

  
  def _compute_cond_probs(self, u, use_log_prob=False):
    # Unpack parameters from self
    n_splits, sharpness_scale, epsilon_base = (
      self.n_splits,
      self.sharpness_scale,
      self.epsilon_base
    )
    sharpness_scale = int(sharpness_scale)

    # Compute conditional probabilities in log space for numerical stability
    log_probs = self._beta_dists.logpdf(u.reshape(-1, 1, 1)) # shape (n_samples, sharpness_scale, n_splits)
    log_probs = scipy.special.logsumexp(log_probs, axis=-2) # Sum over sharpness, new shape (n_samples, n_splits)
    # Normalize (though probably unnecessary since we normalize later 
    log_cond_probs = log_probs - scipy.special.logsumexp(log_probs, axis=1).reshape(-1,1) # Normalize to make conditionals
    # The line below should simply normalize things so that the sum to 1
    #  because there are n_splits * sharpness_scale beta distributions 
    #  but we know that an even mixture of this produces a uniform distribution
    #  Thus, the marginal density values for each component are equal to the conditional 
    #  probabilities (again, a very special case when the mixture is a uniform).
    log_marg_probs = log_probs - np.log(n_splits * sharpness_scale) # Normalize mixture by number of components
    assert np.allclose(log_cond_probs, log_marg_probs), 'Normalization seems off'

    # Mixture of uniform distribution and cond_probs
    uniform_pdf = np.ones_like(log_cond_probs) # Uniform pdf is just 1 everywhere between 0 and 1
    #cond_probs = (1 - epsilon_base) * cond_probs + epsilon_base * uniform_pdf
    with np.errstate(divide='ignore'):
      log_cond_probs = scipy.special.logsumexp(np.array([
        np.log1p(-epsilon_base) + log_cond_probs, 
        np.log(epsilon_base) + np.log(uniform_pdf)
      ]), axis=0)

    # Renormalize to get true conditional probabilities
    log_cond_probs = log_cond_probs - scipy.special.logsumexp(log_cond_probs, axis=1).reshape(-1,1)
    if use_log_prob:
      return log_cond_probs
    return np.exp(log_cond_probs)

  def _beta_split(self, u, n_splits=4, sharpness_scale=10, epsilon_base=0.01):
    cond_probs = self._compute_cond_probs(u)

    rng = check_random_state(self.random_state)
    # Sample based on conditional probabilities
    test_masks = np.array([
      rng.multinomial(1, prob) == 1
      for i, prob in enumerate(cond_probs)
    ])
    train_masks = ~test_masks

    # Setup splits
    splits = np.empty((n_splits, 2, len(u)), dtype=bool)
    splits[:,0,:] = train_masks.T
    splits[:,1,:] = test_masks.T
    assert splits[:,1,:].sum() == len(u), 'All test splits should be all points'
    assert np.all(splits[:,0,:].sum(axis=0) == n_splits - 1), 'Each point should be in n_splits-1 train splits'
    assert np.all(splits[:,1,:].sum(axis=0) == 1), 'Each point should only be in one test split'
    return splits

def _compute_kde(a, max_samples=None, random_state=0):
  '''Compute KDE for samples'''
  
  if max_samples is not None:
    a = np.asarray(a).copy()
    rng = check_random_state(random_state)
    rng.shuffle(a)
    a = a[:min(len(a), max_samples)]
  scotts_rule = scipy.stats.gaussian_kde(a).scotts_factor()
  #kde = KernelDensity(bandwidth="scott", kernel="exponential")
  kde = KernelDensity(bandwidth=scotts_rule)
  kde.fit(a.reshape(-1, 1))
  return kde


def _project_test(test_x, base_x, return_nearest=False):
  '''Projects test_x points into [0,1] using base_x for the empirical CDF but handling ties for discrete values'''
  def find_closest(A, target):
    #From https://stackoverflow.com/questions/8914491/finding-the-nearest-value-and-return-the-index-of-array-in-python
    # Handles edge cases of being larger or smaller than A
    # Decrements counter if it's closer to left
    #A must be sorted
    idx = A.searchsorted(target)
    idx = np.clip(idx, 1, len(A)-1)
    left = A[idx-1]
    right = A[idx]
    idx -= target - left < right - target
    return idx
  
  # Sort test Z but keep inverse sort idx for end
  sort_idx = np.argsort(test_x)
  inverse_sort_idx = np.argsort(sort_idx)

  sorted_test_x = test_x[sort_idx]

  # Compute unique for base and then find nearest point 
  #  in base for each test point
  unique_base, counts_base = np.unique(base_x, return_counts=True)
  nearest_ind = find_closest(unique_base, sorted_test_x)
  ind_nearest_test = unique_base[nearest_ind]
  
  # Compute unique indices for test points
  unique_nearest_ind, counts_nearest_ind = np.unique(nearest_ind, return_counts=True)
  
  # Using above we can compute the relative u values within each unique value
  overlap_rank = np.concatenate([np.arange(c) + 1 for c in counts_nearest_ind])
  overlap_num = np.concatenate([c * np.ones(c) for c in counts_nearest_ind])
  relative_u = (overlap_rank-0.5)/overlap_num
  
  # Now we need to scale these u values within the appropriate interval
  # Create ecdf and then determine min max of the interval
  # Finally scale the relative u in [0,1] into this new interval
  ecdf = np.concatenate(([0], np.cumsum(counts_base)/np.sum(counts_base)))
  min_u = np.concatenate([ecdf[uniq_ind] * np.ones(c) for uniq_ind, c in zip(unique_nearest_ind, counts_nearest_ind)])
  max_u = np.concatenate([ecdf[uniq_ind + 1] * np.ones(c) for uniq_ind, c in zip(unique_nearest_ind, counts_nearest_ind)])
  test_u = (max_u - min_u) * relative_u + min_u # Shift relative to absolute
  
  # Reverse sort permutation on Z
  ind_nearest_test = ind_nearest_test[inverse_sort_idx]
  test_u = test_u[inverse_sort_idx]
  
  # Sanity-check validations
  u_sort_idx = np.argsort(test_u)
  assert np.all(test_u > 0) and np.all(test_u < 1), 'u is not between 0 or 1 exclusive'
  assert np.all(np.diff(ind_nearest_test[u_sort_idx]) >= 0), 'Should be non-negative difference (monotonic)'
  assert np.all(np.diff(test_u[u_sort_idx]) >= 0), 'Should be non-negative difference (monotonic)'

  if return_nearest:
    return test_u, ind_nearest_test
  return test_u

def _empirical_cdf_projection(x, shuffle=True, return_ranks=False, return_order=False, random_state=None):
  if not shuffle:
    # Compute ranks without shuffling
    order = np.argsort(x)
    ranks = np.argsort(order) + 1
  else:
    # Compute ranks with random permutation, i.e., shuffling
    rng = check_random_state(random_state)
    rand_perm = rng.permutation(len(x))
    perm_order = np.argsort(x[rand_perm])
    ranks = np.empty_like(perm_order)
    ranks[rand_perm] = np.argsort(perm_order) + 1 # 1...len(x)
    order = np.argsort(ranks)
    
  # Project to [0,1] where we project the point to the middle of the range
  # Note: Since starts with 1 then it will be 0.5/n for the first point.
  #   The ECDF would be 1/n so we are projecting to the middle of the step
  u = (ranks - 0.5)/len(x)
  
  # Quick validations
  assert np.all(u > 0) and np.all(u < 1), 'u is not between 0 or 1 exclusive'
  assert np.all(np.diff(x[order]) >= 0), 'When reordered all x should always have a non-negative diff'
  assert np.all(np.diff(ranks[order]) >= 0), 'When reordered all x should always have a non-negative diff'
  assert np.allclose(np.diff(u[order]), u[order][1] - u[order][0]), 'The difference between each point should be the same'

  if return_ranks or return_order:
    out = (u,)
  else:
    return u
  if return_ranks:
    out = (*out, ranks)
  if return_order:
    out = (*out, order)
  return out
    

def _ks_weighted(data1, data2, wei1, wei2, alternative='two-sided'):
  '''Weighted KS statistical test'''
  # MODIFIED slightly to use weights for p-value from 
  # Use "effective" sample size based on sum of weights.
  # If weights are all 1 this should reduce to unweighted ks
  # https://stackoverflow.com/questions/40044375/how-to-calculate-the-kolmogorov-smirnov-statistic-between-two-weighted-samples
  ix1 = np.argsort(data1)
  ix2 = np.argsort(data2)
  data1 = data1[ix1]
  data2 = data2[ix2]
  wei1 = wei1[ix1]
  wei2 = wei2[ix2]
  data = np.concatenate([data1, data2])
  cwei1 = np.hstack([0, np.cumsum(wei1)/sum(wei1)])
  cwei2 = np.hstack([0, np.cumsum(wei2)/sum(wei2)])
  cdf1we = cwei1[np.searchsorted(data1, data, side='right')]
  cdf2we = cwei2[np.searchsorted(data2, data, side='right')]
  d = np.max(np.abs(cdf1we - cdf2we))
  # calculate p-value
  # MODIFIED for effective sample size
  # Same as unweighted if weights are all 1
  #n1 = data1.shape[0]
  #n2 = data2.shape[0]

  # n_effective 1/S_2 from 
  # Numerical Methods of Statistics, 2nd edition, John F. Mohanan
  # Section 12.4 on Importance Sampling and Weighted Observations
  n1 = np.sum(wei1) ** 2 / np.sum(wei1 ** 2)
  n2 = np.sum(wei2) ** 2 / np.sum(wei2 ** 2)
  m, n = sorted([float(n1), float(n2)], reverse=True)
  en = m * n / (m + n)
  if alternative == 'two-sided':
    prob = scipy.stats.distributions.kstwo.sf(d, np.round(en))
  else:
    z = np.sqrt(en) * d
    # Use Hodges' suggested approximation Eqn 5.3
    # Requires m to be the larger of (n1, n2)
    expt = -2 * z**2 - 2 * z * (m + 2*n)/np.sqrt(m*n*(m+n))/3.0
    prob = np.exp(expt)
  return d, prob
