from .okazaki_train import okazaki_vjp
import pandas as pd
from tqdm import tqdm
import random
import numpy as np
from scipy.stats import spearmanr
import os
import dill as pickle
from itertools import product
from jax.tree_util import tree_flatten
from .optimizers import TrainState
import inspect
import jax.numpy as jnp
import jax
from pathlib import Path
from functools import partial
import hashlib
import json
from operator import getitem
from .utils import set_dtype
from typing import Any
import optax
from .optimizers import AdamOptimizer
from .sgd import SGDOptimizer
from uuid import uuid4
from .utils import schedule_to_lr

@partial(jnp.vectorize, signature='(n),()->()')
def correct_logit(logits, label):
    return logits[label]

@partial(jnp.vectorize, signature='(n),()->()')
def highest_incorrect_logit(logits, label):
    logits = jnp.where(jnp.arange(logits.shape[0]) == label, -jnp.inf, logits)
    return jnp.max(logits)

@partial(jax.jit, static_argnums=(1,))
def custom_lr_schedule(step, kwargs):
    sched = optax.linear_onecycle_schedule(**kwargs)
    return sched(step)

# batch: idx, (ims, labs)
def cifar_per_sample_loss(params, batch_stats, batch, train, model, bs, leave_out_mask=None, 
                          total_num_poisons=0, aug_idxs=None):
    idxs, ((ims_good, ims_poison, augmenter), (labs_good, labs_poison)) = batch
    inverse_padding = None
    if isinstance(idxs, tuple):
        idxs, inverse_padding = idxs
    if aug_idxs is not None:
        assert ims_poison.shape[0] == 0, "Poison incompatible with augmentation"
    if augmenter is not None:
        # ims_good = augmenter(X=ims_good, inds=(jnp.arange(ims_good.shape[0]) if aug_idxs is None else aug_idxs))
        ims_good = augmenter(X=ims_good, inds=jnp.arange(ims_good.shape[0]))
        if len(ims_poison) > 0:
            ims_poison = augmenter(X=ims_poison, inds=ims_good.shape[0] + jnp.arange(ims_poison.shape[0]))
    labs = jnp.concatenate([labs_good, labs_poison]) if len(labs_poison) > 0 else labs_good
    labs = jax.nn.softmax(labs)
    if inverse_padding is not None:
        labs = labs[inverse_padding]
    if train:
        if len(ims_poison) == 0:
            logits, updates = model({'params': params, 'batch_stats': batch_stats}, 
                                    (ims_good,), mutable=['batch_stats'], train=True) 
        else:
            logits, updates = model({'params': params, 'batch_stats': batch_stats}, 
                                    (ims_good, ims_poison), mutable=['batch_stats'], 
                                    train=True, inverse_padder=inverse_padding)
        # assert logits.shape[0] == bs
        losses = optax.softmax_cross_entropy(logits, labs)
        if leave_out_mask is not None:
            losses = losses * leave_out_mask
        return losses / bs, updates
    else:
        logits = model({'params': params, 
                        'batch_stats': batch_stats}, (ims_good,), train=False)
        losses = optax.softmax_cross_entropy(logits, labs)
        return losses / bs

def cifar_corrects(params, batch_stats, batch, model, bs, train):
    assert train == False
    _, ((ims_good, ims_poison, augmenter), (labs_good, labs_poison)) = batch
    if augmenter is not None:
        ims_good = augmenter(ims_good)
        ims_poison = augmenter(ims_poison)
    labs = jnp.concatenate([labs_good, labs_poison])
    labs = jax.nn.softmax(labs)
    logits = model({'params': params, 
                    'batch_stats': batch_stats}, (ims_good,), train=False)
    corrects = logits.argmax(axis=-1) == labs_good.argmax(axis=-1)
    return corrects / bs

# hash str to int
def str_hash(x):
    return hashlib.md5(str(x).encode()).hexdigest()

def clean_jaxdict(d):
    to_write = {k: str(v) for k,v in d.items()}
    return to_write

global my_lr_schedule
my_lr_schedule = None

global my_momentum_schedule
my_momentum_schedule = None

