import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
from functools import partial
from flax.training.train_state import TrainState as ParentTrainState
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import chex
import wandb

class TrainState(ParentTrainState):
    dropout_rng: chex.PRNGKey

class Network(nn.Module):
    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Conv(features=32, kernel_size=(3, 3), strides=(2, 2), name="conv1")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2))
        
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2), name="conv2")(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2))

        x = x.reshape((x.shape[0], -1))
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        x = nn.Dense(features=10, name="head")(x)

        return x

# Initialize population
def init(rng):
    rng, dropout_rng = jax.random.split(rng)
    network = Network()
    params = network.init(rng, jnp.ones([1, 28, 28, 1]), training=False)
    tx = optax.sgd(learning_rate=0., momentum=0.)
    train_state = TrainState.create(
        apply_fn=network.apply,
        params=params,
        tx=tx,
        dropout_rng=dropout_rng
    )
    return train_state
init = jax.vmap(jax.jit(init))

def init_hyperparams(rng, hyperparam_ranges):
    rng_hyperparams = {key: rng for key, rng in zip(hyperparam_ranges.keys(), jax.random.split(rng, len(hyperparam_ranges)))}
    hyperparams = jax.tree_map(lambda rng, range: jax.random.uniform(rng, minval=range[0], maxval=range[1]), rng_hyperparams, hyperparam_ranges)
    return hyperparams
init_hyperparams = jax.vmap(jax.jit(init_hyperparams), (0, None))

@jax.jit
def train(rng, train_states, hyperparams, dataset):
    batch_size = 32

    dataset_size = len(dataset['image'])
    steps_per_epoch = dataset_size // batch_size
    steps_per_epoch = int(0.1 * steps_per_epoch)
    perms = jax.random.permutation(rng, len(dataset['image']))
    perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batches = jax.tree_map(lambda x: jnp.take(x, perms, axis=0), dataset)

    @jax.vmap
    def reset_head(params1, params2):
        params1["params"]["head"] = params2["params"]["head"]
        # params1["params"] = params2["params"]
        return params1
        # return jax.tree_map(lambda x, y: (1 - coeff) * x + coeff * y, params1, params2)
    num_members = jax.tree_util.tree_flatten(train_states)[0][0].shape[0]
    rng, _rng = jax.random.split(rng)
    params = reset_head(train_states.params, init(jax.random.split(_rng, num_members)).params)
    train_states = train_states.replace(params=params)

    @jax.jit
    def train_step(train_states, batch):
        @jax.vmap
        def update(train_state, hyperparams):
            tx = optax.sgd(hyperparams["LR"], momentum=hyperparams["MOMENTUM"])
            tmp_train_state = train_state.replace(tx=tx)
            def loss_fn(params):
                logits = tmp_train_state.apply_fn(params, batch['image'], training=True, rngs={'dropout': train_state.dropout_rng})
                loss = jnp.mean(optax.softmax_cross_entropy(
                    logits=logits, 
                    labels=jax.nn.one_hot(batch['label'], num_classes=10)))
                return loss, logits
            grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
            (loss, logits), grads = grad_fn(tmp_train_state.params)
            tmp_train_state = tmp_train_state.apply_gradients(grads=grads)
            _, dropout_rng = jax.random.split(train_state.dropout_rng)
            return train_state.replace(step=tmp_train_state.step, params=tmp_train_state.params, opt_state=tmp_train_state.opt_state, dropout_rng=dropout_rng), loss
        return update(train_states, hyperparams)
    train_states, losses = jax.lax.scan(train_step, train_states, batches)
    
    return train_states, losses.mean(axis=0)

