"""Data loading and pre-processing utilities."""

from pdb import set_trace
from typing import Tuple, Callable, Sequence, Text, Dict, Union

import os
import json

from absl import logging

import jax.numpy as np
from jax import random
import torch

import pandas as pd

import numpy as onp
from scipy.stats import norm
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity

import pdb


DataSynth = Tuple[Dict[Text, Union[np.ndarray, float, None]],
                  np.ndarray, np.ndarray]
DataReal = Dict[Text, Union[np.ndarray, float, None]]
ArrayTup = Tuple[np.ndarray, np.ndarray]

Equations = Dict[Text, Callable[..., np.ndarray]]


# =============================================================================
# NOISE SOURCES
# =============================================================================
def std_normal_1d(key: np.ndarray, num: int) -> np.ndarray:
  """Generate a Gaussian for the confounder."""
  return random.normal(key, (num,))


def std_normal_2d(key: np.ndarray, num: int) -> ArrayTup:
  """Generate a multivariate Gaussian for the noises e_X, e_Y."""
  key1, key2 = random.split(key)
  return random.normal(key1, (num,)), random.normal(key2, (num,))

def normal_2d(k):
  """Generate a multivariate Gaussian for the noises e_X, e_Y with mean k"""
  def std_normal_2d(key: np.ndarray, num: int):
    key1, key2 = random.split(key)
    return k + random.normal(key1, (num,)), k + random.normal(key2, (num,))
  return std_normal_2d

def normal_3d(k):
  """Generate a multivariate Gaussian for the noises e_X, e_Y with mean k"""
  def std_normal_3d(key: np.ndarray, num: int) -> ArrayTup:
    key1, key2 = random.split(key)
    key2, key3 = random.split(key2)
    return k+random.normal(key1, (num,)), k+random.normal(key2, (num,)), k+random.normal(key3, (num,))
  return std_normal_3d

def std_normal_3d(key: np.ndarray, num: int) -> ArrayTup:
  """Generate a multivariate Gaussian for the noises e_X, e_Y."""
  key1, key2 = random.split(key)
  key2, key3 = random.split(key2)
  return random.normal(key1, (num,)), random.normal(key2, (num,)), random.normal(key3, (num,))


# =============================================================================
# SYNTHETIC STRUCTURAL EQUATIONS
# =============================================================================


structural_equations = {
  "lin1": {
    "noise": std_normal_2d,
    "confounder": std_normal_1d,
    "f_z": std_normal_1d,
    "f_x": lambda z, c, ex: 0.5 * z + 3 * c + ex,
    "f_y": lambda x, c, ey: x - 6 * c + ey,
  },
  "lin2": {
    "noise": std_normal_2d,
    "confounder": std_normal_1d,
    "f_z": std_normal_1d,
    "f_x": lambda z, c, ex: 3.0 * z + 0.5 * c + ex,
    "f_y": lambda x, c, ey: x - 6 * c + ey,
  },
  "quad1": {
    "noise": std_normal_2d,
    "confounder": std_normal_1d,
    "f_z": std_normal_1d,
    "f_x": lambda z, c, ex: 0.5 * z + 3 * c + ex,
    "f_y": lambda x, c, ey: 0.3 * x ** 2 - 1.5 * x * c + ey,
  },
  "quad2": {
    "noise": std_normal_2d,
    "confounder": std_normal_1d,
    "f_z": std_normal_1d,
    "f_x": lambda z, c, ex: 3.0 * z + 0.5 * c + ex,
    "f_y": lambda x, c, ey: 0.3 * x ** 2 - 1.5 * x * c + ey,
  },
  "lin1-2d": {
    "ex": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "f_z": std_normal_2d,
    "f_x": lambda z, c, ex:  0.5 * z + 2 * c + ex, 
    "f_y": lambda x1, x2, c1, c2, ey: x1 +  x2 - 3 * (x1 + x2) * (c1 + c2) + ey,
  },
   "lin2-2d": {
    "ex": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "f_z": std_normal_2d,
    "f_x": lambda z, c, ex: 2 * z + 1 * c + ex, 
    "f_y": lambda x1, x2, c1, c2, ey: 5*x1 + 6*x2 - x1 * (c1 + c2) + ey,
  },
  "lin3-2d": {
    "ex": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "f_z": std_normal_2d,
    "f_x": lambda z, c, ex:  z + 2 * c + ex, 
    "f_y": lambda x1, x2, c1, c2, ey: 2*x1 + x2 - 1 * (c1 + c2) + ey,
  },
  "quad1-2d": {
    "ex": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "f_z": std_normal_2d,
    "f_x": lambda z, c, ex:  z + 2 * c + ex, 
    "f_y": lambda x1, x2, c1, c2, ey: 2*x1**2 + 2*x2**2 - 1 * (c1 + c2) + ey,
  },
  "quad2-2d": {
    "ex": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "f_z": std_normal_2d,
    "f_x": lambda z, c, ex: 2 * z + 1 * c + ex, 
    "f_y": lambda x1, x2, c1, c2, ey: 5*x1**2 + 5*x2**2 - (x1 + x2) * (c1 + c2) + ey,
  },
  "lin1-3d": {
    "ex": std_normal_3d,
    "ey": std_normal_1d,
    "confounder": std_normal_3d,
    "f_z": std_normal_3d,
    "f_x": lambda z, c, ex: z + 2 * c + ex, 
    "f_y": lambda x1, x2, x3, c1, c2, c3, ey: x1 + x2 + 2 * x3 + 2 * (c1 + c2) + ey,
  },
  "lin2-3d": {
    "ex": std_normal_3d,
    "ey": std_normal_1d,
    "confounder": std_normal_3d,
    "f_z": std_normal_3d,
    "f_x": lambda z, c, ex: 2 * z + 1 * c + ex, 
    "f_y": lambda x1, x2, x3, c1, c2, c3, ey: x1 + x2 + 2 * x3 - 0.5 * (x1 + x2) * (c1 + c2) + ey,
  },
  "quad1-3d": {
    "ex": std_normal_3d,
    "ey": std_normal_1d,
    "confounder": std_normal_3d,
    "f_z": std_normal_3d,
    "f_x": lambda z, c, ex:  z + 2 * c + ex, 
    "f_y": lambda x1, x2, x3, c1, c2, c3, ey:  x1**2 + x2**2 + 2 * x3 - 2 * (c1 + c2 + c3) + ey,
  },
  "quad2-3d": {
    "ex": std_normal_3d,
    "ey": std_normal_1d,
    "confounder": std_normal_3d,
    "f_z": std_normal_3d,
    "f_x": lambda z, c, ex: 2 * z + 1 * c + ex, 
    "f_y": lambda x1, x2, x3, c1, c2, c3, ey: 2 * x1**2 + 2 * x2**2 + 2 * x3 \
                  - 0.3 * (x2 + x3) * (c1 + c2 + c3) + ey,
  },
  "test-3d": {
    "ex": std_normal_3d,
    "ey": std_normal_1d,
    "confounder": std_normal_3d,
    "f_z": std_normal_3d,
    "f_x": lambda z, c, ex:  z - 3 * c + ex, 
    "f_y": lambda x1, x2, x3, c1, c2, c3, ey: 3*x1 +  4*x2 - x3 - 2*(x1 + x2 + 2*x3) * (c1 + c2 + c3) + ey,
  },
}

