import sys
import jax
import json
import numpy as np
import optax
import model
import jax.numpy as jnp
from flax import nnx
from dataloader import DataLoader
from snapshot import Snapshot
from datetime import datetime
import mlflow
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def loss_fn(instance, reasoner, data, target, mask, indices):
    embeddings = reasoner.embed(data)
    latent = reasoner.reason(embeddings)
    logits = instance(latent, input_tokens=data, latent_len=indices, token_len=indices)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target) * mask
    return jnp.mean(loss)


@nnx.jit
def train(instance, reasoner, optimizer, metrics, data, target, mask, indices):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(instance, reasoner, data, target, mask, indices)
    metrics.update(loss=loss)
    optimizer.update(instance, grads)


def load_reasoner(reasoner_snap: str, vocab_size: int):
    """ Load JEPA-Reasoner in inference mode """
    with open('reasoner.json', 'r') as file:
        config = json.load(file)
        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

    reasoner = model.Reasoner(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=vocab_size,
        key=jax.random.key(0),
        dtype=jnp.bfloat16
    )
    reasoner.eval(
        rope_base=rope_base,
        max_len=max_len
    )

    snap = Snapshot(os.path.dirname(reasoner_snap))
    return snap.load(os.path.basename(reasoner_snap), reasoner, skip_ema=True)


def main(reasoner_snap: str):
    # Load configuration
    with open('reasoner.json', 'r') as file:
        config = json.load(file)
        latent_feature = int(config['model']['Feature'])

    with open('talker.json', 'r') as file:
        config = json.load(file)
        mlflow.log_params(config['model'])
        mlflow.log_params(config['train'])

        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        encoder_count = int(config['model']['Encoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

        context_len = int(config['train']['Context Length'])
        peak_lr = float(config['train']['Peak LR'])
        gradient_clipping = float(config['train']['Grads Clipping'])
        weight_decay = float(config['train']['Weight Decay'])
        total_steps = int(config['train']['Total Steps'])
        warmup_steps = int(config['train']['Warmup Steps'])
        anneal_steps = int(config['train']['Anneal Steps'])
        batch_size = int(config['train']['Batch Size'])

        if 'Accumulation' in config['train']:
            accu_steps = int(config['train']['Accumulation'])
        else:
            accu_steps = 1

    # Build model instance
    key = jax.random.key(0)
    reasoner = load_reasoner(reasoner_snap, vocab_size=30000)
    instance = model.DualTalker(
        feature=feature,
        latent_feature=latent_feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        encoder_count=encoder_count,
        decoder_count=decoder_count,
        init_scalar=init_scalar,
        vocab_size=30000,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.train(
        rope_base=rope_base,
        max_len=max_len
    )

    # Build optimizer & learning rate scheduler
    scheduler = optax.join_schedules(
        schedules=[
            optax.linear_schedule(
                init_value=0.0,
                end_value=peak_lr,
                transition_steps=warmup_steps,
            ),
            optax.constant_schedule(value=peak_lr),
            optax.linear_schedule(
                init_value=peak_lr,
                end_value=0.0,
                transition_steps=anneal_steps
            )
        ],
        boundaries=[warmup_steps, total_steps - anneal_steps]
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(gradient_clipping),
        optax.adamw(
            learning_rate=scheduler,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.95,
            eps=1e-6
        )
    )

    optimizer = nnx.Optimizer(instance, optax.MultiSteps(optimizer, every_k_schedule=accu_steps), wrt=nnx.Param)
    metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'))
    snap = Snapshot(path=datetime.now().strftime('snapshot/natural_language/talker/train_%d-%m-%y_%H:%M'))

    # Data loader setup
    train_loader = DataLoader(
        path=[
            WIKIPEIDA_PATH,
            FINEWEB_EDU_PATH
        ],
        ratio=[0.5],
        tokenizer='tokenizer_250725',
        context_len=context_len,
        batch_size=batch_size,
        mem=2,
        pad_token=0,
        dtype=np.int32,
        pattern='*.parquet',
        threads=4
    )

    # Train
    for step in range(total_steps * accu_steps):
        data, target, mask, indices = next(train_loader)
        train(instance, reasoner, optimizer, metrics, data, target, mask, indices)

        if step % (accu_steps * 500) == 0:
            train_loss = metrics.compute().get('loss')
            logger.info(f'Train loss at step {step // accu_steps}: {train_loss}')
            grads = optimizer.opt_state.acc_grads
            metrics.reset()

            try:
                metric_dict = {
                    'train loss': train_loss,
                    'grad: global': optax.global_norm(grads),
                    'grad: embed': optax.global_norm(grads['featurizer']),
                    'grad: final norm': optax.global_norm(grads['norm']),
                    'learning rate': scheduler(step // accu_steps),
                }

                for i in range(encoder_count):
                    metric_dict.update({
                        f'grad: encoder-{i:02d}': float(optax.global_norm(grads['encoders'][i])),
                        f'grad: attention-{i:02d}': float(optax.global_norm(grads['encoders'][i]['attention'])),
                    })

                for i in range(decoder_count):
                    metric_dict.update({
                        f'grad: decoder-{i:02d}': float(optax.global_norm(grads['decoders'][i])),
                        f'grad: self-attention-{i:02d}': float(optax.global_norm(grads['decoders'][i]['self_attention'])),
                        f'grad: cross-attention-{i:02d}': float(optax.global_norm(grads['decoders'][i]['cross_attention'])),
                    })

                mlflow.log_metrics(metric_dict, step=step // accu_steps)

            except Exception as e:
                print(f'Logging error at step {step}: {e}\n')

            snap.save(f'train_{step // accu_steps}', instance)

    snap.save(f'train_{total_steps}', instance)
    train_loader.stop()
    sys.exit(0)


if __name__ == '__main__':
    from util import prepare_training_env
    WIKIPEIDA_PATH = '/path/to/wikitext/dataset'
    FINEWEB_EDU_PATH = '/path/to/fineweb-edu/dataset'

    mlflow_credential = 'YOUR_MLFLOW_CREDENTIAL'
    mlflow_tracking_uri = 'YOUR_MLFLOW_TRACKING_URI'
    gpu_preallocation_fraction = '.99'
    prepare_training_env(mlflow_tracking_uri, gpu_preallocation_fraction, mlflow_credential)
    mlflow.set_experiment('Talker')

    with mlflow.start_run(run_name=datetime.now().strftime('train_%d/%m/%y-%H:%M')):
        main('snapshot/natural_language/reasoner/sst/sst_12000')
