import jax.numpy as jnp

from .. import load_ae, load_vqvae
from .teco import TECOShard
from .diffusion.latent_fdm import LatentFDMShard
from .perceiver_ar import PerceiverARShard
from .debug.teco import TECO
from .debug.diffusion.latent_fdm import LatentFDM
from .debug.perceiver_ar import PerceiverAR


def get_model(config, need_encode=None, **kwargs):
    if config.model != 'vqgan':
        if 'training' in kwargs:
            del kwargs['training']

    if 'teco' in config.model or config.model == 'perceiver_ar':
        if need_encode is None:
            need_encode = not 'encoded' in config.data_path
        vq_fns, vqvae = load_vqvae(config.vqvae_ckpt, need_encode)
        kwargs.update(vq_fns=vq_fns, vqvae=vqvae)
    elif config.model == 'latent_fdm':
        if need_encode is None:
            need_encode = not 'encoded' in config.data_path
        ae_fns, ae = load_ae(config.ae_ckpt, need_encode)
        kwargs.update(ae_fns=ae_fns, ae=ae)

    try:
        dtype = jnp.bfloat16 if config.half_precision else jnp.float32
    except:
        print('Did not find half_precision, defaulting to FP32 for', config.model)
        dtype = jnp.float32
    kwargs['dtype'] = dtype

    if config.model == 'teco':
        model_shard = TECOShard(config, **kwargs)
        model = TECO(config, **kwargs)
    elif config.model == 'latent_fdm':
        model_shard = LatentFDMShard(config, **kwargs)
        model = LatentFDM(config, **kwargs)
    elif config.model == 'perceiver_ar':
        model_shard = PerceiverARShard(config, **kwargs)
        model = PerceiverAR(config, **kwargs)

    return model, model_shard


def load_ckpt(ckpt_path, training=False, replicate=True, return_config=False, 
              default_if_none=dict(), need_encode=None, **kwargs):
    import os.path as osp
    import pickle
    from flax import jax_utils
    from flax.training import checkpoints
    from .train_utils import TrainState

    config = pickle.load(open(osp.join(ckpt_path, 'args'), 'rb'))
    for k, v in kwargs.items():
        setattr(config, k, v)
    for k, v in default_if_none.items():
        if not hasattr(config, k):
            print('did not find', k, 'setting default to', v)
            setattr(config, k, v)
    
    model, _ = get_model(config, training=training, need_encode=need_encode)
    state = checkpoints.restore_checkpoint(osp.join(ckpt_path, 'checkpoints'), None)
    state = TrainState(
        step=state['step'],
        params=state['params'],
        opt_state=state['opt_state'],
        model_state=state['model_state'],
        apply_fn=model.apply,
        tx=None
    )

    assert state is not None, f'No checkpoint found in {ckpt_path}'

    if replicate:
        state = jax_utils.replicate(state)

    if return_config:
        return model, state, config
    else:
        return model, state


def get_sample(config):
    if 'teco' in config.model:
        from .teco import sample
    elif config.model == 'latent_fdm':
        from .diffusion.latent_fdm import sample
    elif config.model == 'perceiver_ar':
        from .perceiver_ar import sample
    else:
        raise Exception(f'Model {config.model} not supported')
    return sample
