# -*- coding: utf-8 -*-
"""

@author: Anonymous Author
"""

import neural_tangents as nt
import jax
import jax.numpy as jnp
from jax import random 
from jax import vmap,pmap#tree_util, lax
import numpy as nnp
import time
import math
from typing import Sequence

import flax.linen as nn

from samplings.block_predict_utils import prepare_gradient_descent_mse_staged, block_gradient_descent_mse_staged

import os 
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = '.80'

learning_rate = 1 #1e-2
momentum = 0.9


class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

def ntk_ker_f0(alfeas, num_indim = 512, num_hidden = 512, num_classes = 10, scaleidx = None, batch_ker = True):

    model = MLP([num_hidden, num_classes])
    # model = _ResNet(block_cls = _ResNetBlock, num_classes = num_classes)
    rng = dict(params=random.PRNGKey(0))
    params = model.init(rng, nnp.ones((1, num_indim)))#model.init(rng, nnp.ones((1, 8,8,256)))#

    def apply_fn(params, x):
      return model.apply(params, x, mutable=['batch_stats'])[0]

    def apply_fn_trace(params, x):
        out = apply_fn(params, x)
        return jnp.sum(out, axis=-1) / out.shape[-1] ** 0.5

    if scaleidx is not None:
        alfeas = alfeas[scaleidx,:]

    fx_all_0 = apply_fn(params, alfeas)

    kernel_fn = nt.empirical_kernel_fn( apply_fn_trace, trace_axes=(), vmap_axes=0, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES)#trace_axes=() if apply_trace else trace_axes=(-1,)
    kernel_fn_batched = nt.batch(kernel_fn, device_count=-1, batch_size=200)

    tntk = []
    batchsize = 2000
    num_batch = math.ceil(len(alfeas) / batchsize)
    s = time.time()
    for ib in range(num_batch):
        if batch_ker:
            ntk = kernel_fn_batched(alfeas[ib*batchsize:(ib+1)*batchsize,:], alfeas[:,:], 'ntk', params)
        else:
            ntk = kernel_fn(alfeas[ib*batchsize:(ib+1)*batchsize,:], alfeas[:,:], 'ntk', params)
        print(ib)
        if len(tntk) == 0:
            tntk = nnp.array(ntk)
        else:
            tntk = nnp.vstack((tntk, nnp.array(ntk)))
    print(time.time() - s)

    return tntk, fx_all_0


### vmap run on cpu 22s/sample(stepsize = 200) 20s(stepsize = 250) 25s(stepsize = 500)
def compute_batch_permutation_predictions(x_non_channel_shape, A_shape_1, C, rhs, orig_preds, odd, first,
                                          k_test_test, k_test_train, fx_test_0, new_y_train, permutation):
    extra_k_train_train_col = k_test_train[permutation].T  # n x a x k x k
    extra_k_train_train_point = k_test_test[permutation, permutation]  # a x a x k x k
    extra_k_test_train_col = k_test_test[:, permutation]
    extra_fx_train_0 = fx_test_0[permutation]
    fx_test_t = block_gradient_descent_mse_staged(
        k_test_train, x_non_channel_shape, A_shape_1, C, rhs,
        orig_preds, odd, first, extra_k_test_train_col,
        extra_k_train_train_col, extra_k_train_train_point, new_y_train, extra_fx_train_0
    )
    return fx_test_t  # nt x k



v_compute_batch_permutation_predictions = vmap(compute_batch_permutation_predictions, in_axes=[None] * 10 + [0] * 2)
#jitcompute = jax.jit(v_compute_batch_permutation_predictions)
jitcompute = pmap(v_compute_batch_permutation_predictions,
                    in_axes=tuple([None] * 10 + [0] * 2), static_broadcasted_argnums=(0, 1, 5, 6))

