import os
import numpy as np
import time
import jax 
from jax.tree_util import Partial
from jax import numpy as jnp
from tqdm import tqdm
from ..cifar_ds import make_make_canary_batch_functions, IMAGE_LOWER_BD, IMAGE_UPPER_BD
from ..resnet9_poison import construct_model, reinit_model
from ..okazaki_cifar import cifar_per_sample_loss, okazaki_cifar, str_hash, cifar_corrects
from ..okazaki_train import add_trees, subtract_trees
import uuid
from time import sleep
from functools import partial
from datetime import datetime
import json
from hashlib import sha256 as my_hash
import optax
import pickle
import sys
from itertools import product

# @jax.jit
def single_grad(state, batch, per_sample_loss, targeted):
    def losser(state, batch):
        params = state.params
        test_loss = per_sample_loss(params, state.batch_stats, batch, train=False)
        if targeted:
            ixs, ((x, x_p, a), (y, y_p)) = batch
            return jnp.sum(test_loss[y.argmax(1) == 6])
        else:
            return jnp.sum(test_loss)

    return jax.grad(jax.tree_util.Partial(losser, batch=batch), allow_int=True)(state)

def big_model_cotangenter(state, per_sample_loss, get_val_batch, num_val_batches, targeted):
    grad_accumulator = single_grad(state, get_val_batch(0), per_sample_loss, targeted)
    def add_grad(i, acc):
        return add_trees(acc, single_grad(state, get_val_batch(i), per_sample_loss, targeted))

    for i in tqdm(range(1, num_val_batches), desc='Computing gradients'):
        grad_accumulator = add_grad(i, grad_accumulator)

    return grad_accumulator

def big_model_canary_cotangenter(state, per_sample_loss, get_in_canaries, get_out_canaries, num_batches, targeted):
    pos_grad_accumulator = single_grad(state, get_in_canaries(0), per_sample_loss, targeted)
    neg_grad_accumulator = single_grad(state, get_out_canaries(0), per_sample_loss, targeted)
    def add_grad(get_batch, i, acc):
        return add_trees(acc, single_grad(state, get_batch(i), per_sample_loss, targeted))

    for i in tqdm(range(1, num_batches), desc='Computing gradients'):
        pos_grad_accumulator = add_grad(get_in_canaries, i, pos_grad_accumulator)
        neg_grad_accumulator = add_grad(get_out_canaries, i, neg_grad_accumulator)

    grad_accumulator = subtract_trees(pos_grad_accumulator, neg_grad_accumulator)
    return grad_accumulator

def canaries_eval(state, per_sample_loss, get_in_canaries, get_out_canaries, num_batches, targeted):
    def losser(state, batch):
        params = state.params
        test_loss = per_sample_loss(params, state.batch_stats, batch, train=False)
        if targeted:
            ixs, ((x, x_p, a), (y, y_p)) = batch
            return jnp.sum(test_loss[y.argmax(1) == 6])
        else:
            return jnp.sum(test_loss)
    
    pos_accumulator = losser(state, get_in_canaries(0))
    neg_accumulator = losser(state, get_out_canaries(0))

    for i in tqdm(range(1, num_batches), desc='Computing gradients'):
        pos_accumulator += losser(state, get_in_canaries(i))
        neg_accumulator += losser(state, get_out_canaries(i))

    return pos_accumulator, neg_accumulator