@jax.jit
def evaluate(train_states, dataset):
    batch_size = 100
    
    dataset_size = len(dataset['image'])
    steps_per_epoch = dataset_size // batch_size
    inds = jnp.arange(len(dataset['image']))
    inds = inds[:steps_per_epoch * batch_size]
    inds = inds.reshape((steps_per_epoch, batch_size))
    batches = jax.tree_map(lambda x: jnp.take(x, inds, axis=0), dataset)
    
    def eval_step(_, batch):
        @jax.vmap
        def eval_minibatch(train_state):
            logits = train_state.apply_fn(train_state.params, batch["image"], training=False)
            loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(batch["label"], num_classes=10)))
            accuracy = jnp.mean(jnp.argmax(logits, -1) == batch["label"])
            metrics = {
                'loss': loss,
                'accuracy': accuracy
            }
            return metrics
        return None, eval_minibatch(train_states)
    _, metrics = jax.lax.scan(eval_step, None, batches)
        
    return jax.tree_map(lambda x: x.mean(axis=0), metrics)

def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
    return train_ds, test_ds

@jax.jit
def step_and_eval(rng, train_states, hyperparams, train_data, test_data):
    num_members = jax.tree_util.tree_flatten(train_states)[0][0].shape[0]
    rngs = jax.random.split(rng, num_members)
    
    train_states, losses = train(rng, train_states, hyperparams, train_data)
    eval_metrics = evaluate(train_states, test_data)
    fitness = eval_metrics["accuracy"]
    
    return train_states, fitness, eval_metrics["loss"]

@jax.jit
def exploit(rng, h, theta, fitness):
    fitness_sorted = jnp.sort(fitness)
    kth_best_fitness = fitness_sorted[-6]
    rngs = jax.random.split(rng, len(fitness))
    def member_exploit(rng, h_i, theta_i, fitness_i):
        exploit_bool = fitness_i < kth_best_fitness
        copy_id = jax.random.choice(rng, len(fitness), p=(fitness >= kth_best_fitness))
        theta_i = jax.tree_map(lambda x, y: jax.lax.select(exploit_bool, x[copy_id], y), theta, theta_i)
        h_i = jax.lax.select(exploit_bool, h[copy_id], h_i)
        return h_i, theta_i, exploit_bool
    return jax.vmap(member_exploit)(rngs, h, theta, fitness)

@partial(jax.jit, static_argnums=(4, 5, 6))
def explore(rng, h, theta, explore_mask, final_meta_rate=None, self_referential=False, fixed_pbt=False):
    rngs = jax.random.split(rng, len(explore_mask))
    def member_explore(rng, h_i, theta_i, explore_bool):
        # noise
        if not fixed_pbt:
            if self_referential:
                std = jnp.roll(h_i, -1, axis=0).at[-1].set(h_i[-1])
            else:
                std = jnp.roll(h_i, -1, axis=0).at[-1].set(final_meta_rate)
            offset = std * jax.random.normal(rng, h_i.shape)
            h_i = jax.lax.select(explore_bool, h_i + offset, h_i)

            # offset = jnp.log(std * jax.random.normal(rng, h_i.shape) + 1)
            # h_i = jax.lax.select(explore_bool, jnp.clip(jnp.exp(jnp.log(h_i) + offset), a_min=1e-9), h_i)

            # offset = std * jax.random.normal(rng, h_i.shape)
            # h_i = jax.lax.select(explore_bool, jnp.clip(h_i * (1 + offset), a_min=1e-9), h_i)
        else:
            offset = jax.random.choice(rng, jnp.array([0.8, 1.2]))
            h_i = jax.lax.select(explore_bool, h_i * offset, h_i)

        
        # # Offset + noise
        # noise = 1e-5 * jax.random.normal(rng, h_i.shape)
        # offset = jnp.roll(h_i, -1, axis=0).at[-1].set(0.) + noise
        # h_i = jax.lax.select(explore_bool, h_i + offset, h_i)
    
        return h_i, theta_i
    return jax.vmap(member_explore)(rngs, h, theta, explore_mask)

