import functools

import haiku as hk
from jax.interpreters import xla
from jax import lax, vmap
from jax import numpy as jnp
import sys
import jax

sys.path.append("google-research/")


he_normal = hk.initializers.VarianceScaling(2.0, "fan_in", "truncated_normal")

def batch_split_axis(data, n_split):
  """Reshapes batch to have first axes size equal n_split."""
  x = data
  n = x.shape[0]
  n_new = n / n_split
  assert n_new == int(n_new), (
      "First axis cannot be split: batch dimension was {} when "
      "n_split was {}.".format(x.shape[0], n_split))
  n_new = int(n_new)
  return x.reshape([n_split, n_new, *x.shape[1:]])


class FeatureResponseNorm(hk.Module):

  def __init__(self, eps=1e-6, name="frn"):
    super().__init__(name=name)
    self.eps = eps

  def __call__(self, x, **unused_kwargs):
    del unused_kwargs
    par_shape = (1, 1, 1, x.shape[-1])  # [1,1,1,C]
    tau = hk.get_parameter("tau", par_shape, x.dtype, init=jnp.zeros)
    beta = hk.get_parameter("beta", par_shape, x.dtype, init=jnp.zeros)
    gamma = hk.get_parameter("gamma", par_shape, x.dtype, init=jnp.ones)
    nu2 = jnp.mean(jnp.square(x), axis=[1, 2], keepdims=True)
    x = x * jax.lax.rsqrt(nu2 + self.eps)
    y = gamma * x + beta
    z = jnp.maximum(y, tau)
    return z


def _resnet_layer(inputs,
                  num_filters,
                  normalization_layer,
                  shift_match_layer,
                  shift_match_mode=None,
                  shift_match_before_act=False,
                  kernel_size=3,
                  strides=1,
                  activation=lambda x: x,
                  use_bias=True,
                  is_training=True):
  x = inputs
  x = hk.Conv2D(
      num_filters,
      kernel_size,
      stride=strides,
      padding="same",
      w_init=he_normal,
      with_bias=use_bias)(
          x)
  x = normalization_layer()(x, is_training=is_training)
  if not is_training:
    x = x.squeeze()
  if shift_match_before_act:
    x = shift_match_layer()(x, shift_match_mode)
    x = activation(x)
  else:
    x = activation(x)
    x = shift_match_layer()(x, shift_match_mode)
  return x


def make_resnet_fn(
    num_classes,
    depth,
    normalization_layer,
    shift_match_layer,
    width=16,
    use_bias=True,
    activation=jax.nn.relu,
):
  num_res_blocks = (depth - 2) // 6
  if (depth - 2) % 6 != 0:
    raise ValueError("depth must be 6n+2 (e.g. 20, 32, 44).")

  def forward(batch, is_training, shift_match_mode=None, input_only=False,
              feature_only=False, shift_match_before_act=False):
    num_filters = width
    x = batch
    if feature_only:
      x = x
    else:
      x = shift_match_layer()(x, shift_match_mode)
    if input_only:
      shift_match_mode=None
    x = _resnet_layer(
        x,
        shift_match_mode=shift_match_mode,
        shift_match_before_act=shift_match_before_act,
        num_filters=num_filters,
        activation=activation,
        use_bias=use_bias,
        normalization_layer=normalization_layer,
        shift_match_layer=shift_match_layer)
    for stack in range(3):
      for res_block in range(num_res_blocks):
        strides = 1
        if stack > 0 and res_block == 0:  # first layer but not first stack
          strides = 2  # downsample
        y = _resnet_layer(
            x,
            shift_match_mode=shift_match_mode,
            shift_match_before_act=shift_match_before_act,
            num_filters=num_filters,
            strides=strides,
            activation=activation,
            use_bias=use_bias,
            is_training=is_training,
            normalization_layer=normalization_layer,
            shift_match_layer=shift_match_layer)
        y = _resnet_layer(
            y,
            shift_match_mode=shift_match_mode,
            shift_match_before_act=shift_match_before_act,
            num_filters=num_filters,
            use_bias=use_bias,
            is_training=is_training,
            normalization_layer=normalization_layer,
            shift_match_layer=shift_match_layer)
        if stack > 0 and res_block == 0:  # first layer but not first stack
          # linear projection residual shortcut connection to match changed dims
          x = _resnet_layer(
              x,
              shift_match_mode=shift_match_mode,
              shift_match_before_act=shift_match_before_act,
              num_filters=num_filters,
              kernel_size=1,
              strides=strides,
              use_bias=use_bias,
              is_training=is_training,
              normalization_layer=normalization_layer,
              shift_match_layer=shift_match_layer)
        x = activation(x + y)
      num_filters *= 2
    x = hk.AvgPool((8, 8, 1), 8, "VALID")(x)
    x = hk.Flatten()(x)