def one_sample_cotangenter(state, per_sample_loss, bs, get_val_batch, test_index):
    def losser(state):
        params = state.params
        val_batch_idx = int(test_index // bs)

        # index within the batch
        k = int(test_index - val_batch_idx * bs)
        sel = slice(k, k+1)

        idx, ((x, x_p, a), (y, y_p)) = get_val_batch(val_batch_idx)
        test_batch = idx[sel], ((x[sel], x_p, a), (y[sel], y_p))
        print(f'Test example {test_index} has label {y[sel]}')
        losses = per_sample_loss(params, state.batch_stats, test_batch, train=False) * bs

        # ensure that num elements in losses is 1
        assert losses.shape == (1,)

        # val_batch: idx, (x, y)
        return jnp.sum(losses)

    return jax.grad(losser, allow_int=True)(state)

def poison_batch_cotangenter(losser0, get_grads, apply_grads, num_poison):
    print('Recompiling cotangenter for poison size {}'.format(num_poison))
    # eps = jnp.zeros(32 * 32 * 3 + 10, device=jax.devices('gpu')[0])
    # (_, eps_template_ims, _), (_, eps_template_labs) = batch[1]
    eps = (jnp.zeros((num_poison, 32, 32, 3), device=jax.devices('gpu')[0]),
           jnp.zeros((num_poison, 10), device=jax.devices('gpu')[0]))
    def new_get_grads(eps, state, batch):
        idx, (_ims, _labs) = batch
        good_ims, poison_ims, augmenter = _ims
        good_labs, poison_labs = _labs
        return get_grads(state, 
                         (idx, ((good_ims, poison_ims + eps[0], augmenter), 
                                (good_labs, poison_labs + eps[1]))), 
                         losser0)

    def new_apply_grads(eps, state, grads, updates):
        return apply_grads(state, grads, updates)

    return eps, new_get_grads, new_apply_grads

def safe_norm(v, **kwargs):
    if jnp.linalg.norm(v) == 0:
        print(f'Gradient with shape {v.shape} is zero, using random noise instead')
        key = jax.random.PRNGKey(int(time.time()))
        v = jax.random.normal(key, shape=v.shape, dtype=v.dtype)
    return v / jnp.linalg.norm(v, **kwargs)

def backtrack_search_step(x0, full_grad, step, f, f_x0, alpha, beta):
    # x0: pytree
    # full_grad: pytree of same shape as x0
    # step: pytree of same shape as x0
    # f: callable that takes weights and returns a scalar loss
    # f_x0: scalar that is f(x0)
    t = 1
    print('>> BACKTRACKING from initial loss', f_x0)

    # Calculate slope using tree_map instead of tree_dot
    slope = alpha * sum(jax.tree_util.tree_map(lambda g, s: jnp.sum(g * s), full_grad, step))

    if slope > 0:
        raise ValueError(f'Slope is {slope}, not negative')

    for _ in range(10):
        x_guess = jax.tree_util.tree_map(lambda x, s: x + t * s, x0, step)
        x_guess = (jnp.clip(x_guess[0], IMAGE_LOWER_BD, IMAGE_UPPER_BD), x_guess[1])
        f_x_guess = f(x_guess)
        f_x_guess_linear = f_x0 + t * slope
        if f_x_guess <= f_x_guess_linear:
            print('Backtracking succeeded, final loss:', f_x_guess, 'initial loss:', f_x0)
            break
        t = t * beta
    else:
        print('Backtracking failed, final loss:', f_x_guess, 'initial loss:', f_x0)

    if f_x0 < f_x_guess:
        print('Backtracking made negative progress, final loss:', f_x_guess, 'initial loss:', f_x0)

    return x_guess

def get_kwargs(model_args):
    return {
        'seed': 0,
        'test_index': 32,
        'lr': 0.2,
        # 'b1': 0.9,
        'b1': 0.85,
        'wd': 1e-5,
        'min_lr_relative': 10000,
        'bs': 250,
        'epochs': 18,
        'pct_start': 0.5,
        'optimizer': 'sgd',
        'exclude_bn': True,
        'maxits': 50_000,
        'nesterov': True,
        'schedule_momentum': False,
        'momentum_pct_final': None,
        'momentum_min_relative': None,
        'momentum_min_relative_final': None,
    } | model_args

def get_model_init_kwargs(model_init_args):
    if model_init_args is None:
        model_init_args = {}
    return {
        'bn_eps': 1e-5,
        'bn_momentum': 0.5,
        'use_fast_variance': True,
        'final_bias': True,
        'width_multiplier': 2.,
        'residual_scale': 1.,
        'final_scale': 0.125,
        'init_scale': 2.,
        'batchnorm_before_act': True,
        'activation_fn': 'gelu',
        'big_first_conv': False,
        'init_distribution': 'truncated_normal',
        'max_pool': False,
        'tta': True
    } | model_init_args

def jvp_to_grad(vjp, num_poison, train_loader):
    out = {}
    for seg_start in tqdm(vjp.keys()):
        seg_batches = vjp[seg_start]
        for seg_id, deps in seg_batches.items():
            out[seg_id] = deps
    
    # im_grad = jnp.zeros_like(base_poison_X)
    im_grad = jnp.zeros((num_poison, 32, 32, 3), device=jax.devices('gpu')[0])
    lab_grad = jnp.zeros((num_poison, 10), device=jax.devices('gpu')[0])
    for i in out:
        this_im_grad, this_lab_grad = out[i]
        idx, *_ = train_loader(i)
        if this_im_grad.shape[0] == 0:
            continue
        corresponding_inds = idx[0][idx[0] < num_poison//2]
        # Check that the norm is not too big 
        if jnp.linalg.norm(this_im_grad) > 1.0:
            print('Gradient norm is too high, skipping update at segment', i, 'with norm', jnp.linalg.norm(this_im_grad))
            continue
        
        
        im_grad = im_grad.at[corresponding_inds].add(this_im_grad[:len(corresponding_inds)])
        lab_grad = lab_grad.at[corresponding_inds].add(this_lab_grad[:len(corresponding_inds)])

    return im_grad, lab_grad

def init_poison(base_dir, resume, freeze_labels, num_poison, optimizers, clip=True, start_poison_X=None, start_poison_Y=None):
    outer_dir = os.path.dirname(base_dir)
    did_resume = False
    im_optimizer, lab_optimizer = optimizers
    opt_states = None
    if resume and os.path.exists(outer_dir):
        print('>> RESUMING')
        if freeze_labels:
            print('>> FREEZING LABELS')
        sibling_dirs = sorted(os.listdir(outer_dir), key=int)
        # Find all the sibling directories of base_dir and sort them by name
        # Get the last one
        for resume_dir in reversed(sibling_dirs):
            resume_dir = os.path.join(outer_dir, resume_dir)
            print('Resuming from', resume_dir)
            if os.path.exists(f'{resume_dir}/poison_ims.npy'):
                base_poison_X = jnp.load(f'{resume_dir}/poison_ims.npy')
                base_poison_Y = jnp.load(f'{resume_dir}/poison_labs.npy')
                all_losses = list(jnp.load(f'{resume_dir}/losses.npy', allow_pickle=True))
                with open(f'{resume_dir}/im_optimizer_state.pkl', 'rb') as f:
                    im_optimizer_state = pickle.load(f)
                with open(f'{resume_dir}/lab_optimizer_state.pkl', 'rb') as f:
                    lab_optimizer_state = pickle.load(f)
                opt_states = (im_optimizer_state, lab_optimizer_state)
                did_resume = True
                break
        else:
            print('>> NO POISON IMAGES FOUND, INITIALIZING RANDOMLY')
        
    if not did_resume: 
        if start_poison_X is None:
            im = jnp.zeros((32, 32, 3), device=jax.devices('gpu')[0])
            base_poison_X = im + jax.random.normal(jax.random.PRNGKey(0), (num_poison, 32, 32, 3)) * 0.1
            if clip:
                base_poison_X = jnp.clip(base_poison_X, IMAGE_LOWER_BD, IMAGE_UPPER_BD)
        else:
            base_poison_X = start_poison_X
        if start_poison_Y is None:
            # Make the poison logits random
            base_poison_Y = jax.random.uniform(jax.random.PRNGKey(0), (num_poison, 10), minval=0., maxval=1.)
        else:
            base_poison_Y = start_poison_Y
        all_losses = []
        if freeze_labels:
            base_poison_Y = jnp.argmax(base_poison_Y, axis=1)
            base_poison_Y = jax.nn.one_hot(base_poison_Y, 10) * 20.

    if opt_states is None:
        im_opt_state = im_optimizer.init(base_poison_X)
        lab_opt_state = lab_optimizer.init(base_poison_Y) if lab_optimizer is not None else None
        opt_states = (im_opt_state, lab_opt_state)

    # Save the initial poison images and labels
    jnp.save(f'{base_dir}/poison_ims.npy', base_poison_X)
    jnp.save(f'{base_dir}/poison_labs.npy', base_poison_Y)
    jnp.save(f'{base_dir}/losses.npy', all_losses)

    # Pickle the optimizer states
    with open(f'{base_dir}/im_optimizer_state.pkl', 'wb') as f:
        pickle.dump(opt_states[0], f)
    with open(f'{base_dir}/lab_optimizer_state.pkl', 'wb') as f:
        pickle.dump(opt_states[1], f)

    return base_poison_X, base_poison_Y, all_losses, opt_states

def run_optimization(args, model_args, logdir, datapath,
               model_init_args=None, antipoison=False, resume=False, jit_cotangent=False, clip=True,
               num_steps=1000):
    args_str = my_hash(json.dumps(args | model_args).encode('utf-8')).hexdigest()
    if model_init_args is not None:
        args_str = args_str + '_' + my_hash(json.dumps(model_init_args).encode('utf-8')).hexdigest()
    val_set_size = args['val_set_size']
    num_poison = args['num_poison']
    freeze_labels = args['freeze_labels']
    targeted = args['targeted']
    optimizer_name = args['optimizer']
    # Make a learning rate schedule that drops by a factor of 2 every 75 iterations
    lr_schedule = optax.piecewise_constant_schedule(init_value=args['im_lr'], 
                    boundaries_and_scales={args['drop_freq']: 0.5, 2 * args['drop_freq']: 0.5, 3 * args['drop_freq']: 0.5})
    if optimizer_name == 'lion':
        im_optimizer = optax.lion(learning_rate=lr_schedule, 
                                  b1=args['im_b1'], 
                                  b2=args['im_b1'],
                                  weight_decay=args['im_wd'])
    else:
        raise ValueError(f'Optimizer {optimizer_name} not recognized')

    lab_optimizer = None
    if not args['freeze_labels']:
        lab_optimizer = optax.lion(learning_rate=args['lab_lr'], 
                                   b1=args['lab_b1'], 
                                   weight_decay=args['lab_wd'])

    # Get current timestamp in seconds since epoch
    timestamp = int(datetime.now().timestamp())
    base_dir = f'{logdir}/{args_str}/{timestamp}'
    os.makedirs(base_dir, exist_ok=True)

    minibatch_size = args['minibatch_size']

    kwargs = get_kwargs(model_args)

    # model_init_kwargs = get_model_init_kwargs()
    model_init_kwargs = get_model_init_kwargs(model_init_args)

    #model initialization
    raw_model, params = construct_model(init_params=True, **model_init_kwargs)
    model = jax.tree_util.Partial(raw_model.apply)
    #model's loss function for both training and validation
    per_sample_loss = jax.tree_util.Partial(partial(cifar_per_sample_loss, bs=kwargs['bs']), model=model)
    val_per_sample_loss = jax.tree_util.Partial(cifar_corrects, model=model, bs=kwargs['bs'])
    make_batch_functions = make_make_canary_batch_functions(test_is_subset=False, test_set_size=val_set_size, datapath=datapath)

    #init poison images and labels
    optimizers = (im_optimizer, lab_optimizer)
    start_poison_X, start_poison_Y = make_batch_functions(kwargs['seed'], kwargs['bs'], kwargs['epochs'], None, range(num_poison), init_poison=True)
    base_poison_X, base_poison_Y, all_losses, opt_states = init_poison(base_dir, resume, freeze_labels, num_poison, optimizers, clip, start_poison_X, start_poison_Y)
    im_optimizer_state, lab_optimizer_state = opt_states

    assert val_set_size % kwargs['bs'] == 0
    num_val_batches = (num_poison//2) // kwargs['bs']

    #metagragradient training
    for local_it in range(max(0,len(all_losses)),num_steps):
        it = len(all_losses)
        # Clip the labels to be in the range [-20, 20] so that we don't get NaNs in the gradients
        base_poison_Y = jnp.clip(base_poison_Y, -20., 20.)

        #fix seed and reinitialize model
        this_it_seed = kwargs['seed'] + it if args['vary_seed'] else kwargs['seed']
        reinit_params = reinit_model(raw_model, seed=this_it_seed)
        kwargs['save_dir'] = f'{base_dir}/metadata/save-{it}'
        kwargs['cache_dir'] = f'{base_dir}/metadata/cache-{it}'

        # create training loader and val loader
        tl, vl, posl, negl, shuffled_canary_order = make_batch_functions(this_it_seed, kwargs['bs'], kwargs['epochs'], base_poison_X, base_poison_Y)
        model_cotangenter = jax.tree_util.Partial(partial(big_model_canary_cotangenter, get_in_canaries=posl,
                                        get_out_canaries=negl,
                                        per_sample_loss=per_sample_loss, 
                                        num_batches=num_val_batches, targeted=targeted))
        if jit_cotangent:
            model_cotangenter = jax.jit(model_cotangenter)

        #get vjp and losses
        vjp, losses, final_state, vvls, _ = okazaki_cifar(model=model, 
                                                loaders=(tl, vl), 
                                                per_sample_loss=per_sample_loss, 
                                                val_per_sample_loss=val_per_sample_loss,
                                                params=reinit_params, 
                                                model_cotangenter=model_cotangenter,
                                                batch_cotangenter=jax.tree_util.Partial(poison_batch_cotangenter),
                                                **kwargs)
        #convert vjp to gradients for images and labels
        pos_canaries_evaluation, neg_canaries_evaluation = canaries_eval(final_state, per_sample_loss, posl, negl, num_val_batches, targeted)
        
        im_grad, lab_grad = jvp_to_grad(vjp, num_poison, tl)
        
        final_im_grad = jnp.zeros((num_poison, 32, 32, 3), device=jax.devices('gpu')[0])
        final_lab_grad = jnp.zeros((num_poison, 10), device=jax.devices('gpu')[0])
        final_im_grad = final_im_grad.at[shuffled_canary_order].set(im_grad)
        final_lab_grad = final_lab_grad.at[shuffled_canary_order].set(lab_grad)
        
        def get_grad(state, batch, per_sample_loss):
            def losser(state, batch):
                params = state.params
                test_loss = per_sample_loss(params, state.batch_stats, batch, train=False)
                return jnp.sum(test_loss)
            x = jax.grad(losser, allow_int=True, argnums=1)(state, batch)
            image_grad = x[1][0][0]
            label_grad = x[1][1][0]
            return image_grad, label_grad

        for i in tqdm(range(0, num_val_batches), desc='Computing gradients'):
            im_grad, label_grad= get_grad(final_state, posl(i), per_sample_loss)
            final_im_grad = final_im_grad.at[posl(i)[0]].add(im_grad)
            final_lab_grad = final_lab_grad.at[posl(i)[0]].add(label_grad)
            
            im_grad, label_grad= get_grad(final_state, negl(i), per_sample_loss)
            final_im_grad = final_im_grad.at[negl(i)[0]].add(-im_grad)
            final_lab_grad = final_lab_grad.at[negl(i)[0]].add(-label_grad)
         
        im_grad = final_im_grad
        lab_grad = final_lab_grad 
        
        # zeroed_idx = shuffled_canary_order[num_poison // 2:]
        # im_grad = im_grad.at[zeroed_idx].set(jnp.zeros_like(im_grad[zeroed_idx]))
        # lab_grad = lab_grad.at[zeroed_idx].set(jnp.zeros_like(lab_grad[zeroed_idx]))
        
        # if not antipoison:
        #     im_grad = -im_grad
        #     lab_grad = -lab_grad

        im_grad_norm = jnp.linalg.norm(im_grad)
        lab_grad_norm = jnp.linalg.norm(lab_grad)
        if max(im_grad_norm, lab_grad_norm) > 1.0:
            print('Gradient norm is too high, skipping update')
            all_losses.append({
                'iteration': len(all_losses),
                'pos_canaries_eval': -1,
                'neg_canaries_eval': -1,
                'canaries_eval': -1,
            })
            continue

        # update images and labels
        im_update, im_optimizer_state = im_optimizer.update(im_grad, im_optimizer_state, base_poison_X)
        base_poison_X = optax.apply_updates(base_poison_X, im_update)
        if clip:
            base_poison_X = jnp.clip(base_poison_X, IMAGE_LOWER_BD, IMAGE_UPPER_BD)

        if not args['freeze_labels']:
            lab_update, lab_optimizer_state = lab_optimizer.update(lab_grad, lab_optimizer_state, base_poison_Y)
            base_poison_Y = optax.apply_updates(base_poison_Y, lab_update)

        if antipoison:
            all_losses.append(losses[kwargs['test_index']])
        else:
            all_losses.append({
                'iteration': len(all_losses),
                'val_loss': np.mean(np.array(losses, copy=False)[:kwargs['bs'] * num_val_batches]),
                'test_loss': np.mean(np.array(losses, copy=False)[kwargs['bs'] * num_val_batches:]),
                'val_acc': np.mean(np.array(vvls, copy=False)[:kwargs['bs'] * num_val_batches]),
                'test_acc': np.mean(np.array(vvls, copy=False)[kwargs['bs'] * num_val_batches:]),
                'im_grad_norm': np.linalg.norm(np.array(im_grad, copy=False)),
                'lab_grad_norm': np.linalg.norm(np.array(lab_grad, copy=False)),
                'pos_canaries_eval': pos_canaries_evaluation.item(),
                'neg_canaries_eval': neg_canaries_evaluation.item(),
                'canaries_eval': (pos_canaries_evaluation - neg_canaries_evaluation).item(),
            })
        print('Losses:', all_losses[-1])

        best_iteration =  min(all_losses, key=lambda x: x['canaries_eval'])
        current_it_metric = all_losses[-1]['canaries_eval']
        if current_it_metric == best_iteration['canaries_eval']:
            jnp.save(f'{base_dir}/poison_ims_{local_it}_{current_it_metric:.3f}.npy', base_poison_X)
            jnp.save(f'{base_dir}/poison_labs_{local_it}_{current_it_metric:.3f}.npy', base_poison_Y)
        
        print('>> SAVING to ', f'{base_dir}/poison_ims.npy')
        jnp.save(f'{base_dir}/poison_ims.npy', base_poison_X)
        jnp.save(f'{base_dir}/poison_labs.npy', base_poison_Y)
        jnp.save(f'{base_dir}/losses.npy', all_losses)
        with open(f'{base_dir}/im_optimizer_state.pkl', 'wb') as f:
            pickle.dump(im_optimizer_state, f)
        with open(f'{base_dir}/lab_optimizer_state.pkl', 'wb') as f:
            pickle.dump(lab_optimizer_state, f)

        if local_it % 10 == 0:
            jnp.save(f'{base_dir}/poison_ims-{local_it}.npy', base_poison_X)
            jnp.save(f'{base_dir}/poison_labs-{local_it}.npy', base_poison_Y)
            jnp.save(f'{base_dir}/losses-{local_it}.npy', all_losses)

        sleep(10)

    return all_losses

def dict_product(d):
    keys = d.keys()
    for values in product(*d.values()):
        yield dict(zip(keys, values))

if __name__ == '__main__':
    run_optimization({'im_lr': 0.5, 
                'lab_lr': 0.03, 
                'truncate': 0,
                'drop_freq': 150,
                'backtrack': False,
                'vary_seed': True,
                # 'alpha': 0.1,
                # 'beta': 0.5,
                'val_set_size': 0,
                'minibatch_size': 1_000,
                'num_poison': 1000,
                'targeted': False,
                'optimizer': 'lion',
                'im_b1': 0.25,
                'lab_b1': 0.25,
                'im_wd': 1e-3,
                'lab_wd': 1e-3,
                'freeze_labels': False}, 
                {'maxits': 10000}, 
                logdir='insert_your_logdir_here', 
                datapath='insert_your_data_path_here',
                antipoison=False, 
                resume=True)