# -*- coding: utf-8 -*-
"""

block prediction utils

@author: Anonymous Author
"""
import jax
import jax.numpy as np
import jax.scipy as sp
from jax._src.scipy.linalg import _cho_solve as _jsp_cho_solve

from samplings.neural_tangents_f.utils import utils#, dataclasses
from samplings.neural_tangents_f.utils.typing import Axes

import scipy as osp
import numpy as onp
from functools import partial

# ###eg
# staged_x_non_channel_shape, staged_C, staged_rhs, staged_orig_preds, staged_odd, staged_first = \
#     prepare_gradient_descent_mse_staged(kll0, kul0, pslbl, fl0, fu0, learning_rate, trace_axes=(-1,))
# staged_A = kll0.shape[1] * 10

# perm = [0]
# perm = nnp.array(perm)

# extra_k_train_train_col = kul0[perm].T
# extra_k_train_train_point = kuu0[perm, perm]
# extra_k_test_train_col = kuu0[:, perm]
# extra_fx_train_0 = fu0[perm]
# extra_y_train_onehot = pslbl[-1:]

# fx_test_t_blockn = block_gradient_descent_mse_staged(
#     kul0, staged_x_non_channel_shape, staged_A, staged_C, staged_rhs,
#     staged_orig_preds, staged_odd, staged_first, extra_k_test_train_col,
#     extra_k_train_train_col, extra_k_train_train_point, extra_y_train_onehot, extra_fx_train_0,2
# )

# ###


# Should be called at the start of each batch/sequential choice
def prepare_gradient_descent_mse_staged(
        k_train_train: np.ndarray,
        k_test_train: np.ndarray,
        y_train: np.ndarray,
        fx_train_0: np.ndarray,
        fx_test_0: np.ndarray,
        learning_rate: float = 1.,
        diag_reg: float = 0.,
        diag_reg_absolute_scale: bool = False,
        trace_axes: Axes = (-1,)):

  _, odd, first, _ = _get_axes(k_train_train)
  trace_axes = utils.canonicalize_axis(trace_axes, y_train)
  trace_axes = tuple(-y_train.ndim + a for a in trace_axes)
  n_t_axes, n_non_t_axes = len(trace_axes), y_train.ndim - len(trace_axes)
  last_t_axes = tuple(range(-n_t_axes, 0))
  non_t_axes = tuple(range(-y_train.ndim, -n_t_axes))

  rhs = y_train - fx_train_0

  x_non_channel_shape = k_train_train.shape[1::2]
  # n, k = k_train_train.shape[0], k_train_train.shape[2]
  # B = k_train_train.transpose(0, 2, 1, 3).reshape(n * k, n * k)
  # Todo: if k_train_train doesn't fit in one GPU's memory, then we should revise this
  min_eigval = osp.linalg.eigh(make_2d_cpu(k_train_train), eigvals_only=True, subset_by_index=[0, 0])[0]
  if min_eigval < 0:
    dr = 1.1 * max(-min_eigval, 1e-6)
    C = jit_prepare_cho_solve_staged_with_dr(k_train_train, dr)[0]
  else:
    C = jit_prepare_cho_solve_staged(k_train_train)[0]
  dfx_test = jit_raw_cho_solve_multiply_staged(C, k_test_train, rhs, trace_axes, x_non_channel_shape, odd, first)
  dfx_test = onp.moveaxis(dfx_test, last_t_axes, trace_axes)
  orig_preds = fx_test_0 + dfx_test  # (nt + n) x k

  return x_non_channel_shape, C, rhs, orig_preds, odd, first

@partial(jax.jit, backend='cpu')
def make_2d_cpu(mat):
  return utils.make_2d(mat)

