from collections import namedtuple
from tqdm import trange
import numpy as np
from absl.app import run

import jax
import jax.numpy as jnp

from elastic.data import DatasetFactory
from tux import (
    JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
    get_float_dtype_by_name,
    set_random_seed,
    make_shard_and_gather_fns, define_flags_with_default,
    StreamingCheckpointer
)
from elastic.model import ElasticTokConfig, ElasticTok
from elastic.inference import ElasticInference


FLAGS, FLAGS_DEF = define_flags_with_default(
    search_alg='binary',
    threshold=0.003,
    max_prop_codes=1.0,
    default_prop_codes=1.0,
    seed=42,
    mesh_dim='1,-1,1,1',
    dtype='fp32',
    eval_steps=128,
    load_elastic_config='',
    update_elastic_config='',
    load_checkpoint='',
    train_dataset=DatasetFactory.get_default_config(),
    checkpointer=StreamingCheckpointer.get_default_config(),
    elastic_tok=ElasticTokConfig.get_default_config(),
    jax_distributed=JaxDistributedConfig.get_default_config(),
)


State = namedtuple('State', ['params'])


def main(argv):
    JaxDistributedConfig.initialize(FLAGS.jax_distributed)
    print(f'Started {jax.process_index()} / {jax.process_count()}')
    set_random_seed(FLAGS.seed)

    if FLAGS.load_elastic_config != '':
        elastic_config = ElasticTokConfig.load_config(FLAGS.load_elastic_config)
        updates = ElasticTokConfig(**FLAGS.elastic_tok)
        elastic_config.update(dict(
            remat_block=updates.remat_block,
            remat_attention=updates.remat_attention,
            remat_mlp=updates.remat_mlp,
            scan_attention=updates.scan_attention,
            scan_mlp=updates.scan_mlp,
            scan_query_chunk_size=updates.scan_query_chunk_size,
            scan_key_chunk_size=updates.scan_key_chunk_size,
            scan_mlp_chunk_size=updates.scan_mlp_chunk_size,
            scan_layers=updates.scan_layers,
            param_scan_axis=updates.param_scan_axis,
        ))
    else:
        elastic_config = ElasticTokConfig(**FLAGS.elastic_tok)
    if FLAGS.update_elastic_config != '':
        elastic_config.update(dict(eval(FLAGS.update_elastic_config)))
    elastic_config.update(dict(mesh_dim=FLAGS.mesh_dim))

    mesh = ElasticTokConfig.get_jax_mesh(FLAGS.mesh_dim)
    node_info = ElasticTokConfig.get_ranks_and_size(mesh)

    dataset = DatasetFactory.load_dataset(
        FLAGS.train_dataset, node_info=node_info, mesh=mesh,
        elastic_config=elastic_config
    )

    model = ElasticTok(
        elastic_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
    )

    def init_fn(rng):
        rng_generator = JaxRNG(rng)
        batch = mesh.shape['dp'] * mesh.shape['fsdp']
        params = model.init(
            vision=jnp.zeros((batch, elastic_config.max_sequence_length, np.prod(elastic_config.patch_size) * 3), dtype=jnp.int32),
            encoding_mask=jnp.ones((batch, elastic_config.max_sequence_length), dtype=bool),
            attention_mask=jnp.ones((batch, elastic_config.max_sequence_length), dtype=bool),
            segment_ids=jnp.zeros((batch, elastic_config.max_sequence_length), dtype=jnp.int32),
            position_ids=jnp.zeros((batch, elastic_config.max_sequence_length), dtype=jnp.int32),
            rngs=rng_generator(elastic_config.rng_keys()),
        )
        return State(params)

    param_shapes = jax.eval_shape(init_fn, next_rng())
    param_partition = match_partition_rules(
        ElasticTokConfig.get_partition_rules(elastic_config.scan_layers, elastic_config.param_scan_axis),
        param_shapes
    )

    shard_fns, _ = make_shard_and_gather_fns(
        param_partition, param_shapes
    )
    checkpointer = StreamingCheckpointer(
        FLAGS.checkpointer, None,
        enable=jax.process_index() == 0,
    )

    inference = ElasticInference(
        model, elastic_config, dataset.config, mesh, node_info,
        param_partition.params, FLAGS.search_alg
    )

    print(f"Threshold: {FLAGS.threshold}")
    with mesh:
        _, params = checkpointer.load_trainstate_checkpoint(
            FLAGS.load_checkpoint, param_shapes, shard_fns
        )

        step_counter = trange(0, FLAGS.eval_steps, ncols=0)
        recon_losses, meets_threshold, average_toks, average_prop = [], [], [], []
        ntoks = []
        for step, batch in zip(step_counter, dataset):
            _, recon_loss, final_ntoks = inference.inference(
                params, batch, FLAGS.threshold * 0.95, FLAGS.default_prop_codes, FLAGS.max_prop_codes
            )
            recon_loss, final_ntoks = jax.device_get((recon_loss, final_ntoks))
            ntoks.append(final_ntoks)
            recon_losses.append(recon_loss)
            meets_threshold.append(recon_loss <= FLAGS.threshold)
            average_toks.append(final_ntoks)
            average_prop.append(final_ntoks / elastic_config.max_toks)
            print(f'Recon Loss: {np.mean(recon_losses)}')
            print(f'Prop Meets Threshold: {np.mean(meets_threshold)}')
            print(f'Average Tokens Absolute: {np.mean(average_toks)}')
            print(f'Average Tokens Proportion: {np.mean(average_prop)}')
            print(np.concatenate(ntoks).mean(0))


if __name__ == "__main__":
    run(main)
