import jax
import json
import numpy as np
import optax
import model
import jax.numpy as jnp
from flax import nnx
from dataloader import DataLoader
from snapshot import Snapshot
from datetime import datetime
import mlflow
import os


def loss_fn(instance, data, target, mask, indices):
    """ Cosine similarity loss """
    data = instance.embed_rec(data)     # `embed_rec` must be called before `embed_ema`
    target = instance.embed_ema(target) # to trigger EMA accumulation at the correct time
    predict = instance.reason(data, mask=None, q_len=indices, kv_len=indices)
    return 4 * jnp.nansum(optax.cosine_distance(predict, target) * mask) / jnp.sum(mask)


@nnx.jit
def train(instance, optimizer, metrics, data, target, mask, indices):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(instance, data, target, mask, indices)
    metrics.update(loss=loss)
    optimizer.update(instance, grads)


def main(init_snap: str = None, size: str = 'large'):
    # Load configuration
    with open('configs/reasoner.json', 'r') as file:
        config = json.load(file)

        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])


        context_len = int(config['sst']['Context Length'])
        peak_lr = float(config['sst']['Peak LR'])
        gradient_clipping = float(config['sst']['Grads Clipping'])
        weight_decay = float(config['sst']['Weight Decay'])
        total_steps = int(config['sst']['Total Steps'])
        warmup_steps = int(config['sst']['Warmup Steps'])
        batch_size = int(config['sst']['Batch Size'])

        if 'Accumulation' in config['sst']:
            accu_steps = int(config['sst']['Accumulation'])
        else:
            accu_steps = 1

    # Build model instance
    key = jax.random.key(0)
    instance = model.Reasoner(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=30000,
        ema_interval=accu_steps,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.train(
        rope_base=rope_base,
        max_len=max_len
    )

    if init_snap is not None:  # Continue training from constructed initialization
        snap = Snapshot(os.path.dirname(init_snap))
        instance = snap.load(os.path.basename(init_snap), instance)
        instance.reset_embed_ema(accu_steps)  # Reset EMA cache with loaded embedding values

    # Build optimizer & learning rate scheduler
    scheduler = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=peak_lr,
        warmup_steps=warmup_steps,
        decay_steps=total_steps
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(gradient_clipping),
        optax.adamw(
            learning_rate=scheduler,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.95,
            eps=1e-6
        )
    )

    optimizer = nnx.Optimizer(instance, optax.MultiSteps(optimizer, every_k_schedule=accu_steps), wrt=nnx.Param)
    metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'))
    snap = Snapshot(path=datetime.now().strftime(f'snapshot/natural_language/reasoner/sst_{size}_%d-%m-%y_%H:%M'))

    # Data loader setup
    train_loader = DataLoader(
        path=[
            WIKIPEIDA_PATH,
            FINEWEB_EDU_PATH
        ],
        ratio=[0.5],
        workers=[2, 2],
        tokenizer='tokenizer_250725',
        context_len=context_len,
        batch_size=batch_size,
        mem=2,
        pad_token=0,
        dtype=np.int32,
        pattern='*.parquet',
        start_method='spawn'
    )

    # Train
    for step in range(total_steps * accu_steps):
        data, target, mask, indices = next(train_loader)
        train(instance, optimizer, metrics, data, target, mask, indices)

        if step % (accu_steps * 200) == 0:
            train_loss = metrics.compute().get('loss')
            grads = optimizer.opt_state.acc_grads
            metrics.reset()

            try:
                metric_dict = {
                    'train loss': float(train_loss),
                    'grad: global': float(optax.global_norm(grads)),
                    'grad: embed': float(optax.global_norm(grads['featurizer'])),
                    'learning rate': float(scheduler(step // accu_steps)),
                }

                for i in range(decoder_count):
                    metric_dict.update({
                        f'grad: decoder-{i:02d}': float(optax.global_norm(grads['decoders'][i])),
                        f'grad: attention-{i:02d}-k': float(optax.global_norm(grads['decoders'][i]['attention']['w_k'])),
                        f'grad: attention-{i:02d}-q': float(optax.global_norm(grads['decoders'][i]['attention']['w_q'])),
                        f'grad: attention-{i:02d}-v': float(optax.global_norm(grads['decoders'][i]['attention']['w_v']))
                    })

                mlflow.log_metrics(metric_dict, step=step // accu_steps)

            except Exception as e:
                print(f'Logging error at step {step}: {e}\n')

            snap.save(f'sst_{step // accu_steps}', instance)

    snap.save(f'sst_{total_steps}', instance)
    train_loader.shutdown()


if __name__ == '__main__':
    from util import prepare_training_env
    WIKIPEIDA_PATH = '/path/to/wikitext/dataset'
    FINEWEB_EDU_PATH = '/path/to/fineweb-edu/dataset'
    model_size = 'large'

    mlflow_credential = 'YOUR_MLFLOW_CREDENTIAL'
    mlflow_tracking_uri = 'YOUR_MLFLOW_TRACKING_URI'
    gpu_preallocation_fraction = '.99'
    prepare_training_env(mlflow_tracking_uri, gpu_preallocation_fraction, mlflow_credential)
    mlflow.set_experiment('JEPA-Reasoner')

    with mlflow.start_run(run_name=datetime.now().strftime(f'sst_{model_size}_%d/%m/%y-%H:%M')):
        main('snapshot/natural_language/reasoner/init', model_size)
