import argparse
import os
import jax
from tqdm import tqdm
from functools import partial
import jax.numpy as jnp
import optax
from utils import make_summary_writer
from lira_utils import generate_train_test_model_perms, lira_hinge_loss
from utils import get_datasets
from canary_opt import apply_model_perms, flatten_over_models, get_logits, process_config, update_model
from canary_opt import create_train_state


from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
jnp.ones(10).device()

def fit_greedy_canary(initial_canary_params: optax.Params, iniail_model_state, initial_opt_state, model_perms, output_path):
    """Fit a canary for `num_epochs:= model_perms.shape[0]` epochs where each epoch is guranteed to contain the canary """

    @partial(jax.jit)
    def train_model_and_canary(canary_params, opt_state, state_with_canary, state_without_canary, perms):
        """Train models for epoch and then train the canary for one step.
        initial_state: initial model state
        opt_state: initial canary optimizer state
        train_ds: training dataset
        perms: permutations of the training dataset
        """

        num_steps = perms.shape[0] # number of batches
        num_models = perms.shape[-1]

        epoch_loss_with_canary = jnp.zeros(shape=(num_steps, num_models))
        epoch_accuracy_with_canary = jnp.zeros(shape=(num_steps, num_models))

        epoch_loss_without_canary = jnp.zeros(shape=(num_steps, num_models))
        epoch_accuracy_without_canary = jnp.zeros(shape=(num_steps, num_models))

        @partial(jax.jit)
        def train_step(perm_idx, vals, canary=None):
            state, epoch_loss, epoch_accuracy = vals
            perm = perms[perm_idx]
            if canary is not None:
                images = jnp.concatenate([train_ds['image'], canary[None]])
                labels = jnp.append(train_ds['label'], config.canary_label)
            else:
                images = train_ds['image']
                labels = train_ds['label']
            grads, loss, accuracy = apply_model_perms(state, images, labels, perm)
            state = update_model(state, grads)
            epoch_loss = epoch_loss.at[perm_idx, :].set(loss)
            epoch_accuracy = epoch_accuracy.at[perm_idx, :].set(accuracy)
            return (state, epoch_loss, epoch_accuracy)


        # Take `num_steps` to train each of the two models

        # Provide the canary here
        state_with_canary, epoch_loss_with_canary, epoch_accuracy_with_canary = \
                                jax.lax.fori_loop(0, num_steps, partial(train_step, canary=canary_params), (state_with_canary, epoch_loss_with_canary, epoch_accuracy_with_canary))

        # No canary
        state_without_canary, epoch_loss_without_canary, epoch_accuracy_without_canary = \
                                jax.lax.fori_loop(0, num_steps, train_step, (state_without_canary, epoch_loss_without_canary, epoch_accuracy_without_canary))

        train_loss_with_canary = jnp.mean(epoch_loss_with_canary, axis=0)
        train_accuracy_with_canary = jnp.mean(epoch_accuracy_with_canary, axis=0)

        train_loss_without_canary = jnp.mean(epoch_loss_without_canary, axis=0)
        train_accuracy_without_canary = jnp.mean(epoch_accuracy_without_canary, axis=0)

        def canary_loss_given_model(canary, state_with_canary, state_without_canary, canary_label, loss_type='l2', loss_agg='max'):
            # breakpoint()
            logits_with_canary = flatten_over_models(get_logits(state_with_canary, canary[None, :]))[:,0,:]
            logits_without_canary = flatten_over_models(get_logits(state_without_canary, canary[None, :]))[:,0,:]
            if loss_type == 'l2':
                distance_per_model = jnp.mean(optax.l2_loss(logits_with_canary, logits_without_canary), axis=1)
            elif loss_type == 'hinge':
                _loss = jax.vmap(lambda _model_logits: lira_hinge_loss(_model_logits, canary_label[None]))
                # breakpoint()
                hinge_loss_with_canary = _loss(logits_with_canary).ravel()
                hinge_loss_without_canary = _loss(logits_without_canary).ravel()
                distance_per_model = hinge_loss_with_canary - hinge_loss_without_canary
            else:
                raise NotImplementedError

            if loss_type != 'lira':
                if loss_agg == 'max':
                    loss = - jnp.max(distance_per_model) # across models
                elif loss_agg == 'mean':
                    loss = - jnp.mean(distance_per_model)
                else:
                    raise NotImplementedError
            else:
                raise NotImplementedError

            return loss

        @jax.jit
        def canary_step(canary_params, opt_state):
            loss_value, grads = jax.value_and_grad(partial(canary_loss_given_model, canary_label=config.canary_label, loss_type=config.loss_type, loss_agg=config.loss_agg))(canary_params, state_with_canary, state_without_canary)

            updates, opt_state = canary_optimizer.update(grads, opt_state, canary_params)
            canary_params = optax.apply_updates(canary_params, updates)
            if config.clip_canary:
                canary_params = jnp.clip(canary_params, config.clip_min, config.clip_max)
            return canary_params, opt_state, loss_value

        # Update the canary
        canary_params, opt_state, canary_loss_value = canary_step(canary_params, opt_state)

        return canary_params, opt_state, state_with_canary, state_without_canary, canary_loss_value, train_loss_with_canary, train_loss_without_canary, train_accuracy_with_canary, train_accuracy_without_canary

    def save_progress(per_step_canaries, per_step_canary_losses):
        jnp.savez(output_path,
            per_step_canaries=jnp.array(per_step_canaries),
            per_step_canary_losses=jnp.array(per_step_canary_losses))


    # Start

    # Initialize models
    state_with_canary, state_without_canary = iniail_model_state, iniail_model_state

    # Initialize canary
    canary_params = initial_canary_params

    # Initialize canary optimizer
    opt_state = initial_opt_state

    # Number of steps to train the canary which is also the number of epochs the models are trained for
    canary_search_steps = model_perms.shape[0]

    per_step_canaries = []
    per_step_canary_losses = []
    progress_bar = tqdm(total=canary_search_steps, desc="Canary Search", unit="step")

    # with jax.profiler.trace(config.logdir):
    for _step in range(canary_search_steps):
        try:
            canary_params, opt_state, state_with_canary, state_without_canary, canary_loss_value, train_loss_with_canary, train_loss_without_canary, train_accuracy_with_canary, train_accuracy_without_canary = \
                                            train_model_and_canary(canary_params, opt_state, state_with_canary, state_without_canary, model_perms[_step])

            if _step % 10 == 0:
                save_progress(per_step_canaries, per_step_canary_losses)

            if jnp.any(jnp.isnan(canary_params)):
                save_progress(per_step_canaries, per_step_canary_losses)
                break
            else:
                per_step_canaries.append(canary_params)
                per_step_canary_losses.append(canary_loss_value)

            progress_bar.set_postfix(canary_loss=f'{canary_loss_value:.4f}')
            progress_bar.update(1)

            summary_writer.scalar('canary_loss', canary_loss_value, _step)
            summary_writer.scalar('train_accuracy_with_canary', train_accuracy_with_canary.mean(), _step)
            summary_writer.scalar('train_accuracy_without_canary', train_accuracy_without_canary.mean(), _step)
            summary_writer.image('canary', canary_params, _step)
            summary_writer.flush()
        except KeyboardInterrupt:
            break

    progress_bar.close()
    save_progress(per_step_canaries, per_step_canary_losses)

    return canary_params