structural_equations_fd = {
  "fd-lin2-2d": {
    "ex": std_normal_2d,
    "em": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "u_xy": std_normal_2d,
    "f_x": lambda u_xy, ex: u_xy + ex,
    "f_m": lambda x, c, em: 3 * x - em + c, 
    "f_y": lambda m1, m2, c1, c2, u1, u2, ey: 2 * m1 + m2 - 0.3 * (m1 + m2) * (c1 + c2 + u1 + u2)  + ey,
  },
  "fd-lin1-2d": {
    "ex": std_normal_2d,
    "em": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "u_xy": std_normal_2d,
    "f_x": lambda u_xy, ex: u_xy + ex,
    "f_m": lambda x, c, em: x - em + 3*c, 
    "f_y": lambda m1, m2, c1, c2, u1, u2, ey: 2 * m1 + m2 -  (m1 + m2) * (c1 + c2 + u1 + u2) + ey,
  },
   "fd-test-2d": {
    "ex": std_normal_2d,
    "em": std_normal_2d,
    "ey": std_normal_1d,
    "confounder": std_normal_2d,
    "u_xy": std_normal_2d,
    "f_x": lambda u_xy, ex: 5*u_xy + ex,
    "f_m": lambda x, c, em: x - em + 3*c, 
    "f_y": lambda m1, m2, c1, c2, u1, u2, ey: 2 * m1 + m2 -  (m1 + m2) * (c1 + c2 + u1 + u2) + ey,
  }
}


# =============================================================================
# DATA GENERATORS
# =============================================================================

def whiten(
  inputs: Dict[Text, np.ndarray]
) -> Dict[Text, Union[float, np.ndarray, None]]:
  """Whiten each input."""
  res = {}
  for k, v in inputs.items():
    if v is not None:
      mu = np.mean(v, 0)
      std = np.maximum(np.std(v, 0), 1e-7)
      res[k + "_mu"] = mu
      res[k + "_std"] = std
      res[k] = (v - mu) / std
    else:
      res[k] = v
  return res


def whiten_with_mu_std(val: np.ndarray, mu: float, std: float) -> np.ndarray:
  return (val - mu) / std

def unwhiten_with_mu_std(val: np.ndarray, mu: float, std: float) -> np.ndarray:
  return (val*std + mu)

def get_nonconstant_axis(d):
  for ax in range(d.shape[1]):
    if(d[0,ax] != d[1, ax]): return d[:,ax]

def get_synth_data(
  key: np.ndarray,
  num: int,
  equations: Text,
  num_xstar: int = 100,
  external_equations: Equations = None,
  disconnect_instrument: bool = False
) -> DataSynth:
  """Generate some synthetic data.

    Args:
      key: A JAX random key.
      num: The number of examples to generate.
      equations: Which structural equations to choose for x and y. Default: 1
      num_xstar: Size of grid for interventions on x.
      external_equations: A dictionary that must contain the keys 'f_x' and
        'f_y' mapping to callables as values that take two np.ndarrays as
        arguments and produce another np.ndarray. These are the structural
        equations for X and Y in the graph Z -> X -> Y.
        If this argument is not provided, the `equation` argument selects
        structural equations from the pre-defined dict `structural_equations`.
      disconnect_instrument: Whether to regenerate random (standard Gaussian)
        values for the instrument after the data has been generated. This
        serves for diagnostic purposes, i.e., looking at the same x, y data,

    Returns:
      A 3-tuple (values, xstar, ystar) consisting a dictionary `values`
          containing values for x, y, z, confounder, ex, ey as well as two
          array xstar, ystar containing values for the true cause-effect.
  """
  if external_equations is not None:
    eqs = external_equations
  else:
    eqs = structural_equations[equations]

  key, subkey = random.split(key)
  ex, ey = eqs["noise"](subkey, num)
  key, subkey = random.split(key)
  confounder = eqs["confounder"](subkey, num)
  key, subkey = random.split(key)
  z = eqs["f_z"](subkey, num)
  x = eqs["f_x"](z, confounder, ex)
  y = eqs["f_y"](x, confounder, ey)

  values = whiten({'x': x, 'y': y, 'z': z, 'confounder': confounder,
                   'ex': ex, 'ey': ey})

  # Evaluate E[ Y | do(x^*)] empirically
  xmin, xmax = np.min(x), np.max(x)
  xstar = np.linspace(xmin, xmax, num_xstar)
  ystar = []
  for _ in range(500):
    key, subkey = random.split(key)
    tmpey = eqs["noise"](subkey, num_xstar)[1]
    key, subkey = random.split(key)
    tmpconf = eqs["confounder"](subkey, num_xstar)
    tmp_ystar = whiten_with_mu_std(
      eqs["f_y"](xstar, tmpconf, tmpey), values["y_mu"], values["y_std"])
    ystar.append(tmp_ystar)
  ystar = np.array(ystar)
  xstar = whiten_with_mu_std(xstar, values["x_mu"], values["x_std"])
  if disconnect_instrument:
    key, subkey = random.split(key)
    values['z'] = random.normal(subkey, shape=z.shape)
  return values, xstar, ystar


