import jax
import json
import optax
import model
import jax.numpy as jnp
from flax import nnx
from snapshot import Snapshot
from datetime import datetime
from exp_utils import TreeGenerator, TreeTokenizer
import mlflow
import os


def inference_reasoner(instance, token, start, end):
    latent = instance.embed(token)
    batch = latent.shape[0]
    latent = jnp.concatenate([latent, jnp.zeros(shape=(batch, end - start, *latent.shape[2:]))], axis=1)

    def loop(idx: int, l: jnp.ndarray):
        n = instance.reason(l, mask=None, kv_len=jnp.full((batch,), idx), q_len=jnp.full((batch,), idx))
        return l.at[:, idx].set(n[:, idx - 1])

    latent = jax.lax.fori_loop(start, end, loop, latent)
    return latent[:, start:end]   # Shape: (batch, depth, feature)


def loss_fn(instance, reasoner, data, target, mask, start, end):
    latent = inference_reasoner(reasoner, data, start, end)
    logits = instance(jax.lax.stop_gradient(latent))
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target[:, start:end]) * mask[:, start:end]
    return jnp.mean(loss)


@nnx.jit(static_argnames=['start', 'end'])
def train(instance, reasoner, optimizer, metrics, data, target, mask, start, end):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(instance, reasoner, data, target, mask, start, end)
    metrics.update(loss=loss)
    optimizer.update(instance, grads)
    return grads


def load_reasoner(reasoner_snap: str, vocab_size: int):
    """ Load JEPA-Reasoner in inference mode """
    with open('./configs/reasoner_tree.json', 'r') as file:
        config = json.load(file)
        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

    reasoner = model.Reasoner(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=vocab_size,
        key=jax.random.key(0),
        dtype=jnp.bfloat16
    )
    reasoner.eval(
        rope_base=rope_base,
        max_len=max_len
    )

    snap = Snapshot(os.path.dirname(reasoner_snap))
    return snap.load(os.path.basename(reasoner_snap), reasoner, skip_ema=True)


def main(depth: int, reasoner_snap: str, init_snap: str | None = None):
    # Load configuration
    with open('./configs/reasoner_tree.json', 'r') as file:
        config = json.load(file)
        latent_feature = int(config['model']['Feature'])

    with open('./configs/talker_tree.json', 'r') as file:
        config = json.load(file)
        mlflow.log_params(config['model'])
        mlflow.log_params(config['train'])

        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

        peak_lr = float(config['train']['Peak LR'])
        gradient_clipping = float(config['train']['Grads Clipping'])
        weight_decay = float(config['train']['Weight Decay'])
        total_steps = int(config['train']['Total Steps'])
        warmup_steps = int(config['train']['Warmup Steps'])
        anneal_steps = int(config['train']['Anneal Steps'])
        batch_size = int(config['train']['Batch Size'])

        if 'Accumulation' in config['train']:
            accu_steps = int(config['train']['Accumulation'])
        else:
            accu_steps = 1

    context_len = 14 if depth == 2 else 26 if depth == 3 else 52 if depth == 4 else 64

    # Build model instance
    key = jax.random.key(0)
    reasoner = load_reasoner(reasoner_snap, vocab_size=32)
    instance = model.MonoTalker(
        feature=feature,
        latent_feature=latent_feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=False,
        init_scalar=init_scalar,
        vocab_size=32,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.train(
        rope_base=rope_base,
        max_len=max_len
    )

    if init_snap is not None:  # Continue training from constructed initialization
        snap = Snapshot(os.path.dirname(init_snap))
        instance = snap.load(os.path.basename(init_snap), instance)

    # Build optimizer & learning rate scheduler
    scheduler = optax.join_schedules(
        schedules=[
            optax.linear_schedule(
                init_value=0.0,
                end_value=peak_lr,
                transition_steps=warmup_steps,
            ),
            optax.constant_schedule(value=peak_lr),
            optax.linear_schedule(
                init_value=peak_lr,
                end_value=0.0,
                transition_steps=anneal_steps
            )
        ],
        boundaries=[warmup_steps, total_steps - anneal_steps]
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(gradient_clipping),
        optax.adamw(
            learning_rate=scheduler,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.95,
            eps=1e-6
        )
    )

    if accu_steps == 1:
        optimizer = nnx.Optimizer(instance, optimizer, wrt=nnx.Param)
    else:
        optimizer = nnx.Optimizer(instance, optax.MultiSteps(optimizer, every_k_schedule=accu_steps), wrt=nnx.Param)
    metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'))
    snap = Snapshot(path=datetime.now().strftime('snapshot/talker/train_%d-%m-%y_%H:%M'))

    # Initialize TreeGenerator and TreeTokenizer
    tree_generator = TreeGenerator(
        depth=depth,
        node_width=2,  # Binary trees
        batch_size=batch_size
    )
    tree_tokenizer = TreeTokenizer(context_len=context_len)
    start = 10 if depth == 2 else 22 if depth == 3 else 46 if depth == 4 else 64
    end = 12 if depth == 2 else 25 if depth == 3 else 50 if depth == 4 else 64

    # Train
    for step in range(total_steps * accu_steps):
        tree_strings = tree_generator.generate_trees()
        data, _, mask, _ = tree_tokenizer.encode(tree_strings)
        # Talker is only in charge of reconstruction, so input is also the target
        grads = train(instance, reasoner, optimizer, metrics, data[:, :start], data, mask, start, end)

        if step % (accu_steps * 500) == 0:
            train_loss = metrics.compute().get('loss')
            if accu_steps != 1:
                grads = optimizer.opt_state.acc_grads
            metrics.reset()

            try:
                metric_dict = {
                    'train loss': train_loss,
                    'grad: global': optax.global_norm(grads),
                    'grad: embed': optax.global_norm(grads['featurizer']),
                    'grad: output norm': optax.global_norm(grads['output_norm']),
                    'learning rate': scheduler(step // accu_steps),
                }

                for i in range(decoder_count):
                    metric_dict.update({
                        f'grad: decoder-{i:02d}': float(optax.global_norm(grads['decoders'][i])),
                        f'grad: attention-{i:02d}-k': float(optax.global_norm(grads['decoders'][i]['attention']['w_k'])),
                        f'grad: attention-{i:02d}-q': float(optax.global_norm(grads['decoders'][i]['attention']['w_q'])),
                        f'grad: attention-{i:02d}-v': float(optax.global_norm(grads['decoders'][i]['attention']['w_v']))
                    })

                mlflow.log_metrics(metric_dict, step=step // accu_steps)

            except Exception as e:
                print(f'Logging error at step {step}: {e}\n')

            snap.save(f'train_{step // accu_steps}', instance)

    snap.save(f'train_{total_steps}', instance)


if __name__ == '__main__':
    from util import prepare_training_env
    DEPTH = 4
    REASONER_SNAP = 'snapshot/tree-search/reasoner/exp_sst_tree-search-d4/exp_tree-search_6000'

    mlflow_credential = 'YOUR_MLFLOW_CREDENTIAL'
    mlflow_tracking_uri = 'YOUR_MLFLOW_TRACKING_URI'
    gpu_preallocation_fraction = '.99'
    mlflow.end_run()
    prepare_training_env(mlflow_tracking_uri, gpu_preallocation_fraction, mlflow_credential)

    mlflow.set_experiment('JEPA-Reasoner')
    with mlflow.start_run(run_name=datetime.now().strftime('exp-talker-train_%d/%m/%y-%H:%M')):
        main(DEPTH, REASONER_SNAP, None)