if __name__=="__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--self_referential", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--fixed_pbt", action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument("--final_meta_rate", type=float, default=1e-7)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--norder", type=int, default=3)
    parser.add_argument("--popsize", type=int, default=30)
    parser.add_argument("--LR", type=float, default=1e-9)
    parser.add_argument("--sr_mult", type=float, default=2.)
    config = vars(parser.parse_args())
    
    wandb.login()
    
    popsize=config["popsize"]
    norder=config["norder"]

    train_data, test_data = get_datasets()

    rng = jax.random.PRNGKey(config["seed"])

    # ranges for generation
    hyperparam_ranges = {
        "LR": jnp.array([config["LR"], config["LR"]]),
        "MOMENTUM": jnp.array([0., 0.]),
    }
    
    rng, rng_pop, rng_h = jax.random.split(rng, 3)
    train_states = init(jax.random.split(rng_pop, popsize))
    hyperparams = init_hyperparams(jax.random.split(rng_h, popsize), hyperparam_ranges)

    # ranges for clipping
    hyperparam_ranges = {
        "LR": jnp.array([1e-12, 10.]),
        "MOMENTUM": jnp.array([0., 1.]),
    }

    _, unravel = jax.flatten_util.ravel_pytree(jax.tree_map(lambda x: x[0], hyperparams))
    unravel = jax.vmap(unravel)

    @jax.vmap
    def flatten(hyperparams):
        flat, _ = jax.flatten_util.ravel_pytree(hyperparams)
        return flat

    @jax.jit
    def clip(hyperparams, ranges):
        hyperparams_tree = unravel(hyperparams[:, 0])
        hyperparams_tree = jax.tree_map(lambda x, r: x.clip(r[0], r[1]), hyperparams_tree, ranges)
        flat = flatten(hyperparams_tree)
        hyperparams = hyperparams.at[:, 0].set(flat)
        return hyperparams

    hyperparams = flatten(hyperparams)

    # add arbitrary order meta-hyperparams
    hyperparams = hyperparams.reshape(hyperparams.shape[0], 1, hyperparams.shape[1])
    if norder > 1:
        hyperparams = jnp.repeat(hyperparams, norder, axis=1)#.at[:, 1:].set(1e-5)
        # rng, rng_noise = jax.random.split(rng)
        # metaparams = 1e-5 * jax.random.normal(rng_noise, shape=(hyperparams.shape[0], norder-1, hyperparams.shape[2]))
        # hyperparams = hyperparams.at[:, 1:].set(metaparams)
        hyperparams *= (config["sr_mult"] ** jnp.arange(norder))[None, :, None]
    print("(META) HYPERPARAM SHAPE:", hyperparams.shape)
    print(hyperparams)

    run = wandb.init(project="MNIST2")
    
    print(jax.tree_map(lambda x: x.shape, train_states.params))
    # quit()

    for gen in range(500):
        rng, rng_step, rng_exploit, rng_explore = jax.random.split(rng, 4)
        train_states, fitness, losses = step_and_eval(rng_step, train_states, unravel(hyperparams[:, 0]), train_data, test_data)
        hyperparams, train_states, explore_mask = exploit(rng_exploit, hyperparams, train_states, fitness)
        if norder > 0:
            hyperparams, train_states = explore(rng_explore, hyperparams, train_states, explore_mask, config["final_meta_rate"], config["self_referential"], config["fixed_pbt"])
        hyperparams = clip(hyperparams, hyperparam_ranges)

        # # permute the labels!
        train_data["label"] = (train_data["label"] + 1) % 10
        test_data["label"] = (test_data["label"] + 1) % 10
        
        log_dict = {"fitness": fitness.max()}
        best_id = fitness.argmax()
        log_dict.update({"loss": losses[best_id]})
        for i in range(max(norder, 1)):
            log_dict.update({f"{i}-th_order/best_{key}": value[best_id] for key, value in unravel(hyperparams[:, i]).items()})
            log_dict.update({f"{i}-th_order/{key}": value for key, value in unravel(hyperparams[:, i]).items()})
        print(log_dict["fitness"], log_dict["0-th_order/best_LR"], log_dict["0-th_order/best_MOMENTUM"])
        wandb.log(log_dict)
