
import jax, optax, itertools, collections, flax


import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from flax.training import train_state
from flax.jax_utils import prefetch_to_device
from jax.profiler import start_trace, stop_trace
from collections import namedtuple


import os
from transformers import FlaxViTForImageClassification

# ## Load data to CPU
train_images = np.load("./data/cifar100/train_images.npy")# .to_device(device=jax.devices("cpu")[0])
train_labels = np.load("./data/cifar100/train_labels.npy")# .to_device(device=jax.devices("cpu")[0])

train_images = jnp.array(train_images, device=jax.devices("cpu")[0])
train_labels = jnp.array(train_labels, device=jax.devices("cpu")[0])

dimension = 224
resizer = lambda x: jax.image.resize(x, shape=(3, dimension, dimension), method="bilinear")
train_images = jax.vmap(resizer)(train_images[:10000])


train_images.shape


## define some jax utility functions
@jax.jit
def add_trees(x, y):
    return jax.tree_util.tree_map(lambda a, b: a + b, x, y)

# ## Find the max logical batch-size


q = 0.1
full_data_size = train_images.shape[0]
physical_bs = 64

alpha = 1e-9 # failure prob.

from scipy.stats import binom
k = 1
binom_dist = binom(full_data_size, q)
while True:
    right_prob = binom_dist.sf(k * physical_bs)
    if right_prob < alpha:
        break
    k += 1
    
max_logical_batch_size = k * physical_bs

 
# ## Main functions for DP-SGD


@jax.jit
def compute_per_example_gradients(state, batch_X, batch_y):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params, X, y):
        logits = state.apply_fn(X, params=params)[0]
        one_hot = jax.nn.one_hot(y, 100)
        loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).flatten()
        assert len(loss) == 1
        return loss.sum()
    
    grad_fn = lambda X, y: jax.grad(loss_fn)(state.params, X, y)
    px_grads = jax.vmap(grad_fn, in_axes=(0, 0))(batch_X, batch_y)
  
    return px_grads

@jax.jit
def process_a_physical_batch(px_grads, mask, C):

    def clip_mask_and_sum(x, mask, clipping_multiplier):

        new_shape = (-1,) + (1,) * (x.ndim - 1)
        mask = mask.reshape(new_shape)
        clipping_multiplier = clipping_multiplier.reshape(new_shape)

        return jnp.sum(x * mask * clipping_multiplier, axis=0)

    px_per_param_sq_norms = jax.tree.map(lambda x: jnp.linalg.norm(x.reshape(x.shape[0], -1), axis=-1)**2, px_grads)
    flattened_px_per_param_sq_norms, tree_def = jax.tree_util.tree_flatten(px_per_param_sq_norms)

    px_grad_norms = jnp.sqrt(jnp.sum(jnp.array(flattened_px_per_param_sq_norms), axis=0))

    clipping_multiplier = jnp.minimum(1., C/px_grad_norms)

    return jax.tree.map(lambda x: clip_mask_and_sum(x, mask, clipping_multiplier), px_grads)

@jax.jit
def noise_addition(rng_key, accumulated_clipped_grads, noise_std, C):
    num_vars = len(jax.tree_util.tree_leaves(accumulated_clipped_grads))
    treedef = jax.tree_util.tree_structure(accumulated_clipped_grads)
    new_key, *all_keys = jax.random.split(rng_key, num=num_vars + 1)
    # draw noise
    noise = jax.tree_util.tree_map(
        lambda g, k: noise_std * C * jax.random.normal(k, shape=g.shape, dtype=g.dtype),
        accumulated_clipped_grads, jax.tree_util.tree_unflatten(treedef, all_keys))
    
    updates = add_trees(accumulated_clipped_grads, noise)
    return updates


 
# ## Define a data loader with prefetch


def prepare_data(xs):
    local_device_count = jax.local_device_count()

    def _prepare(x):
        return x.reshape((local_device_count, -1) + x.shape[1:])

    return jax.tree_util.tree_map(_prepare, xs)

def prefetch_to_device(iterator, size):
    queue = collections.deque()

    def _prefetch(xs):
        return jax.device_put(xs, jax.devices("gpu")[0])

    def enqueue(n):  # Enqueues *up to* `n` elements from the iterator.
        for data in itertools.islice(iterator, n):
            queue.append(jax.tree_util.tree_map(_prefetch, data))

    enqueue(size)  # Fill up the buffer.
    while queue:
        yield queue.popleft()
        enqueue(1)

 
# ### Parameters for training


config = namedtuple("Config", ["momentum", "learning_rate"])
config.momentum = 1
config.learning_rate = 1e-3