def get_synth_data_2d(
  key: np.ndarray,
  num: int,
  equations,
  axis,
  num_xstar_bound,
  num_xstar: int = 100,
  save_equation = True,
  external_equations = None
):
    """Generate some synthetic data.

    Args:
      key: A JAX random key.
      num: The number of examples to generate.
      equations: Which structural equations to choose for x and y. Default: 1
      num_xstar: Size of grid for interventions on x.
      external_equations: A dictionary that must contain the keys `f_x' and
        `f_y' mapping to callables as values that take two np.ndarrays as
        arguments and produce another np.ndarray. These are the structural
        equations for X and Y in the graph Z -> X -> Y.
        If this argument is not provided, the `equation` argument selects
        structural equations from the pre-defined dict `structural_equations`.
      disconnect_instrument: Whether to regenerate random (standard Gaussian)
        values for the instrument after the data has been generated. This
        serves for diagnostic purposes, i.e., looking at the same x, y data,

    Returns:
      A 3-tuple (values, xstar, ystar) consisting a dictionary `values`
          containing values for x, y, z, confounder, ex, ey as well as two
          array xstar, ystar containing values for the true cause-effect.
    """
    if external_equations is not None:
        eqs = external_equations
    else:
        eqs = structural_equations[equations]
    
    # Get x_value for xstar_bound to make sure we have the same each time
    key_bound = random.PRNGKey(42)
    key_bound, subkey_bound = random.split(key_bound)
    _ex = eqs["ex"](key_bound, 2000)
    key_bound, subkey_bound = random.split(key_bound)
    _c = eqs["confounder"](key_bound, 2000)
    key_bound, subkey_bound = random.split(key_bound)
    _z = eqs["f_z"](key_bound, 2000)
    _x_bound = eqs["f_x"](np.array(_z), np.array(_c), np.array(_ex))
    _x_bound = _x_bound.T

    key, subkey = random.split(key)
    ex = eqs["ex"](subkey, num)
    key, subkey = random.split(key)
    ey = eqs["ey"](subkey, num)
    key, subkey = random.split(key)
    c = eqs["confounder"](subkey, num)
    key, subkey = random.split(key)
    z = eqs["f_z"](subkey, num)
    x = eqs["f_x"](np.array(z), np.array(c), np.array(ex))
    x1, x2 = x
    c1, c2 = c
    y = eqs["f_y"](x1, x2, c1, c2, ey)
    x = x.T
    z = np.array(z).T
    c = np.array(c).T
    ex = np.array(ex).T

    values = whiten({'x': x, 'y': y, 'z': z, 'confounder': c,
                   'ex': ex, 'ey': ey, 'x_temp': _x_bound})

  # Evaluate E[ Y | do(x^*)] empirically
    if(axis==0):
      xmin_b, xmax_b = np.min(_x_bound[:,axis]), np.max(_x_bound[:,axis])
      _diff = xmax_b - xmin_b
      xmin_b = xmin_b + _diff/10; xmax_b = xmax_b - _diff/10

      xmin, xmax = np.min(x[:,axis]), np.max(x[:,axis])
      x2_0 = np.mean(_x_bound[:,1])
      x1star = np.linspace(xmin, xmax, num_xstar)
      x2star = np.full_like(x1star, x2_0)
      x1star_bound = np.linspace(xmin_b, xmax_b, num_xstar_bound)
      x2star_bound = np.full_like(x1star_bound, x2_0)
      xstar_grid_plotting = x1star_bound
      xstar_plotting = x1star
    if(axis==1):
      xmin_b, xmax_b = np.min(_x_bound[:,axis]), np.max(_x_bound[:,axis])
      _diff = xmax_b - xmin_b
      xmin_b = xmin_b + _diff/10; xmax_b = xmax_b - _diff/10
      
      xmin, xmax = np.min(x[:,axis]), np.max(x[:,axis])
      x1_0 = np.mean(_x_bound[:,0])
      x2star = np.linspace(xmin, xmax, num_xstar)
      x1star = np.full_like(x2star, x1_0)
      x2star_bound = np.linspace(xmin_b, xmax_b, num_xstar_bound)
      x1star_bound = np.full_like(x2star_bound, x1_0)
      xstar_grid_plotting = x2star_bound
      xstar_plotting = x2star
    ystar = []
    for _ in range(500):
        key, subkey = random.split(key)
        tmpey = eqs["ey"](subkey, num_xstar)
        key, subkey = random.split(key)
        tmpconf = eqs["confounder"](subkey, num_xstar)
        tc1, tc2 = tmpconf
        tmp_ystar = whiten_with_mu_std(
          eqs["f_y"](x1star, x2star, tc1, tc2, tmpey), values["y_mu"], values["y_std"])
        ystar.append(tmp_ystar)
    ystar = np.array(ystar) 
    xstar = np.vstack((x1star, x2star)).T
    xstar_grid = np.vstack((x1star_bound, x2star_bound)).T
    xstar = whiten_with_mu_std(xstar, values["x_temp_mu"], values["x_temp_std"])
    xstar_plotting = get_nonconstant_axis(xstar)
    xstar_grid = whiten_with_mu_std(xstar_grid, values['x_temp_mu'], values['x_temp_std'])
    xstar_grid_plotting = get_nonconstant_axis(xstar_grid)
    for k in list(values.keys()):
      if(k.startswith("x_temp")): del values[k]

    return values, xstar, ystar, xstar_grid, xstar_grid_plotting, xstar_plotting