def make_optimizer(initial_params, lr, wd, pct_start, pct_final,
                   num_iters, b1, b2, min_lr_relative, final_min_lr_relative,
                   eps, eps_sqrt, kahan, selective_wd, dtype, optimizer,
                   exclude_bn, nesterov=None, tanh_factor=None, 
                   factored_lr_wd=False, schedule_momentum=False,
                   momentum_pct_start=None, momentum_pct_final=None,
                   momentum_min_relative=None, momentum_min_relative_final=None,
                   custom_lr_schedule=None):
    if optimizer == 'sgd':
        assert nesterov is not None
    global my_lr_schedule
    if my_lr_schedule is None:
        print('Instantiating LR schedule')
        my_lr_schedule = optax.linear_onecycle_schedule(num_iters, 
                                                lr, 
                                                pct_start,
                                                pct_final,
                                                div_factor=min_lr_relative,
                                                final_div_factor=final_min_lr_relative)
    if custom_lr_schedule is not None:
        assert num_iters % (len(custom_lr_schedule) - 1) == 0, \
            f'num_iters {num_iters} must be divisible by {len(custom_lr_schedule) - 1}'
        # iters_per_lr = num_iters // (len(custom_lr_schedule) - 1)
        # custom_bds = {
            # i * iters_per_lr: custom_lr_schedule[i] / custom_lr_schedule[i - 1] for i in range(1, len(custom_lr_schedule))
        # }
        # print('>> Overriding LR schedule with: ', custom_bds)
        # this_custom_lr_schedule = optax.piecewise_interpolate_schedule(interpolate_type='linear',
                                                                # init_value=lr * custom_lr_schedule[0],
                                                                # boundaries_and_scales=custom_bds)
        this_custom_lr_schedule = jax.tree_util.Partial(schedule_to_lr, 
                                                        epoch_schedule=custom_lr_schedule, 
                                                        total_iters=num_iters, 
                                                        log_param=False)
        should_exclude_tree = jax.tree_util.tree_map_with_path(lambda path, _: exclude_bn and ('BatchNorm' in str(path)), initial_params['params'])
        tx = AdamOptimizer(this_custom_lr_schedule, wd, b1, b2, eps, eps_sqrt, selective_wd, factored_lr_wd, max_lr=lr) \
                if optimizer == 'adam' else SGDOptimizer(lr=this_custom_lr_schedule, 
                    wd=wd, momentum=b1, 
                    peak_lr=lr, 
                    should_exclude_tree=should_exclude_tree,
                    nesterov=nesterov,
                    schedule_momentum=schedule_momentum)
        return TrainState.create(optimizer=tx, kahan=kahan, **initial_params)


        # my_momentum_schedule = optax.linear_onecycle_schedule(num_iters,
        #                                                       b1,
        #                                                       momentum_pct_start,
        #                                                       momentum_pct_final,
        #                                                       div_factor=momentum_min_relative,
        #                                                       final_div_factor=momentum_min_relative_final)
        # my_momentum_schedule = optax.piecewise_interpolate_schedule(interpolate_type='linear',
        #                                                             init_value=b1 / momentum_min_relative,
        #                                                             boundaries_and_scales={
        #                                                                 int(momentum_pct_start * num_iters): 1.,
        #                                                                 int(momentum_pct_final * num_iters): momentum_min_relative,
        #                                                             })
    should_exclude_tree = jax.tree_util.tree_map_with_path(lambda path, _: exclude_bn and ('BatchNorm' in str(path)), initial_params['params'])
    tx = AdamOptimizer(my_lr_schedule, wd, b1, b2, eps, eps_sqrt, selective_wd, factored_lr_wd, max_lr=lr) \
            if optimizer == 'adam' else SGDOptimizer(lr=my_lr_schedule, 
                wd=wd, momentum=b1, 
                peak_lr=lr, 
                should_exclude_tree=should_exclude_tree,
                nesterov=nesterov,
                schedule_momentum=schedule_momentum)
    # import ipdb; ipdb.set_trace()
    return TrainState.create(optimizer=tx, kahan=kahan, **initial_params)