def create_train_state(model_name, num_labels, config):
    """Creates initial `TrainState`."""

    model = FlaxViTForImageClassification.from_pretrained(model_name, num_labels=num_labels, return_dict=False, ignore_mismatched_sizes=True)

    # Initialize the model
    params = model.params
    
    # set the optimizer
    tx = optax.sgd(config.learning_rate, config.momentum)
    return train_state.TrainState.create(apply_fn=model.__call__, params=params, tx=tx)

@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)
 
# ## DP-SGD parameters


noise_std = 0.0
C = 1.0
num_iter = 2
 
# # fori_loop

num_classes = 100
dimension = 224
input_shape = (1, 3, dimension, dimension) # vit

model_name = "google/vit-base-patch16-224"


state = create_train_state(
    model_name = model_name,
    num_labels = 100,
    config = config,
)


import time
jax.clear_caches()

num_iter = 10
dynamic_slice = True

if dynamic_slice:
    @jax.jit
    def body_fun(t, args):
        state, accumulated_clipped_grads, logical_batch_X, logical_batch_y, masks = args
        # slice
        start_idx = t * physical_bs
        pb = jax.lax.dynamic_slice(logical_batch_X, (start_idx, 0, 0, 0, 0), (physical_bs, 1, 3, 224, 224))
        yb = jax.lax.dynamic_slice(logical_batch_y, (start_idx,), (physical_bs,))
        mask = jax.lax.dynamic_slice(masks, (start_idx,), (physical_bs,))

        # compute grads and clip
        per_example_gradients = compute_per_example_gradients(state, pb, yb)
        sum_of_clipped_grads_from_pb = process_a_physical_batch(per_example_gradients, mask, C)
        accumulated_clipped_grads = add_trees(accumulated_clipped_grads, sum_of_clipped_grads_from_pb)

        return state, accumulated_clipped_grads, logical_batch_X, logical_batch_y, masks

else:
    def body_fun(t, args):
        state, accumulated_clipped_grads, logical_batch_X, logical_batch_y, masks = args
        
        # slice
        pb = logical_batch_X_split[t]
        yb = logical_batch_y_split[t]
        mask = masks[t]

        # compute grads and clip
        per_example_gradients = compute_per_example_gradients(state, pb, yb)
        sum_of_clipped_grads_from_pb = process_a_physical_batch(per_example_gradients, mask, C)
        accumulated_clipped_grads = add_trees(accumulated_clipped_grads, sum_of_clipped_grads_from_pb)

        return state, accumulated_clipped_grads, logical_batch_X, logical_batch_y, masks

for t in range(num_iter):
    sampling_rng = jax.random.PRNGKey(t + 1)
    batch_rng, binomial_rng, noise_rng = jax.random.split(sampling_rng, 3)

    indices = jax.random.permutation(batch_rng, full_data_size)[:max_logical_batch_size]
    logical_batch_X = train_images[indices].reshape(-1, 1, 3, dimension, dimension)
    logical_batch_y = train_labels[indices]

    # poisson subsample
    logical_batch_size = len(logical_batch_X)
    actual_batch_size = jax.random.bernoulli(binomial_rng, shape=(full_data_size,), p=q).sum()    
    n_masked_elements = logical_batch_size - actual_batch_size

    # masks
    masks = jnp.concatenate([jnp.ones(actual_batch_size), jnp.zeros(n_masked_elements)]).to_device(jax.devices("cpu")[0])
    
    # cast to GPU
    logical_batch_X = logical_batch_X.to_device(jax.devices("gpu")[0])
    logical_batch_y = logical_batch_y.to_device(jax.devices("gpu")[0])
    masks = masks.to_device(jax.devices("gpu")[0])
    
    if not dynamic_slice:
        masks = jnp.array(jnp.split(masks, k))
        logical_batch_X_split = jnp.array(jnp.split(logical_batch_X, k))
        logical_batch_y_split = jnp.array(jnp.split(logical_batch_y, k))


    ### gradient accumulation
    params = state.params
    
    accumulated_clipped_grads0 = jax.tree.map(lambda x: 0. * x, params)
    
    start = time.time()        
        
    _, accumulated_clipped_grads, *_ = jax.lax.fori_loop(0, k, body_fun, (state, accumulated_clipped_grads0, logical_batch_X, logical_batch_y, masks))
    noisy_grad = noise_addition(noise_rng, accumulated_clipped_grads, noise_std, C)
    
    #update
    state = jax.block_until_ready(update_model(state, noisy_grad))
    end = time.time()
    print(end-start)