def get_synth_data_3d(
  key: np.ndarray,
  num: int,
  equations,
  axis,
  num_xstar_bound,
  num_xstar: int = 100,
  external_equations = None,
  disconnect_instrument: bool = False
):
    """Generate some synthetic data.

    Args:
      key: A JAX random key.
      num: The number of examples to generate.
      equations: Which structural equations to choose for x and y. Default: 1
      num_xstar: Size of grid for interventions on x.
      external_equations: A dictionary that must contain the keys 'f_x' and
        'f_y' mapping to callables as values that take two np.ndarrays as
        arguments and produce another np.ndarray. These are the structural
        equations for X and Y in the graph Z -> X -> Y.
        If this argument is not provided, the `equation` argument selects
        structural equations from the pre-defined dict `structural_equations`.
      disconnect_instrument: Whether to regenerate random (standard Gaussian)
        values for the instrument after the data has been generated. This
        serves for diagnostic purposes, i.e., looking at the same x, y data,

    Returns:
      A 3-tuple (values, xstar, ystar) consisting a dictionary `values`
          containing values for x, y, z, confounder, ex, ey as well as two
          array xstar, ystar containing values for the true cause-effect.
    """
    if external_equations is not None:
        eqs = external_equations
    else:
        eqs = structural_equations[equations]

    # Get x_value for xstar_bound to make sure we have the same each time
    key_bound = random.PRNGKey(42)
    key_bound, subkey_bound = random.split(key_bound)
    _ex = eqs["ex"](subkey_bound, 2000)
    key_bound, subkey_bound = random.split(key_bound)
    _c = eqs["confounder"](subkey_bound, 2000)
    key_bound, subkey_bound = random.split(key_bound)
    _z = eqs["f_z"](subkey_bound, 2000)
    _x_bound = eqs["f_x"](np.array(_z), np.array(_c), np.array(_ex))
    _x_bound = _x_bound.T

    key, subkey = random.split(key)
    ex = eqs["ex"](subkey, num)
    key, subkey = random.split(key)
    ey = eqs["ey"](subkey, num)
    key, subkey = random.split(key)
    c = eqs["confounder"](subkey, num)
    key, subkey = random.split(key)
    z = eqs["f_z"](subkey, num)
    x = eqs["f_x"](np.array(z), np.array(c), np.array(ex))
    x1, x2, x3 = x
    c1, c2, c3 = c
    y = eqs["f_y"](x1, x2, x3, c1, c2, c3, ey)
    x = x.T
    z = np.array(z).T
    c = np.array(c).T
    ex = np.array(ex).T

    values = whiten({'x': x, 'y': y, 'z': z, 'confounder': c,
                   'ex': ex, 'ey': ey, 'x_temp': _x_bound})

  # Evaluate E[ Y | do(x^*)] empirically
    if(axis==0):
      xmin_b, xmax_b = np.min(_x_bound[:,axis]), np.max(_x_bound[:,axis])
      _diff = xmax_b - xmin_b
      xmin_b = xmin_b + _diff/10; xmax_b = xmax_b - _diff/10

      xmin, xmax = np.min(x[:,axis]), np.max(x[:,axis])
      x2_0 = np.mean(_x_bound[:,1])
      x3_0 = np.mean(_x_bound[:,2])
      x1star = np.linspace(xmin, xmax, num_xstar)
      x2star = np.full_like(x1star, x2_0)
      x3star = np.full_like(x1star, x3_0)
      x1star_bound = np.linspace(xmin_b, xmax_b, num_xstar_bound)
      x2star_bound = np.full_like(x1star_bound, x2_0)
      x3star_bound = np.full_like(x1star_bound, x3_0)
      xstar_grid_plotting = x1star_bound
      xstar_plotting = x1star
    if(axis==1):
      xmin_b, xmax_b = np.min(_x_bound[:,axis]), np.max(_x_bound[:,axis])
      _diff = xmax_b - xmin_b
      xmin_b = xmin_b + _diff/10; xmax_b = xmax_b - _diff/10

      xmin, xmax = np.min(x[:,axis]), np.max(x[:,axis])
      x1_0 = np.mean(_x_bound[:,0])
      x3_0 = np.mean(_x_bound[:,2])
      x2star = np.linspace(xmin, xmax, num_xstar)
      x1star = np.full_like(x2star, x1_0)
      x3star = np.full_like(x2star, x3_0)
      x2star_bound = np.linspace(xmin_b, xmax_b, num_xstar_bound)
      x1star_bound = np.full_like(x2star_bound, x1_0)
      x3star_bound = np.full_like(x2star_bound, x3_0)
      xstar_grid_plotting = x2star_bound
      xstar_plotting = x2star
    if(axis==2):
      xmin_b, xmax_b = np.min(_x_bound[:,axis]), np.max(_x_bound[:,axis])
      _diff = xmax_b - xmin_b
      xmin_b = xmin_b + _diff/10; xmax_b = xmax_b - _diff/10

      xmin, xmax = np.min(x[:,axis]), np.max(x[:,axis])
      x1_0 = np.mean(_x_bound[:,0])
      x2_0 = np.mean(_x_bound[:,1])
      x3star = np.linspace(xmin, xmax, num_xstar)
      x1star = np.full_like(x3star, x1_0)
      x2star = np.full_like(x3star, x2_0)
      x3star_bound = np.linspace(xmin_b, xmax_b, num_xstar_bound)
      x1star_bound = np.full_like(x3star_bound, x1_0)
      x2star_bound = np.full_like(x3star_bound, x2_0)
      xstar_grid_plotting = x3star_bound
      xstar_plotting = x3star
    ystar = []
    for _ in range(500):
        key, subkey = random.split(key)
        tmpey = eqs["ey"](subkey, num_xstar)
        key, subkey = random.split(key)
        tmpconf = eqs["confounder"](subkey, num_xstar)
        tc1, tc2, tc3 = tmpconf
        tmp_ystar = whiten_with_mu_std(
          eqs["f_y"](x1star, x2star, x3star, tc1, tc2, tc3, tmpey), values["y_mu"], values["y_std"])
        ystar.append(tmp_ystar)
    ystar = np.array(ystar) 
    xstar = np.vstack((x1star, x2star, x3star)).T
    xstar = whiten_with_mu_std(xstar, values["x_temp_mu"], values["x_temp_std"])
    xstar_plotting = get_nonconstant_axis(xstar)
    xstar_grid = np.vstack((x1star_bound, x2star_bound, x3star_bound)).T
    xstar_grid = whiten_with_mu_std(xstar_grid, values['x_temp_mu'], values['x_temp_std'])
    xstar_grid_plotting = get_nonconstant_axis(xstar_grid)
    for k in list(values.keys()):
      if(k.startswith("x_temp")): del values[k]

    return values, xstar, ystar, xstar_grid, xstar_grid_plotting, xstar_plotting