def gpu_split(arr, over=0, use_gpu=True):
    """Splits the first axis of `arr` evenly across the number of devices."""
    num_devices = jax.device_count('gpu' if use_gpu else 'cpu')
    if num_devices > arr.shape[over]:
        return arr
    if over == 0:
        return arr.reshape(num_devices, arr.shape[over] // num_devices, *arr.shape[1:])
    else:
        return arr.reshape(num_devices, *arr.shape[:over], arr.shape[over] // num_devices, *arr.shape[over + 1:])

data_func =gpu_split

def select_new_sample(alidx, totntk, totf0, preds, num_budget, num_class, outpath, stepsize = 250):

    if len(alidx) == 0:
        alidx = []
        totgain = []
        ridx = [ttt for ttt in range(len(totf0[0]))]
        nnp.random.shuffle(ridx)
        candidx = ridx[:1000]
        uidx = ridx[-7000:]
        s = time.time()
        for i in candidx:
            tgain = 0
            for ic in range(len(preds)):

                pslbl = preds[ic][i]
                ty = nnp.zeros((1,num_class ))
                ty[0, pslbl] = 1
                pslbl = nnp.array(ty)

                predict_fn = nt.predict.gradient_descent_mse( nnp.array([[totntk[ic][i,i]]]), pslbl)
                fx_train_t, fx_test_t = predict_fn(None, nnp.array([totf0[ic][i,:]]), nnp.array(totf0[ic][uidx,:]), nnp.array([totntk[ic][uidx,i]]).T)

                tgain += jnp.linalg.norm(fx_test_t - totf0[ic][jnp.array(uidx)] )

            totgain += [ tgain ]
            if len(totgain) % 500 == 1:
                print(len(totgain), time.time() - s)

        alidx += [candidx[nnp.argmax(totgain)]]
        num_budget = num_budget - 1    

    s = time.time()
    for tal in range(num_budget):

        totgain = nnp.zeros(1000)
        ridx = [ttt for ttt in range(len(totf0[0])) if ttt not in alidx]
        nnp.random.shuffle(ridx)
        uidx = ridx[:10000]
        candidx = uidx[:1000]

        for ic in range(len(preds)):
            tntkll = totntk[ic][nnp.ix_(alidx,alidx)]
            tntkul = totntk[ic][nnp.ix_(uidx,alidx)]
            tntkuu = totntk[ic][nnp.ix_(uidx,uidx)]
            uf0 =  totf0[ic][uidx,:]
            pl = preds[ic][uidx[:]]

            # for version 2
            # ty = nnp.zeros((len(pl), num_clusters[ic] ))
            # ty[[i for i in range(len(pl))], pl] = 1
            # plonehot = nnp.array(ty)

            pslbl = preds[ic][alidx]
            ty = nnp.zeros((len(alidx), num_class ))
            ty[[i for i in range(len(alidx))], pslbl] = 1
            pslbl = nnp.array(ty)

            staged_x_non_channel_shape, staged_C, staged_rhs, staged_orig_preds, staged_odd, staged_first = \
                prepare_gradient_descent_mse_staged(tntkll, tntkul, pslbl, totf0[ic][alidx,:], totf0[ic][uidx,:], learning_rate, trace_axes=(-1,))
            staged_A = tntkll.shape[1] * num_class

            squeeze_staged_orig_preds = jnp.array([staged_orig_preds])
            
            for i in range(math.ceil( len(candidx) / stepsize )):
                perms = jnp.array([[i for i in range(stepsize)]]).T

                perm_y_trains = nnp.zeros((stepsize,num_class))            
                for j in range(len(perms)):
                    perm_y_trains[j, preds[ic][candidx[perms[j,0]]]] = 1.

                #vmap jit
                #batch_fx_test_ts = jitcompute(staged_x_non_channel_shape, staged_A, staged_C, staged_rhs, staged_orig_preds,
                #staged_odd, staged_first, jnp.array(tntkuu), jnp.array(tntkul), uf0, perm_y_trains, perms)
                
                #pmap
                batch_fx_test_ts = jitcompute(staged_x_non_channel_shape, staged_A, staged_C, staged_rhs, staged_orig_preds,
                staged_odd, staged_first, jnp.array(tntkuu), jnp.array(tntkul), uf0, data_func(perm_y_trains), data_func(perms))

                #vmap
                # batch_fx_test_ts = vmap(compute_batch_permutation_predictions, in_axes=[None]*10 + [0,0])(staged_x_non_channel_shape, staged_A, staged_C, staged_rhs, staged_orig_preds,
                # staged_odd, staged_first, np.array(tntkuu), np.array(tntkul), uf0, perm_y_trains, perms)

                # batch_fx_test_ts = nnp.array(batch_fx_test_ts)
                ##version 1
                batch_fx_test_ts = batch_fx_test_ts[0,:,:,:]
                totgain[i*stepsize:i*stepsize + len(batch_fx_test_ts)] += jnp.linalg.norm(batch_fx_test_ts - squeeze_staged_orig_preds[0,:,:], axis=(1,2) )

                ### version 2
                # batch_fx_test_ts = f_l2diff(batch_fx_test_ts, plonehot)#batch_fx_test_ts[:,:5000,:].argmax(axis=2)

                # totgain[i*stepsize:i*stepsize + len(batch_fx_test_ts)] += batch_fx_test_ts

                #if i % 500 == 1:
        print(tal, time.time() - s, len(alidx))

        nanidx = jnp.argwhere(jnp.isnan(totgain))
        if len(nanidx) > 0:
            print('Nan')

        ###version 1
        totgain[nanidx] = 0             
        alidx += [candidx[nnp.argmax(totgain)]]#check if candidx in alidx
        ###version 2
        # totgain[nanidx] = 1e5
        # alidx += [candidx[nnp.argmin(totgain)]]
        #nnp.save(outpath + 'alidx.npy', nnp.array(alidx))

    return alidx