def okazaki_cifar(*, model, bs, seed, lr, wd, epochs, loaders, params, test_index, 
                  model_cotangenter, batch_cotangenter, per_sample_loss, 
                  exclude_bn=False, val_per_sample_loss=None, pct_start=0.5, 
               pct_final=1, min_lr_relative=10_000, final_min_lr_relative=10,
               b1=0.9, b2=0.999, eps=1e-5, eps_sqrt=1e-5, maxits=None,
               selective_wd=False, dtype=jnp.float32, eval_every=4000,
               drop_cfx=[], just_train=False, optimal_k_factor=1.0,
               optimizer='adam', save_model=False, 
               cache_dir=None, save_dir=None, stem='auto',
               leave_out_inds=None, leave_out_weight=None,
               nesterov=None, schedule_momentum=False,
               momentum_pct_start=0., momentum_pct_final=1.,
               momentum_min_relative=1., momentum_min_relative_final=1.,
               custom_lr_schedule=None,
               return_final_state=False):

    assert cache_dir is not None
    assert save_dir is not None

    curr_kw = inspect.currentframe().f_locals
    curr_kw = clean_jaxdict(curr_kw)

    # assert dtype in [jnp.float32]
    set_dtype('tf32', True)

    # last arg is the dataset we just remove that
    get_train_batch, get_val_batch = loaders
    n_train_ba = epochs * (get_train_batch.dataset_size // bs)
    n_val_ba = 1 * (get_val_batch.dataset_size // bs)

    if maxits is not None:
        n_train_ba = min(maxits, n_train_ba)
        n_val_ba = min(maxits, n_val_ba)

    total_acc_fn = jax.tree_util.Partial(cifar_corrects, model=model, bs=bs)

    opt_kw = {
        'initial_params': params,
        'lr': lr,
        'wd': wd,
        'pct_start': pct_start,
        'pct_final': pct_final,
        'num_iters': n_train_ba,
        'b1': b1,
        'b2': b2,
        'min_lr_relative': min_lr_relative,
        'final_min_lr_relative': final_min_lr_relative,
        'eps': eps,
        'eps_sqrt': eps_sqrt,
        'kahan': False,
        'selective_wd': selective_wd,
        'dtype': dtype,
        'optimizer': optimizer,
        'exclude_bn': exclude_bn,
        'nesterov': nesterov,
        'schedule_momentum': schedule_momentum,
        'momentum_pct_start': momentum_pct_start,
        'momentum_pct_final': momentum_pct_final,
        'momentum_min_relative': momentum_min_relative,
        'momentum_min_relative_final': momentum_min_relative_final,
        'custom_lr_schedule': custom_lr_schedule
    }

    state = make_optimizer(**opt_kw)

    common_kw = {
        'state': state,
        'train_batch_maker': get_train_batch,
        'val_batch_maker': get_val_batch,
        'train_its': n_train_ba,
        'val_its': n_val_ba,
        'per_sample_loss': per_sample_loss,
        # 'val_per_sample_loss': val_per_sample_loss,
        'val_per_sample_loss': total_acc_fn,
        'eval_every': eval_every,
        'minibs': None,
        'should_wash_state': (custom_lr_schedule is None)
    }

    return _okazaki_cifar(curr_kw, common_kw, test_index, bs, get_val_batch, 
                          model_cotangenter, batch_cotangenter,
                        per_sample_loss, optimal_k_factor, cache_dir, save_dir, stem, 
                        just_train=just_train, leave_out_inds=leave_out_inds,
                        leave_out_weight=leave_out_weight,
                        return_final_state=return_final_state)

def _okazaki_cifar(orig_kw, common_kw, test_index, bs, get_val_batch, 
                   model_cotangenter, batch_cotangenter,
                   per_sample_loss, optimal_k_factor, cache_dir, save_dir, stem, 
                   just_train=False, leave_out_inds=None, leave_out_weight=None,
                   return_final_state=False, should_wash_state=True):
    if stem == 'auto' or stem is None:
        stem = str_hash(orig_kw)

    cache_dir = Path(cache_dir) / stem
    cache_dir.mkdir(parents=True, exist_ok=True)

    save_dir = Path(save_dir) / stem
    save_dir.mkdir(parents=True, exist_ok=True)

    args_path = save_dir / 'args.json'

    # print path info
    print('>> Caching in:', cache_dir)
    print('>> Saving in:', save_dir)

    # save locals 
    if not args_path.exists():
        with open(args_path, 'w') as f:
            json.dump(orig_kw, f)
    else:
        with open(args_path, 'r') as f:
            old = json.load(f)

        assert old == orig_kw
        assert str_hash(old) == stem, (old, orig_kw)

    okazaki_kw = {
        'model_cotangenter': model_cotangenter,
        'batch_cotangenter': batch_cotangenter,
        'optimal_k_factor': optimal_k_factor,
        'cache_dir': cache_dir,
        'save_dir': save_dir,
        'return_final_state': return_final_state,
        'should_wash_state': should_wash_state
    } | common_kw

    vjp = okazaki_vjp(exit_after_forward=just_train, 
                      leave_out_inds=leave_out_inds, 
                      leave_out_weight=leave_out_weight,
                      **okazaki_kw)
    return vjp