def find_nearest(array, value):
    idx = np.argmin(np.abs(array - value))
    return array[idx]

def get_yeast_data(key, num_xstar=5):
  key_alpha = random.PRNGKey(42)
  alpha = random.uniform(key_alpha, (15,))

  csv_path = '../data/mandelian.csv'
  df_yeast = pd.read_csv(csv_path)

  z = np.array(df_yeast.iloc[:, :5].values)
  x = np.array(df_yeast.iloc[:, 5:].values)
  key, subkey = random.split(key)
  c = random.multivariate_normal(subkey, np.zeros(5), np.eye(5), (112,))
  key, subkey = random.split(key)
  ey = random.multivariate_normal(subkey, np.zeros(1), np.eye(1), (112,))
  y = np.zeros(112)
  for i in range(15):
      if i < 5:
          y = y + alpha[i] * (x[:, i] * c[:, i])
      else:
          y = y + (alpha[i] * x[:, i])
  y = y + ey.squeeze()

  xstar_base = np.array(x[y.argmin()].copy())
  xstar = [i for i in range(num_xstar)]
  ystars = []
  for i in range(num_xstar):
    key, subkey = random.split(key)
    noise = random.multivariate_normal(subkey, np.zeros(1), np.eye(1), (3,)).squeeze()
    idx = np.array([j for j in range(3*i, 3*i + 3)])
    xstar[i] = np.array(xstar_base.copy())
    xstar[i] = xstar[i].at[idx].add(noise/5)
  xstar = np.array(xstar)

  for i in range(500):
    key, subkey = random.split(key)
    c_ = random.multivariate_normal(subkey, np.zeros(5), np.eye(5), (num_xstar,))
    key, subkey = random.split(key)
    ey_ = random.multivariate_normal(subkey, np.zeros(1), np.eye(1), (num_xstar,))   
    y_ = np.zeros(num_xstar)
    for i in range(15):
      if i < 5:
        y_ = y_ + alpha[i] * (xstar[:, i] * c_[:, i])
      else:
        y_ = y_ + (alpha[i] * xstar[:, i])
    y_ = y_ + ey_.squeeze()
    ystars.append(y_)
  ystar = np.array(ystars).mean(0)

  values = whiten({'x': x, 'y': y, 'z': z, 'confounder': c,'ey': ey})

  xstar = whiten_with_mu_std(xstar, values["x_mu"], values["x_std"])
  xstar_plotting = np.array([i for i in range(5)])
  xstar_grid = xstar
  xstar_grid_plotting = xstar_plotting

  return(values, xstar, ystar, xstar_grid, xstar_grid_plotting, xstar_plotting)

