import torch as ch
import os
import torchvision
import numpy as np
import jax
import jax.numpy as jnp
from tqdm import tqdm
from .jittools import maybe_jit, set_jit
from math import ceil
from augmax.base import PyTree as AugPyTree

set_jit(True)

# random shift 2px
# random flip lr

CIFAR_MEAN = np.array([0.4914, 0.4822, 0.4465]) * 255
CIFAR_STD = np.array([0.2023, 0.1994, 0.2010]) * 255
import augmax

class CutOut(augmax.imagelevel.ImageLevelTransformation):
    def __init__(self, cutout_size=(8, 8), p=0.5, input_types=[augmax.InputType.IMAGE]):
        assert input_types == [augmax.InputType.IMAGE]
        super().__init__(input_types)
        self.cutout_size = cutout_size
        self.probability = p

    def apply(self, rng, inputs: AugPyTree, input_types: AugPyTree, invert=False):
        assert input_types == augmax.InputType.IMAGE

        key1, key2 = jax.random.split(rng)
        do_apply = jax.random.bernoulli(key1, self.probability)
        def transform_single(raw_image):
            H, W, C = raw_image.shape
            cx, cy = self.cutout_size
            cutout = jnp.zeros((cx, cy, C))
            x, y = jax.random.randint(key2, (2,), (H - cx), (W - cy))
            image = jax.lax.dynamic_update_slice(raw_image, cutout, (x, y, 0))
            current = jnp.where(do_apply, image, raw_image)
            return current

        return jax.tree_map(transform_single, inputs)

AUGMAX_AUGS = augmax.Chain(
    augmax.HorizontalFlip(0.5),
    augmax.RandomCrop(width=32, height=32),
)

def ds_to_jax(ds):
    X = np.array(ds.data).astype(np.float32) #  np ndarray uint8
    X = (X - CIFAR_MEAN) / CIFAR_STD

    idxs = np.arange(X.shape[0])
    Y = np.array(ds.targets).astype(np.int32) # list of ints

    # now convert to jax and put on gpu
    gpu_dev = jax.devices('gpu')[0]
    X = jax.device_put(jnp.array(X), gpu_dev)
    Y = jax.device_put(jnp.array(Y), gpu_dev)
    return idxs, (X, Y)

def augment_X(X, inds, seed):
    rng = jax.random.PRNGKey(seed)
    # sub_rngs = jax.random.split(rng, X.shape[0]) [TODO: edited here]
    sub_rngs = jax.random.split(rng, 50_000)[inds]
    # pad to 36x36
    X = jnp.pad(X, ((0, 0), (2, 2), (2, 2), (0, 0)), mode='reflect')
    vmapped_transform = jax.vmap(AUGMAX_AUGS)
    aug_X = vmapped_transform(sub_rngs, X)
    return aug_X

# def augment_X(X, inds, seed):
    # return X

def apply_nchw(X):
    # in jax
    return jnp.transpose(X, (0, 3, 1, 2))

IMAGE_UPPER_BD = (jnp.ones((32, 32, 3)) * 255 - CIFAR_MEAN) / CIFAR_STD
IMAGE_LOWER_BD = (jnp.zeros((32, 32, 3)) * 255 - CIFAR_MEAN) / CIFAR_STD

@maybe_jit
def replace_with_poison(data, poison):
    if poison is None:
        return data
    return data.at[:poison.shape[0]].set(poison)