if __name__ == "__main__":

    args = argparse.ArgumentParser(description="Generate greedy canaries.")

    args.add_argument('--output-dir', type=str, default='.', help='output directory')
    args.add_argument('--logdir', type=str, default='./runs/', help='tensorboard log directory')
    args.add_argument('--num-models', type=int, default=20, help='number of models')
    args.add_argument('--learning-rate', type=float, default=0.1, help='learning rate')
    args.add_argument('--momentum', type=float, default=0.96, help='momentum')
    args.add_argument('--num-epochs', type=int, default=100, help='number of epochs')
    args.add_argument('--batch-size', type=int, default=1024, help='batch size')
    args.add_argument('--canary-label', type=int, default=0, help='canary label')
    args.add_argument('--canary-learning-rate', type=float, default=0.1, help='canary learning rate')
    args.add_argument('--canary-momentum', type=float, default=0.99, help='canary momentum')
    args.add_argument('--clip-canary', action='store_true', help='clip canary')
    args.add_argument('--loss-type', type=str, default='l2', help='loss type')
    args.add_argument('--loss-agg', type=str, default='max', help='loss aggregation')
    args.add_argument('--architecture', type=str, default='MLP', help='model architecture')
    args.add_argument('--seed', type=int, default=0, help='seed')
    args.add_argument('--clip-min', type=float, default=0, help='clip min')
    args.add_argument('--clip-max', type=float, default=1, help='clip max')
    args.add_argument('--force-num-steps-per-epoch', type=int, default=None, help='This will increase the number of greedy canary updates by this amount. Note that canary gets inserted more than once per actual pass (actual epoch) over the dataset.')

    args.add_argument("--dataset", type=str, default='mnist', help="Dataset to use. Can be 'mnist' or 'cifar10'.")

    config = args.parse_args()
    config = process_config(config)


    key = jax.random.PRNGKey(config.seed)

    # dataset load
    train_ds, test_ds = get_datasets(config.dataset, None)
    train_ds_size = train_ds['label'].shape[0]
    canary_index = train_ds_size

    # generate permutations of data

    key, subkey = jax.random.split(key)

    train_perms, test_perms, model_perms = generate_train_test_model_perms(subkey, config.num_models, config.num_epochs, train_ds_size, config.batch_size,
                                                                           force_num_steps_per_epoch=config.force_num_steps_per_epoch)

    # sanitry check: make sure all permutations are unique
    assert jnp.all(jnp.sort(jnp.unique(model_perms[0, :, :, 0].ravel())) == jnp.sort(model_perms[0, :, :, 0].ravel()))

    if config.force_num_steps_per_epoch is None:
        # sanity check: make sure all permutations contain the canary index (in each epoch)
        assert jnp.any(model_perms[0, :, :, 0].ravel() == canary_index)

    key, subkey = jax.random.split(key)
    init_model_state = create_train_state(jax.random.split(subkey, config.num_models), config.learning_rate, config.momentum, config.num_models, architecture=config.architecture, image_shape=config.image_shape)

    key, subkey = jax.random.split(key)
    init_canary = jax.random.uniform(subkey, shape=config.image_shape)

    # Setup tensorboard summary writter
    experiment_name = \
                      f"g{config.force_num_steps_per_epoch}canary_{config.dataset}_arch_{config.architecture}_{config.canary_label}_clipped_{config.clip_canary}_loss_{config.loss_type}_agg_{config.loss_agg}_models_{config.num_models}"

    summary_writer = make_summary_writer(os.path.join(config.logdir, experiment_name), vars(config))

    # Instantiate the canary optimizer
    canary_optimizer = optax.sgd(learning_rate=config.canary_learning_rate, momentum=config.canary_momentum)
    # Initialize the canary optimizer state
    initial_opt_state = canary_optimizer.init(init_canary)

    output_path = os.path.join(config.output_dir, experiment_name)
    final_canary = fit_greedy_canary(init_canary, init_model_state, initial_opt_state, model_perms, output_path)