def block_gradient_descent_mse_staged(k_test_train: np.ndarray, x_non_channel_shape, A_shape_1, C, rhs, orig_preds, odd,
                                      first, extra_k_test_train_col, extra_k_train_train_col,
                                      extra_k_train_train_point, extra_y_train, extra_fx_train_0, diag_reg_coe = 0.1):

  if len(k_test_train.shape) == 4:

    b = extra_k_train_train_col
    b_axes = utils.canonicalize_axis((1, 3), b)
    last_b_axes = range(-len(b_axes), 0)
    x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes)
    b = np.moveaxis(b, b_axes, last_b_axes)
    b = b.reshape((A_shape_1, -1))
    v = _jsp_cho_solve(C, b, False).reshape(x_shape)

    # copied
    # v = solve(extra_k_train_train_col, (1, 3))  # (n, k, 1, k)
    v = np.moveaxis(v, tuple(range(-2, 0)), (1, 3))  # (n, 1, k, k)

    dTv = np.tensordot(extra_k_train_train_col, v, ((0, 2), (0, 2)))  # 1 x 1 x k x k
    dTv = np.moveaxis(dTv, (-2, -1), (1, 3))
    k = extra_k_train_train_point - dTv

    vT = np.transpose(v, (1, 0, 3, 2))  # 1 x n x k x k
    vTF = np.tensordot(vT, rhs, (odd, first))  # 1 x k
    vTF_minus_f = vTF - (extra_y_train - extra_fx_train_0)  # 1 x k

    k_non_channel_shape = k.shape[1::2]
    k = utils.make_2d(k)
    kC = _add_diagonal_regularizer(k, diag_reg_coe, False)
    kC = sp.linalg.cho_factor(kC, False)

    b_axes = ()
    b = vTF_minus_f
    b_axes = utils.canonicalize_axis(b_axes, b)
    last_b_axes = range(-len(b_axes), 0)
    kx_shape = k_non_channel_shape + tuple(b.shape[a] for a in b_axes)

    b = np.moveaxis(b, b_axes, last_b_axes)
    b = b.reshape((k.shape[1], -1))

    kIvTF_minus_f = sp.linalg.cho_solve(kC, b)
    kIvTF_minus_f = kIvTF_minus_f.reshape(kx_shape)

    Av = np.tensordot(k_test_train, v, ((1, 3), (0, 2)))  # (nt + n) x k x 1 x k
    Av = np.moveaxis(Av, (-2, -1), (1, 3))  # (nt + n) x 1 x k x k

    # ktd_shape = k_test_train.shape
    # k_test_train = utils.make_2d(k_test_train)
    # v = utils.make_2d(v)
    # Av = lax.map(lambda k_td_i: np.einsum('j,jk->k', k_td_i, v), k_test_train)  # ntk x k
    # Av = Av.reshape((ktd_shape[0], ktd_shape[-2], 1, ktd_shape[-1])).transpose(0, 2, 1, 3)  # nt x 1 x k x k

    Av_mius_a = Av - extra_k_test_train_col  # (nt + n) x 1 x k x k  # ntk x k

    change = np.tensordot(Av_mius_a, kIvTF_minus_f, (odd, first))  # (nt + n) x k

  else:

    b = extra_k_train_train_col
    v = _jsp_cho_solve(C, b, False)
    # dTv = np.dot(extra_k_train_train_col.T, v)
    # k = extra_k_train_train_point - dTv
    vTF = np.dot(v.T, rhs)
    vTF_minus_f = vTF - (extra_y_train - extra_fx_train_0)

    # # kC = _add_diagonal_regularizer(k, diag_reg_coe, False)
    # # kC = sp.linalg.cho_factor(kC, False)
    # # kC = _add_diagonal_regularizer(k, diag_reg_coe, False)
    # kC = sp.linalg.cho_factor(k, False)
    # b = vTF_minus_f
    # kIvTF_minus_f = sp.linalg.cho_solve(kC, b)

    Av = np.dot(k_test_train, v)
    Av_minus_a = Av - extra_k_test_train_col

    change = np.dot(Av_minus_a, vTF_minus_f)#kIvTF_minus_f)
    u = extra_k_train_train_point - np.dot(extra_k_train_train_col.T, v)

  new_preds = orig_preds + change/u

  return new_preds

def _add_diagonal_regularizer(A: np.ndarray,
                              diag_reg: float,
                              diag_reg_absolute_scale: bool) -> np.ndarray:
  dimension = A.shape[0]
  if not diag_reg_absolute_scale:
    diag_reg *= np.trace(A) / dimension
  return A + diag_reg * np.eye(dimension)

@partial(jax.jit, backend='cpu')
def jit_prepare_cho_solve_staged_with_dr(A: np.ndarray, dr):
  A = utils.make_2d(A)
  A = A + dr * np.eye(A.shape[0])
  return sp.linalg.cho_factor(A, False)

@partial(jax.jit, backend='cpu')
def jit_prepare_cho_solve_staged(A: np.ndarray):
  return sp.linalg.cho_factor(utils.make_2d(A), False)

@partial(jax.jit, backend='cpu', static_argnums=(3, 4, 5, 6))
def jit_raw_cho_solve_multiply_staged(C, k_test_train, b: np.ndarray, b_axes: tuple, x_non_channel_shape, odd, first) -> np.ndarray:
  b_axes = utils.canonicalize_axis(b_axes, b)
  last_b_axes = range(-len(b_axes), 0)
  x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes)

  b = onp.moveaxis(b, b_axes, last_b_axes)
  b = b.reshape((C.shape[1], -1))

  x = _jsp_cho_solve(C, b, False)
  x = x.reshape(x_shape)
  dfx_test = np.tensordot(k_test_train, x, (odd, first))
  return dfx_test

def _get_axes(x: np.ndarray):
  n = x.ndim
  return (
      tuple(range(0, n, 2)),
      tuple(range(1, n, 2)),
      tuple(range(0, n // 2)),
      tuple(range(n // 2, n))
  )