def make_make_canary_batch_functions(test_is_subset=False, test_set_size=25_000, single_canary=False, datapath=None):
    ds = torchvision.datasets.CIFAR10(root=datapath, train=True,
                                        download=True)
    ds_test = torchvision.datasets.CIFAR10(root=datapath,
                                            train=False, download=True)

    ixs, (X, Y) = ds_to_jax(ds)
    Y = jax.nn.one_hot(Y, 10) * 10 # Roughly one-hot after softmax

    # First, randomly shuffle the training data
    #### BUG HERE 
    rand_perm = np.random.RandomState(1).permutation(len(X))
    X = X[rand_perm]
    Y = Y[rand_perm]
    # ixs = ixs[rand_perm]

    ixs_test, (X_test, Y_test) = ds_to_jax(ds_test)
    Y_test = jax.nn.one_hot(Y_test, 10) * 10 # Roughly one-hot after softmax
    if test_is_subset:
        ixs_test = jnp.concatenate([ixs[-test_set_size:], ixs_test]) 
        X_test = jnp.concatenate([X[-test_set_size:], X_test])
        Y_test = jnp.concatenate([Y[-test_set_size:], Y_test])
        X = X[:-test_set_size]
        Y = Y[:-test_set_size]
        ixs = ixs[:-test_set_size]

    # Make a random permutation with seed 0
    rand_perm = np.random.RandomState(1).permutation(len(X_test))
    X_test = X_test[rand_perm]
    Y_test = Y_test[rand_perm]
    ixs_test = ixs_test[rand_perm]

    n_train = len(X)
    n_test = len(X_test)

    def make_canary_batch_functions(seed, bs, epochs, poison_X, poison_Y, 
                             shuffle_train=True, use_nchw=False,
                             init_poison=False, canary_dir=None):
        assert n_train % bs == 0
        assert n_test % bs == 0
        if init_poison:
            assert (poison_X is None) and (poison_Y is not None)
            print('Initializing poison data...')
            if canary_dir is not None:
                print("Loading canaries from ", canary_dir)
                poison_X = jnp.load(os.path.join(canary_dir, 'poison_ims.npy'))
                poison_Y = jnp.load(os.path.join(canary_dir, 'poison_labs.npy'))
            else:
                poison_X = X[:len(poison_Y)]
                poison_Y = Y[:len(poison_Y)]
            return poison_X, poison_Y

        train_indices = []
        for epoch_ii in range(epochs):
            if shuffle_train:
                shuffle_key = jax.random.PRNGKey(seed + epoch_ii)
                this_epoch_indices = jax.random.permutation(shuffle_key, ixs)
            else:
                this_epoch_indices = ixs

            train_indices.append(this_epoch_indices)

        canary_key = jax.random.PRNGKey(seed)
        num_poison = 0 if poison_X is None else poison_X.shape[0]
        canary_ids = jax.random.permutation(canary_key, num_poison)   
        
        
        train_indices = jnp.concatenate(train_indices)
        new_X = replace_with_poison(X, poison_X[canary_ids[: num_poison//2]])
        new_Y = replace_with_poison(Y, poison_Y[canary_ids[: num_poison//2]])
        

        # Scramble only the first test_set_size examples
        if test_is_subset:
            rand_perm = np.random.RandomState(seed + 100).permutation(test_set_size)
            this_X_test = X_test.at[:test_set_size].set(X_test[:test_set_size][rand_perm])
            this_Y_test = Y_test.at[:test_set_size].set(Y_test[:test_set_size][rand_perm])
            this_ixs_test = ixs_test.at[:test_set_size].set(ixs_test[:test_set_size][rand_perm])
        else:
            this_X_test = X_test
            this_Y_test = Y_test
            this_ixs_test = ixs_test

        average_poison_per_batch = float(ceil(num_poison * bs / float(n_train))) * 2
        def train_batcher(i):
            s, e = i * bs, (i + 1) * bs
            slc = slice(s, e)
            bseed = seed + i
            batch_ixs = train_indices[slc]
            sel_X, sel_Y = new_X[batch_ixs], new_Y[batch_ixs]

            if single_canary:
                good_inds = batch_ixs >= num_poison
            else:
                good_inds = batch_ixs >= (num_poison//2)
            good_X_batch, good_Y_batch = sel_X[good_inds], sel_Y[good_inds]
            poison_X_batch, poison_Y_batch = sel_X[~good_inds], sel_Y[~good_inds]
            augmenter = jax.tree_util.Partial(augment_X, seed=bseed)

            # figure out how much to pad 
            num_padding = 0 if num_poison == 0 else int(ceil(poison_X_batch.shape[0] / average_poison_per_batch) * average_poison_per_batch)
            padded_good_X_batch = jnp.zeros((bs, 32, 32, 3)).at[:good_X_batch.shape[0]].set(good_X_batch)
            padded_good_Y_batch = jnp.zeros((bs, 10)).at[:good_Y_batch.shape[0]].set(good_Y_batch)
            padded_poison_X_batch = jnp.zeros((num_padding, 32, 32, 3)).at[:poison_X_batch.shape[0]].set(poison_X_batch)
            padded_poison_Y_batch = jnp.zeros((num_padding, 10)).at[:poison_Y_batch.shape[0]].set(poison_Y_batch)

            inverse_padding_inds = jnp.concatenate([jnp.arange(good_X_batch.shape[0]), bs + jnp.arange(poison_X_batch.shape[0])])
            return (batch_ixs, inverse_padding_inds),  \
                ((padded_good_X_batch, padded_poison_X_batch, augmenter), \
                 (padded_good_Y_batch, padded_poison_Y_batch))

        def test_batcher(i):
            s, e = i * bs, (i + 1) * bs
            slc = slice(s, e)
            X_batch = poison_X[slc]
            X_batch = apply_nchw(X_batch) if use_nchw else X_batch
            Y_batch = poison_Y[slc]
            X_pois = jnp.zeros((0, 32, 32, 3))
            Y_pois = jnp.zeros((0, 10))
            return jnp.arange(slc.start,slc.stop), ((X_batch, X_pois, None), (Y_batch, Y_pois))    
            
            
            # s, e = i * bs, (i + 1) * bs
            # slc = slice(s, e)
            # X_batch = this_X_test[slc]
            # X_batch = apply_nchw(X_batch) if use_nchw else X_batch
            # Y_batch = this_Y_test[slc]
            # X_pois = jnp.zeros((0, 32, 32, 3))
            # Y_pois = jnp.zeros((0, 10))
            # return np.array(this_ixs_test[slc]), ((X_batch, X_pois, None), (Y_batch, Y_pois))
        
        def pos_batcher(i):
            if single_canary:
                max_index = num_poison
            else:
                max_index = num_poison // 2
            
            s, e = i * bs, (i + 1) * bs
            slc = slice(s, e)
            X_batch = poison_X[canary_ids[: max_index]][slc]
            X_batch = apply_nchw(X_batch) if use_nchw else X_batch
            Y_batch = poison_Y[canary_ids[: max_index]][slc]
            X_pois = jnp.zeros((0, 32, 32, 3))
            Y_pois = jnp.zeros((0, 10))
            return np.array(canary_ids[: max_index][slc]), ((X_batch, X_pois, None), (Y_batch, Y_pois))
        
        def neg_batcher(i):
            if single_canary:
                return None
            s, e = i * bs, (i + 1) * bs
            slc = slice(s, e)
            X_batch = poison_X[canary_ids[num_poison//2 :]][slc]
            X_batch = apply_nchw(X_batch) if use_nchw else X_batch
            Y_batch = poison_Y[canary_ids[num_poison//2 : ]][slc]
            X_pois = jnp.zeros((0, 32, 32, 3))
            Y_pois = jnp.zeros((0, 10))
            return np.array(canary_ids[num_poison//2:][slc]), ((X_batch, X_pois, None), (Y_batch, Y_pois))
         
        if single_canary:    
            pos_batcher.dataset_size = len(poison_X[canary_ids[: num_poison]])
            neg_batcher.dataset_size = 0
        else:
            pos_batcher.dataset_size = len(poison_X[canary_ids[: num_poison//2]])
            neg_batcher.dataset_size = len(poison_X[canary_ids[num_poison//2 :]])
        train_batcher.dataset_size = len(X)
        test_batcher.dataset_size = len(poison_X)
        return train_batcher,test_batcher, pos_batcher, neg_batcher, canary_ids

    return make_canary_batch_functions