def get_hllt(key: np.ndarray,
             num: int,
             rho: float,
             axis,
             num_xstar_bound,
             additive = True,
             num_xstar: int = 100):
  """Get simulated data newey powell from kernel IV paper.

  Inputs:
    axis: -1 is PCA direction
          0 is p
          1 is t
          2 is s
          -2 is highest density points

  Returns:
    values: A dictionary of x, y, z
    xstar: A list of xstar values along the 1-D line we will be using 
    ystar: A list of true ystar values for the xstar values from the output 'xstar'
    xstar_grid: xstar values where we need to calculate the bounds on the effect
    xstar_grid_plotting: List of labels for the 2-D plotting of xstar_grid vs bounds
    xstar_plotting: Labels for plotting xstar vs true effect
  """
  def psi(vals: np.ndarray):
    return 2. * ((vals - 5.) ** 4 / 600. + np.exp(- 4. * (vals - 5.) ** 2) +
                   vals / 10. - 2.)

  def hllt_true(pi: np.ndarray, ti: np.ndarray, si: np.ndarray):
    return 100. + (10. + pi) * si * psi(ti) - 2. * pi

  key, subkey = random.split(key)
  z = random.normal(subkey, shape=(num,))
  key, subkey = random.split(key)
  v = random.normal(subkey, shape=(num,))
  key, subkey = random.split(key)
  s = random.randint(subkey, shape=(num,), minval=1, maxval=8)
  key, subkey = random.split(key)
  t = random.uniform(key, shape=(num,)) * 10
  key, subkey = random.split(key)
  e = random.normal(subkey, shape=(num,)) * (1. - rho ** 2) + rho * v

  p = 25. + (z + 3.) * psi(t) + v

  def y_fun(p, t, s, e):
    if(additive == True):
      y = hllt_true(p, t, s) + e
    else:
      y = hllt_true(p, t, s) + e*p/4 
    return y
  y = y_fun(p, t, s, e)
  x = np.stack((p, t, s)).T

  #set_trace()
  if(axis == 0):
      key, subkey = random.split(key)
      p_vis = np.linspace(10, 25, num_xstar)
      t_vis = np.array([np.mean(t)]) #np.linspace(-0, 10, num_xstar)
      s_vis = np.array([1])
      v = random.normal(subkey, shape=(num_xstar,))
      e = random.normal(subkey, shape=(num_xstar,)) * (1. - rho ** 2) + rho * v

      pg, tg, sg = np.meshgrid(p_vis, t_vis, s_vis)
      xstar = np.stack((pg.ravel(), tg.ravel(), sg.ravel())).T
      ystar = y_fun(pg.ravel(), tg.ravel(), sg.ravel(), e.ravel())

      p_star = np.linspace(10, 25, num_xstar_bound)
      pg_b, tg_b, sg_b = np.meshgrid(p_star, t_vis, s_vis)
      xstar_grid = np.stack((pg_b.ravel(), tg_b.ravel(), sg_b.ravel())).T
  elif(axis == 1):
      key, subkey = random.split(key)
      p_vis = np.array([np.mean(p)])
      t_vis = np.linspace(-0, 10, num_xstar)
      s_vis = np.array([1])
      v = random.normal(subkey, shape=(num_xstar,))
      e = random.normal(subkey, shape=(num_xstar,)) * (1. - rho ** 2) + rho * v

      pg, tg, sg = np.meshgrid(p_vis, t_vis, s_vis)
      xstar = np.stack((pg.ravel(), tg.ravel(), sg.ravel())).T
      ystar = y_fun(pg.ravel(), tg.ravel(), sg.ravel(), e.ravel())

      t_star = np.linspace(-0, 10, num_xstar_bound)
      pg_b, tg_b, sg_b = np.meshgrid(p_vis, t_star, s_vis)
      xstar_grid = np.stack((pg_b.ravel(), tg_b.ravel(), sg_b.ravel())).T
  elif(axis == 2):
      key, subkey = random.split(key)
      p_vis = np.array([np.mean(p)]) #np.linspace(10, 25, num_xstar)
      t_vis = np.array([np.mean(t)]) #np.linspace(-0, 10, num_xstar)
      s_vis = np.arange(1, 8)
      v = random.normal(subkey, shape=(len(s_vis),))
      e = random.normal(subkey, shape=(len(s_vis),)) * (1. - rho ** 2) + rho * v

      pg, tg, sg = np.meshgrid(p_vis, t_vis, s_vis)
      xstar = np.stack((pg.ravel(), tg.ravel(), sg.ravel())).T
      ystar = y_fun(pg.ravel(), tg.ravel(), sg.ravel(), e.ravel())

      #idx = onp.random.choice(xstar.shape[0], num_xstar_bound, replace=False)
      #idx.sort()
      xstar_grid = xstar
  elif(axis==-2):
    kde = KernelDensity(kernel='gaussian', bandwidth=0.05).fit(x)
    neg_log_density = -kde.score_samples(x)
    sorted_indices = np.argsort(neg_log_density)[:200]
    indices_to_keep = np.array([i*len(sorted_indices)//(num_xstar_bound) for i in range(num_xstar_bound)])
    indices_to_keep = sorted_indices[indices_to_keep]
    xstar_grid = x[indices_to_keep]
    xstar_grid_plotting = np.array([neg_log_density[i] for i in indices_to_keep])
    xstar = xstar_grid
    xstar_plotting = xstar_grid_plotting

    v = random.normal(subkey, shape=(num_xstar_bound,))
    e = random.normal(subkey, shape=(num_xstar_bound,)) * (1. - rho ** 2) + rho * v
    ystar = y_fun(xstar[:,0], xstar[:,1], xstar[:,2], e.ravel())

  values = whiten({'x': x, 'y': y, 'z': z, 'ex': e, 'ey': e})

  if(axis != -1):
    xstar = whiten_with_mu_std(xstar, values['x_mu'], values['x_std'])
    ystar = whiten_with_mu_std(ystar, values['y_mu'], values['y_std'])
    xstar_grid = whiten_with_mu_std(xstar_grid, values['x_mu'], values['x_std'])
    if(axis != -2):
      xstar_plotting = xstar[:, axis]
      xstar_grid_plotting = xstar_grid[:,axis]
  elif(axis == -1):
    # key, subkey = random.split(key)
    # v = random.normal(subkey, shape=(num_xstar,))
    # e = random.normal(subkey, shape=(num_xstar,)) * (1. - rho ** 2) + rho * v
    # pca = PCA(n_components=2)
    # pca.fit_transform(values['x'][:,:2])
    # pc = pca.components_[0]
    # x_0 = np.mean(values['x'][:,:2], 0)
    # alphas = np.linspace(-0.5, 0.5, num_xstar)
    # xstar = np.array([x_0 + alp*pc for alp in alphas])
    # s = sorted(set(values['x'][:,2]))[3]
    # s_all = np.array([s] * len(xstar))[:,np.newaxis]
    # xstar = np.append(xstar, s_all, 1)
    # xstar_or = unwhiten_with_mu_std(xstar, values['x_mu'], values['x_std'])
    # ystar = y_fun(xstar_or[:,0].ravel(), xstar_or[:,1].ravel(), xstar_or[:,2].ravel(), e.ravel())

    # xstar = whiten_with_mu_std(xstar_or, values['x_mu'], values['x_std'])
    # ystar = whiten_with_mu_std(ystar, values['y_mu'], values['y_std'])

    # alphas_star = np.linspace(-0.5, 0.5, num_xstar_bound)
    # xstar_grid = np.array([x_0 + alp*pc for alp in alphas_star])
    # s_grid = np.array([s] * len(xstar_grid))[:,np.newaxis]
    # xstar_grid = np.append(xstar_grid, s_grid, 1)
    # xstar_grid_plotting = alphas_star
    # xstar_plotting = alphas

    key, subkey = random.split(key)
    v = random.normal(subkey, shape=(num_xstar,))
    e = random.normal(subkey, shape=(num_xstar,)) * (1. - rho ** 2) + rho * v
    pca = PCA(n_components=2)
    pca.fit_transform(values['x'])
    pc = pca.components_[0]
    x_0 = np.mean(values['x'], 0)
    alphas = np.linspace(-0.5, 0.5, num_xstar)
    xstar = np.array([x_0 + alp*pc for alp in alphas])
    xstar_or = unwhiten_with_mu_std(xstar, values['x_mu'], values['x_std'])
    ystar = y_fun(xstar_or[:,0].ravel(), xstar_or[:,1].ravel(), xstar_or[:,2].ravel(), e.ravel())

    xstar = whiten_with_mu_std(xstar_or, values['x_mu'], values['x_std'])
    ystar = whiten_with_mu_std(ystar, values['y_mu'], values['y_std'])

    alphas_star = np.linspace(-0.5, 0.5, num_xstar_bound)
    xstar_grid = np.array([x_0 + alp*pc for alp in alphas_star])
    xstar_grid_plotting = alphas_star
    xstar_plotting = alphas

  values['confounder'] = None
  return values, xstar, ystar, xstar_grid, xstar_grid_plotting, xstar_plotting

def fd_get_synth_data_2d(
  key: np.ndarray,
  num: int,
  equations,
  axis,
  num_xstar_bound,
  num_xstar: int = 100,
  external_equations = None
):
    """
    Get 2d polynomial data for the leaky front door
    """
    if external_equations is not None:
        eqs = external_equations
    else:
        eqs = structural_equations_fd[equations]
    
    # Get x_value for xstar_bound to make sure we have the same each time
    
    key_bound = random.PRNGKey(40)
    key_bound, subkey_bound = random.split(key_bound)
    _ex = eqs["ex"](key_bound, 2000)
    key_bound, subkey_bound = random.split(key_bound)
    _u = eqs["u_xy"](key_bound, 2000)
    _x_bound = eqs["f_x"](np.array(_u), np.array(_ex))
    _x_bound = _x_bound.T
    #import pdb; pdb.set_trace()

    key, subkey = random.split(key)
    ex = eqs["ex"](subkey, num)
    key, subkey = random.split(key)
    em = eqs["em"](subkey, num)
    key, subkey = random.split(key)
    ey = eqs["ey"](subkey, num)
    key, subkey = random.split(key)
    c = eqs["confounder"](subkey, num)
    key, subkey = random.split(key)
    u = eqs["u_xy"](subkey, num)
    key, subkey = random.split(key)
    x = eqs["f_x"](np.array(u), np.array(ex))
    m = eqs["f_m"](np.array(x), np.array(c), np.array(em))
    c1, c2 = c
    m1, m2 = m
    u1, u2 = u
    y = eqs["f_y"](m1, m2, c1, c2, u1, u2, np.array(ey))
    x = x.T
    m = m.T
    c = np.array(c).T
    ex = np.array(ex).T

    values = whiten({'x': x, 'y': y, 'm': m, 'x_temp': _x_bound})

  # Evaluate E[ Y | do(x^*)] empirically
    if(axis==0):
        xmin_b, xmax_b = np.min(_x_bound[:,axis]), np.max(_x_bound[:,axis])
        _diff = xmax_b - xmin_b
        xmin_b = xmin_b + _diff/10; xmax_b = xmax_b - _diff/10

        xmin, xmax = np.min(x[:,axis]), np.max(x[:,axis])
        x2_0 = np.mean(_x_bound[:,1])
        x1star = np.linspace(xmin, xmax, num_xstar)
        x2star = np.full_like(x1star, x2_0)
        x1star_bound = np.linspace(xmin_b, xmax_b, num_xstar_bound)
        x2star_bound = np.full_like(x1star_bound, x2_0)
        xstar_grid_plotting = x1star_bound
        xstar_plotting = x1star
    if(axis==1):
        xmin_b, xmax_b = np.min(_x_bound[:,axis]), np.max(_x_bound[:,axis])
        _diff = xmax_b - xmin_b
        xmin_b = xmin_b + _diff/10; xmax_b = xmax_b - _diff/10

        xmin, xmax = np.min(x[:,axis]), np.max(x[:,axis])
        x1_0 = np.mean(_x_bound[:,0])
        x2star = np.linspace(xmin, xmax, num_xstar)
        x1star = np.full_like(x2star, x1_0)
        x2star_bound = np.linspace(xmin_b, xmax_b, num_xstar_bound)
        x1star_bound = np.full_like(x2star_bound, x1_0)
        xstar_grid_plotting = x2star_bound
        xstar_plotting = x2star
    ystar = []
    for _ in range(15000):
        key, subkey = random.split(key)
        tmpey = eqs["ey"](subkey, num_xstar)
        key, subkey = random.split(key)
        tmpem = eqs["em"](subkey, num_xstar)
        key, subkey = random.split(key)
        tmpconf = eqs["confounder"](subkey, num_xstar)
        key, subkey = random.split(key)
        tmp_u = eqs["u_xy"](subkey, num_xstar)
        tmp_m = eqs["f_m"](np.vstack((x1star, x2star)), np.array(tmpconf), np.array(tmpem))
        tc1, tc2 = tmpconf
        tm1, tm2 = tmp_m
        tu1, tu2 = tmp_u
        tmp_ystar = whiten_with_mu_std(
          eqs["f_y"](tm1, tm2, tc1, tc2, tu1, tu2, tmpey), values["y_mu"], values["y_std"])
        ystar.append(tmp_ystar)
    ystar = np.array(ystar) 
    xstar = np.vstack((x1star, x2star)).T
    xstar_grid = np.vstack((x1star_bound, x2star_bound)).T
    xstar = whiten_with_mu_std(xstar, values["x_temp_mu"], values["x_temp_std"])
    xstar_plotting = get_nonconstant_axis(xstar)
    xstar_grid = whiten_with_mu_std(xstar_grid, values['x_temp_mu'], values['x_temp_std'])
    xstar_grid_plotting = get_nonconstant_axis(xstar_grid)
    for k in list(values.keys()):
        if(k.startswith("x_temp")): del values[k]

    return values, xstar, ystar.mean(0), xstar_grid, xstar_grid_plotting, xstar_plotting


def get_mz(key: np.ndarray,
           num: int,
           axis,
           num_xstar_bound,
           num_xstar: int = 100):
    """Get simulated data newey powell from kernel IV paper."""
    
    key1, key2 = random.split(key)
    z1, z2 = random.normal(key1, (num,)), random.normal(key2, (num,))
    
    key, subkey = random.split(key)

    key, subkey = random.split(key)
    c = random.normal(subkey, shape=(num,))

    x1 = 2*z1 - z2*c/5 + 1
    x2 = z1 + 3*z2 + c/2
    key, subkey = random.split(key)
    x3 = 5*random.normal(subkey, shape=(num,))
    key, subkey = random.split(key)
    x4 = np.e**z1 + 5

    def y_fun(x1, x2, x3, x4, c):
        return x1**2/2 + (c/4+2)*np.abs(x2) + np.log(x4)*x3/4 

    y = y_fun(x1, x2, x3, x4, c)
    x = np.stack((x1, x2, x3, x4)).T
    z = np.stack((z1, z2)).T

    #set_trace()
    if(axis == 0):
        key, subkey = random.split(key)
        x1_v = np.linspace(-5, 6, num_xstar)
        x2_v = np.array([np.mean(x2)]) #np.linspace(-0, 10, num_xstar)
        x3_v = np.array([np.mean(x3)])
        x4_v = np.array([np.mean(x4)])
        c = random.normal(subkey, shape=(num_xstar,))
        
        x1g, x2g, x3g, x4g = np.meshgrid(x1_v, x2_v, x3_v, x4_v)
        xstar = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
        ystar = y_fun(x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel(), c.ravel())

        x1_star = np.linspace(-5, 6, num_xstar_bound)
        x1g, x2g, x3g, x4g = np.meshgrid(x1_star, x2_v, x3_v, x4_v)
        xstar_grid = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
    elif(axis == 1):
        key, subkey = random.split(key)
        x1_v = np.array([np.mean(x1)])
        x2_v = np.linspace(-13, 13, num_xstar) #np.linspace(-0, 10, num_xstar)
        x3_v = np.array([np.mean(x3)])
        x4_v = np.array([np.mean(x4)])
        c = random.normal(subkey, shape=(num_xstar,))
        
        x1g, x2g, x3g, x4g = np.meshgrid(x1_v, x2_v, x3_v, x4_v)
        xstar = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
        ystar = y_fun(x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel(), c.ravel())

        x2_star = np.linspace(-13, 13, num_xstar_bound)
        x1g, x2g, x3g, x4g = np.meshgrid(x1_v, x2_star, x3_v, x4_v)
        xstar_grid = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
    elif(axis == 2):
        key, subkey = random.split(key)
        x1_v = np.array([np.mean(x1)])
        x2_v = np.array([np.mean(x2)]) #np.linspace(-0, 10, num_xstar)
        x3_v = np.linspace(-15, 15, num_xstar)
        x4_v = np.array([np.mean(x4)])
        c = random.normal(subkey, shape=(num_xstar,))
        
        x1g, x2g, x3g, x4g = np.meshgrid(x1_v, x2_v, x3_v, x4_v)
        xstar = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
        ystar = y_fun(x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel(), c.ravel())

        x3_star = np.linspace(-15, 15, num_xstar_bound)
        x1g, x2g, x3g, x4g = np.meshgrid(x1_v, x2_v, x3_star, x4_v)
        xstar_grid = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
    elif(axis == 3):
        key, subkey = random.split(key)
        x1_v = np.array([np.mean(x1)])
        x2_v = np.array([np.mean(x2)]) #np.linspace(-0, 10, num_xstar)
        x3_v = np.array([np.mean(x3)])
        x4_v = np.linspace(5, 7.5, num_xstar)
        c = random.normal(subkey, shape=(num_xstar,))
        
        x1g, x2g, x3g, x4g = np.meshgrid(x1_v, x2_v, x3_v, x4_v)
        xstar = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
        ystar = y_fun(x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel(), c.ravel())

        x4_star = np.linspace(5, 7.5, num_xstar_bound)
        x1g, x2g, x3g, x4g = np.meshgrid(x1_v, x2_v, x3_v, x4_star)
        xstar_grid = np.stack((x1g.ravel(), x2g.ravel(), x3g.ravel(), x4g.ravel())).T
    
    values = whiten({'x': x, 'y': y, 'z': z, 'ex': c, 'ey': c})

    if(axis != -1):
        xstar = whiten_with_mu_std(xstar, values['x_mu'], values['x_std'])
        ystar = whiten_with_mu_std(ystar, values['y_mu'], values['y_std'])

        xstar_plotting = xstar[:, axis]
        xstar_grid = whiten_with_mu_std(xstar_grid, values['x_mu'], values['x_std'])
        xstar_grid_plotting = xstar_grid[:,axis]
    elif(axis == -1):
        key, subkey = random.split(key)
        c = random.normal(subkey, shape=(num_xstar,))
        pca = PCA(n_components=2)
        pca.fit_transform(values['x'])
        pc = pca.components_[0]
        x_0 = np.mean(values['x'], 0)
        alphas = np.linspace(-1, 1, num_xstar)
        xstar = np.array([x_0 + alp*pc for alp in alphas])
        xstar_or = unwhiten_with_mu_std(xstar, values['x_mu'], values['x_std'])
        ystar = y_fun(xstar_or[:,0].ravel(), xstar_or[:,1].ravel(), xstar_or[:,2].ravel(),\
                     xstar_or[:,3].ravel(), c.ravel())

        xstar = whiten_with_mu_std(xstar_or, values['x_mu'], values['x_std'])
        ystar = whiten_with_mu_std(ystar, values['y_mu'], values['y_std'])

        alphas_star = np.linspace(-1, 1, num_xstar_bound)
        xstar_grid = np.array([x_0 + alp*pc for alp in alphas_star])
        xstar_grid_plotting = alphas_star
        xstar_plotting = alphas

    values['confounder'] = None
    return values, xstar, ystar, xstar_grid, xstar_grid_plotting, xstar_plotting

