import jax
import json
import numpy as np
import optax
import model
import jax.numpy as jnp
from flax import nnx
from dataloader import DataLoader
from snapshot import Snapshot
from datetime import datetime
import mlflow


def loss_fn(instance, data, target, mask, indices):
    logits = instance(data, mask=None, q_len=indices, kv_len=indices)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target) * mask
    return jnp.mean(loss)


@nnx.jit
def train(instance, optimizer, metrics, data, target, mask, indices):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(instance, data, target, mask, indices)
    metrics.update(loss=loss)
    optimizer.update(instance, grads)


@nnx.jit
def evaluate(instance, metrics, data, target, mask, indices):
    loss, logits = loss_fn(instance, data, target, mask, indices)
    metrics.update(loss=loss)


def main():
    # Load configuration
    with open('./configs/transformer.json', 'r') as file:
        config = json.load(file)
        mlflow.log_params(config['model'])
        mlflow.log_params(config['pretrain'])

        feature = int(config['model']['Feature'])
        attn_feature = int(config['model']['ATTN Feature'])
        ffn_feature = int(config['model']['FFN Feature'])
        num_head = int(config['model']['Head Count'])
        decoder_count = int(config['model']['Decoder Count'])
        init_scalar = float(config['model']['Init Scalar'])
        max_len = int(config['model']['Max Length'])
        rope_base = float(config['model']['RoPE Base'])

        context_len = int(config['pretrain']['Context Length'])
        peak_lr = float(config['pretrain']['Peak LR'])
        gradient_clipping = float(config['pretrain']['Grads Clipping'])
        weight_decay = float(config['pretrain']['Weight Decay'])
        total_steps = int(config['pretrain']['Total Steps'])
        warmup_steps = int(config['pretrain']['Warmup Steps'])
        anneal_steps = int(config['pretrain']['Anneal Steps'])
        batch_size = int(config['pretrain']['Batch Size'])

        if 'Accumulation' in config['pretrain']:
            accu_steps = int(config['pretrain']['Accumulation'])
        else:
            accu_steps = 1

    # Build model instance
    key = jax.random.key(0)
    instance = model.Transformer(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=30000,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.train(
        rope_base=rope_base,
        max_len=max_len
    )

    # Build optimizer & learning rate scheduler
    scheduler = optax.join_schedules(
        schedules=[
            optax.linear_schedule(
                init_value=0.0,
                end_value=peak_lr,
                transition_steps=warmup_steps,
            ),
            optax.constant_schedule(value=peak_lr),
            optax.linear_schedule(
                init_value=peak_lr,
                end_value=0.0,
                transition_steps=anneal_steps
            )
        ],
        boundaries=[warmup_steps, total_steps - anneal_steps]
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(gradient_clipping),
        optax.adamw(
            learning_rate=scheduler,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.95,
            eps=1e-6
        )
    )

    optimizer = nnx.Optimizer(instance, optax.MultiSteps(optimizer, every_k_schedule=accu_steps), wrt=nnx.Param)
    metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'))
    snap = Snapshot(path=datetime.now().strftime('snapshot/natural_language/base_transformer/pretrain_%d-%m-%y_%H:%M'))

    # Data loader setup
    train_loader = DataLoader(
        path=[
            C4_DATASET,
            WIKITEXT,
        ],
        ratio=[0.95],
        tokenizer='tokenizer_250725',
        context_len=context_len,
        batch_size=batch_size,
        mem=8,
        pad_token=0,
        dtype=np.int32,
        pattern='train_*.parquet',
        threads=8
    )

    # Train
    for step in range(total_steps * accu_steps):
        data, target, mask, indices = next(train_loader)
        train(instance, optimizer, metrics, data, target, mask, indices)

        if step % (accu_steps * 200) == 0:
            train_loss = metrics.compute().get('loss')
            grads = optimizer.opt_state.acc_grads
            metrics.reset()

            try:
                metric_dict = {
                    'train loss': float(train_loss),
                    'grad: global': float(optax.global_norm(grads)),
                    'grad: embed': float(optax.global_norm(grads['featurizer'])),
                    'learning rate': float(scheduler(step // accu_steps)),
                }

                for i in range(decoder_count):
                    metric_dict.update({
                        f'grad: decoder-{i:02d}': float(optax.global_norm(grads['decoders'][i])),
                        f'grad: attention-{i:02d}-k': float(optax.global_norm(grads['decoders'][i]['attention']['w_k'])),
                        f'grad: attention-{i:02d}-q': float(optax.global_norm(grads['decoders'][i]['attention']['w_q'])),
                        f'grad: attention-{i:02d}-v': float(optax.global_norm(grads['decoders'][i]['attention']['w_v']))
                    })

                mlflow.log_metrics(metric_dict, step=step // accu_steps)

            except Exception as e:
                print(f'Logging error at step {step}: {e}\n')

            snap.save(f'pretrain_{step // accu_steps}', instance)

    snap.save(f'pretrain_{total_steps}', instance)
    train_loader.stop()


if __name__ == '__main__':
    from util import prepare_training_env
    C4_DATASET = 'please/download/c4/dataset'
    WIKITEXT = 'please/download/wikitext/dataset'

    mlflow_credential = 'YOUR_MLFLOW_CREDENTIAL'
    mlflow_tracking_uri = 'YOUR_MLFLOW_TRACKING_URI'
    gpu_preallocation_fraction = '.99'
    prepare_training_env(mlflow_tracking_uri, gpu_preallocation_fraction, mlflow_credential)

    mlflow.set_experiment('Transformer')
    with mlflow.start_run(run_name=datetime.now().strftime('pretrain_%d/%m/%y-%H:%M')):
        main()