#     x = shift_match_builder()(x, shift_match_mode)
    logits = hk.Linear(num_classes, w_init=he_normal)(x)
    return logits

  return forward


def make_resnet20_frn_fn(match_type, activation=jax.nn.swish,
                          normalization_layer=FeatureResponseNorm):
  num_classes = 10
  return make_resnet_fn(
      num_classes,
      depth=20,
      normalization_layer=normalization_layer,
      shift_match_layer=functools.partial(shift_match_builder, match_type),
      activation=activation)


class ShiftMatch(hk.Module):
  def __init__(self, match_type, name="shift_match"):
    assert match_type in [
        'None',
        'full',
        'feature',
        'feature_cov',
        'channel_wise_joint',
        'channel_wise_sep',
        'spatial_joint',
        'spatial_joint_cov',
        'spatial_sep',
        'spatial_sep_cov',
        'fft_spatial',
        'channel_wise_sep_cov',
        'batch_norm',
        'spatial_sep_cov_mean',
        'channel_wise_sep_cov_mean'
    ]
    self.match_type = match_type
    self.sqrt_cov_h_train = None
    super().__init__(name=name)

  def _reshape(self, x, match_type, mode):
    '''
    Assume x is of shape:
    (N, H, W, C)
    '''
    N, H, W, C = x.shape
    batch_size = x.shape[0]
    if match_type == 'full':
      return x.reshape(N, -1)
    if match_type == 'batch_norm':
      return x
    if match_type == 'fft_spatial':
      x = jnp.transpose(x, (0, 3, 1, 2)) # (N, C, H, W)
      return x
    elif 'spatial_joint' in match_type:
      x = jnp.transpose(x, (0, 3, 1, 2)) # (N, C, H, W)
      return x.reshape(N * C, H * W)
    elif 'spatial_sep' in match_type:
      assert H == W
      if mode == 'match':
        # Fake data, only shape matter in match mode
        return x.transpose(0,2,3,1).reshape(N * W * C, H)
      x_H = x.transpose(0,2,3,1).reshape(N * W * C, H)
      x_W = x.transpose(0,1,3,2).reshape(N * H * C, W)
      return jnp.stack([x_H, x_W]) # Assume W == H
    elif match_type == 'channel_wise_joint':
      # Move last axis to the front
      x = jnp.transpose(x, (3, 0, 1, 2)) # (C, N, H, W)
      return x.reshape(C, N, H * W)
    elif 'channel_wise_sep' in match_type:
      if mode == 'match':
        return x.reshape(C, N * W, H)
      x_H = x.transpose(3, 0, 2, 1).reshape(C, N * W, H)
      x_W = x.transpose(3, 0, 1, 2).reshape(C, N * H, W)
      return jnp.stack([x_H, x_W])
    elif 'feature' in match_type:
      return x.reshape(-1, C)

  def _shape_back(self, x, origin_shape, match_type):
    N, H, W, C = origin_shape
    if match_type == 'fft_spatial':
      return x.transpose((0, 2, 3, 1))
    if match_type == 'channel_wise_joint':
      return x.reshape(C, N, H, W).transpose((1, 2, 3, 0))
    elif 'spatial_joint' in match_type:
      return x.reshape(N, C, H, W).transpose((0, 2, 3, 1))
    else:
      return x.reshape(origin_shape)

  def _matrix_sqrt(self, cov, neg=False):
    e, v = lax.linalg.eigh(cov)
    # print(v)
    v = jnp.maximum(v, jnp.ones_like(v) * 1e-10)
    v = jnp.diag(v) if not neg else jnp.diag(1/v)
    return e @ v ** 0.5 @ e.T

  def _batch_outer_product(self, x):
    assert len(x.shape) == 2
    return vmap(lambda _x: jnp.outer(_x, _x))(x)

  def _fft_match(self, h_test, spec_train_sqr):
    fft_test = []
    batch_num = 4
    batch_size = h_test.shape[0] // batch_num
    h_test = h_test.reshape((batch_num, batch_size, *h_test.shape[1:]))
    for i in range(batch_num):
      fft_test.append(jnp.fft.fft2(h_test[i]))
    fft_test = jnp.concatenate(fft_test, 0)
    # fft_test = jnp.fft.fft2(h_test)
    spec_test = jnp.sqrt(jnp.mean(jnp.abs(fft_test)**2, (0))) + 1E-10
    matched_fft_feature = (fft_test / spec_test) * jnp.sqrt(spec_train_sqr)
    ifft_test = []
    matched_fft_feature = matched_fft_feature.reshape((batch_num, batch_size,
                  *matched_fft_feature.shape[1:]))
    for i in range(batch_num):
      ifft_test.append(jnp.fft.ifft2(matched_fft_feature[i]))
    out = jnp.real(jnp.concatenate(ifft_test, 0))
    return out


  def _match(self, h_test, cov_x_train):
    n_devices = jax.local_device_count()
    match_ind_channel = 'wise' in self.match_type
    def _inner(h_test, sqrt_cov_h_train):
      cov_h_test = h_test.T @ h_test / len(h_test)
      neg_sqrt_cov_h_test = self._matrix_sqrt(cov_h_test, neg=True)
      if n_devices > 1:
        out = pmap(
          lambda h: h @ neg_sqrt_cov_h_test @ sqrt_cov_h_train
          )(batch_split_axis(h_test, n_devices))
        return out.reshape(
            (out.shape[0] * out.shape[1],) + out.shape[2:]
            )
      return h_test @ neg_sqrt_cov_h_test @ sqrt_cov_h_train
    if self.sqrt_cov_h_train:
      # Cache the sqrt of the training covariance matrix.
      sqrt_cov_h_train = self.sqrt_cov_h_train
    else:
      if match_ind_channel:
        sqrt_cov_h_train = vmap(self._matrix_sqrt)(cov_x_train)
        self.sqrt_cov_h_train = sqrt_cov_h_train
      else:
        sqrt_cov_h_train = self._matrix_sqrt(cov_x_train)
        self.sqrt_cov_h_train = sqrt_cov_h_train

    match_func = vmap(_inner) if match_ind_channel else _inner
    return match_func(h_test, self.sqrt_cov_h_train)

  def __call__(self, x, mode=None):
    '''
    Entry point for shift_match invoke.
    '''
    if self.match_type == 'None':
      return x
    assert len(x.shape) == 4 # This module currently only supports CNN feature.
    C = x.shape[-1]
    _x = self._reshape(x, self.match_type, mode)
    D = _x.shape[-1]
    old_shape = x.shape
    # Initialize the covariance matrix.
    cov_counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.zeros)
    if self.match_type == 'batch_norm':
      # mean = hk.get_state('mean', shape=x.shape[1:3], dtype=jnp.float32, init=jnp.zeros)
      # var = hk.get_state('var', shape=x.shape[1:3], dtype=jnp.float32, init=jnp.zeros)
      mean = hk.get_state('mean', shape=(x.shape[-1],), dtype=jnp.float32, init=jnp.zeros)
      var = hk.get_state('var', shape=(x.shape[-1],), dtype=jnp.float32, init=jnp.zeros)
    elif self.match_type == 'fft_spatial':
      spec_train_sqr = hk.get_state('spec_train_sqr', shape=(D, D), dtype=jnp.float32, init=jnp.zeros)
    elif 'channel_wise' in self.match_type:
      if 'sep' in self.match_type:
        cov_H = hk.get_state('cov_H', shape=(C, D, D), dtype=jnp.float32, init=jnp.zeros)
        mu_H = hk.get_state('mu_H', shape=(C, D), dtype=jnp.float32, init=jnp.zeros)
        cov_W = hk.get_state('cov_W', shape=(C, D, D), dtype=jnp.float32, init=jnp.zeros)
        mu_W = hk.get_state('mu_W', shape=(C, D), dtype=jnp.float32, init=jnp.zeros)
        cov = (cov_H, cov_W)
        mu = (mu_H, mu_W)
      else:
        mu = hk.get_state('mu', shape=(C, D), dtype=jnp.float32, init=jnp.zeros)
        cov = hk.get_state('cov', shape=(C, D, D), dtype=jnp.float32, init=jnp.zeros)
    elif 'spatial_sep' in self.match_type:
      cov_H = hk.get_state('cov_H', shape=(D, D), dtype=jnp.float32, init=jnp.zeros)
      mu_H = hk.get_state('mu_H', shape=(D), dtype=jnp.float32, init=jnp.zeros)
      cov_W = hk.get_state('cov_W', shape=(D, D), dtype=jnp.float32, init=jnp.zeros)
      mu_W = hk.get_state('mu_W', shape=(D), dtype=jnp.float32, init=jnp.zeros)
      cov = (cov_H, cov_W)
      mu = (mu_H, mu_W)
    else:
      cov = hk.get_state('cov', shape=(D, D), dtype=jnp.float32, init=jnp.zeros)
      mu = hk.get_state('mu', shape=(D), dtype=jnp.float32, init=jnp.zeros)
    if not mode:
      return x
    elif mode == 'acc':
      if self.match_type == 'batch_norm':
        # x = x.transpose(0,3,1,2)
        batch_size = x.shape[0] * x.shape[1] * x.shape[2]
        mu = jnp.mean(x, [0,1,2], keepdims=True)
        mu_2 = jnp.mean(x ** 2, [0,1,2], keepdims=True)
        # print(mu)
        hk.set_state('mean', mean + (mu - mean) / (cov_counter/batch_size + 1))
        hk.set_state('var', var + (mu_2 - var) / (cov_counter/batch_size + 1) )
        hk.set_state('counter', cov_counter + batch_size)
        # x = x.transpose(0,2,3,1)
      elif self.match_type == 'fft_spatial':
        assert _x.shape[-1] == _x.shape[-2]
        fft_train = jnp.fft.fft2(_x)
        batch_size = _x.shape[0]
        hk.set_state('spec_train_sqr',
          (jnp.sum(jnp.abs(fft_train)**2, 0) - 
          batch_size * spec_train_sqr) / (cov_counter + batch_size) + spec_train_sqr
        )
        hk.set_state('counter', cov_counter + batch_size)
        return x
      elif 'channel_wise' in self.match_type:
        # Get cov_train
        if 'sep' in self.match_type:
          cov_H, cov_W = cov
          mu_H, mu_W = mu
          x_H, x_W = _x
          batch_size = x_H.shape[1]
          new_cov_H = vmap(lambda x: x.T @ x)(x_H)
          new_mu_H = x_H.sum(1)
          new_cov_W = vmap(lambda x: x.T @ x)(x_W)
          new_mu_W = x_W.sum(1)
          hk.set_state('cov_H', cov_H + new_cov_H)
          hk.set_state('mu_H', mu_H + (new_mu_H - mu_H * batch_size) / (cov_counter + batch_size))
          hk.set_state('cov_W', cov_W + new_cov_W)
          hk.set_state('mu_W', mu_W + (new_mu_W - mu_W * batch_size) / (cov_counter + batch_size))
          hk.set_state('counter', cov_counter + x_H.shape[1])
        else:
          new_cov = vmap(lambda x: x.T @ x)(_x)  # (C, D, D), vmap over the channel dim
          hk.set_state('cov', cov + new_cov)
          hk.set_state('counter', cov_counter + _x.shape[1])
      elif 'spatial_sep' in self.match_type:
        cov_H, cov_W = cov
        mu_H, mu_W = mu
        x_H, x_W = _x
        new_mu_H = x_H.sum(0)
        new_mu_W = x_W.sum(0)
        new_cov_H = x_H.T @ x_H
        new_cov_W = x_W.T @ x_W
        batch_size = x_H.shape[0]
        hk.set_state('cov_H', cov_H + new_cov_H)
        hk.set_state('mu_H', mu_H + (new_mu_H - mu_H * batch_size) / (cov_counter + batch_size))
        hk.set_state('cov_W', cov_W + new_cov_W)
        hk.set_state('mu_W', mu_W + (new_mu_W - mu_W * batch_size) / (cov_counter + batch_size))
        hk.set_state('counter', cov_counter + len(x_H))
      else:
        new_cov = _x.T @ _x  # (D, D)
        new_mu = _x.sum(0)
        batch_size = _x.shape[0]
        hk.set_state('cov',
              cov + new_cov)
        hk.set_state('mu', mu + (new_mu - mu * batch_size)/ (cov_counter + batch_size))
        hk.set_state('counter', cov_counter + len(_x))
      return x
    elif mode == 'match':
      if not self.match_type:
        return x
      # Match test feature!
      if self.match_type == 'batch_norm':
        N, H, W, C = x.shape
        # train_std = jnp.sqrt(var - (mean ** 2))[None,...][..., None]
        # train_mean = mean[None, ...][..., None]
        # test_std = jnp.std(x, [0,3])[None,...][..., None]
        # test_mean = jnp.mean(x, [0,3])[None,...][..., None]
        # print(train_std.mean())
        # print(test_std.mean())
        train_std = jnp.sqrt(jnp.maximum(var - (mean ** 2), 0) + 1e-6)
        train_mean = mean
        test_std = jnp.maximum(jnp.std(x, [0,1,2], keepdims=True), 0) + 1e-6
        test_mean = jnp.mean(x, [0,1,2], keepdims=True)
        x = (x - test_mean) * train_std / (test_std) + train_mean
        return x
      elif self.match_type == 'fft_spatial':
        assert _x.shape[-1] == _x.shape[-2]
        matched_x = self._fft_match(_x, spec_train_sqr)
        return self._shape_back(matched_x, old_shape, self.match_type)
      elif self.match_type == 'spatial_sep':
        N, H, W, C = old_shape
        cov_H, cov_W = cov
        x = x.transpose(0, 2, 3, 1).reshape(N * W * C, H) # (N, W, C, H)
        x = self._match(x, cov_H / cov_counter).reshape(N, W, C, H).transpose(0, 3, 2, 1) # (N, H, C, W)
        self.sqrt_cov_h_train = None
        x = x.reshape(N * H * C, W)
        x = self._match(x, cov_W / cov_counter).reshape(N, H, C, W).transpose(0, 1, 3, 2)
        return x
      elif self.match_type == 'spatial_sep_cov':
        N, H, W, C = old_shape
        cov_H, cov_W = cov # Cov matrix
        mu_H, mu_W = mu
        x = x.transpose(0, 2, 3, 1).reshape(N * W * C, H) # (N, W, C, H)
        mu_test = x.mean(0)
        x = self._match(x - mu_test,
        cov_H
        # / cov_counter - jnp.outer(mu_H, mu_H)
        ) + mu_test
        x = x.reshape(N, W, C, H).transpose(0, 3, 2, 1) # (N, H, C, W)
        self.sqrt_cov_h_train = None
        x = x.reshape(N * H * C, W)
        mu_test = x.mean(0)
        x = self._match(x - mu_test,
        cov_W
        # / cov_counter - jnp.outer(mu_W, mu_W)
        )
        x = x + mu_test
        x = x.reshape(N, H, C, W).transpose(0, 1, 3, 2)
        return x
      elif self.match_type == 'spatial_sep_cov_mean':
        N, H, W, C = old_shape
        cov_H, cov_W = cov # Cov matrix
        mu_H, mu_W = mu
        x = x.transpose(0, 2, 3, 1).reshape(N * W * C, H) # (N, W, C, H)
        mu_test = x.mean(0)
        x = self._match(x - mu_test,
        cov_H
        # / cov_counter - jnp.outer(mu_H, mu_H)
        ) + mu_H
        x = x.reshape(N, W, C, H).transpose(0, 3, 2, 1) # (N, H, C, W)
        self.sqrt_cov_h_train = None
        x = x.reshape(N * H * C, W)
        mu_test = x.mean(0)
        x = self._match(x - mu_test,
        cov_W
        # / cov_counter - jnp.outer(mu_W, mu_W)
        )
        x = x + mu_W
        x = x.reshape(N, H, C, W).transpose(0, 1, 3, 2)
        return x
      elif self.match_type == 'channel_wise_sep':
        N, H, W, C = old_shape
        cov_H, cov_W = cov
        x = x.transpose(3, 0, 2, 1).reshape(C, N * W, H)
        x = self._match(x, cov_H / cov_counter)
        x = x.reshape(C, N, W, H).transpose(0, 1, 3, 2)
        self.sqrt_cov_h_train = None
        x = x.reshape(C, N * H, W)
        x = self._match(x, cov_W / cov_counter).reshape(C, N, H, W).transpose(1, 2, 3, 0)
        return x
      elif self.match_type == 'channel_wise_sep_cov':
        N, H, W, C = old_shape
        R_H, R_W = cov # Autocorrelation matrix
        mu_H, mu_W = mu
        x = x.transpose(3, 0, 2, 1).reshape(C, N * W, H)
        mu_test = x.mean(1, keepdims=True)
        x = self._match(x - mu_test,
            R_H / cov_counter - self._batch_outer_product(mu_H)) + mu_test
        x = x.reshape(C, N, W, H).transpose(0, 1, 3, 2)
        self.sqrt_cov_h_train = None
        x = x.reshape(C, N * H, W)
        mu_test = x.mean(1, keepdims=True)
        x = self._match(x - mu_test,
        R_W / cov_counter - self._batch_outer_product(mu_W)) + mu_test
        x = x.reshape(C, N, H, W).transpose(1, 2, 3, 0)
        return x
      elif self.match_type == 'channel_wise_sep_cov_mean':
        N, H, W, C = old_shape
        R_H, R_W = cov # Autocorrelation matrix
        mu_H, mu_W = mu
        x = x.transpose(3, 0, 2, 1).reshape(C, N * W, H)
        mu_test = x.mean(1, keepdims=True)
        x = self._match(x - mu_test,
            R_H / cov_counter - self._batch_outer_product(mu_H)) + mu_H.reshape(mu_test.shape)
        x = x.reshape(C, N, W, H).transpose(0, 1, 3, 2)
        self.sqrt_cov_h_train = None
        x = x.reshape(C, N * H, W)
        mu_test = x.mean(1, keepdims=True)
        x = self._match(x - mu_test,
        R_W / cov_counter - self._batch_outer_product(mu_W)) + mu_W.reshape(mu_test.shape)
        x = x.reshape(C, N, H, W).transpose(1, 2, 3, 0)
        return x
      else:
        if 'cov' in self.match_type:
          # matched_x = self._match(_x - _x.mean(0), cov) + _x.mean(0)
          matched_x = self._match(_x - _x.mean(0), jnp.eye(_x.shape[-1]) * 1e-7 + cov) + mu
        else:
          matched_x = self._match(_x, cov / cov_counter)
        return self._shape_back(matched_x, old_shape, self.match_type)
    return


def shift_match_builder(match_mode='channel_wise_joint'):
  return ShiftMatch(match_mode)
