import sys
import jax
import json
import numpy as np
import optax
import model
import jax.numpy as jnp
from flax import nnx
from exp_utils import CoconutDataLoaderMultiAr
from snapshot import Snapshot
from datetime import datetime
import time
import mlflow
import os
from util import sample_tokens
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def loss_fn(instance, data, target, mask, indices, soft_thinking_step, token_generation_step):
    input_latent = instance.embed(data)
    batch_indices = jnp.arange(indices.shape[0], dtype=jnp.int32)

    def soft_thinking_loop_body(i, carry):
        input_latent, current_indices = carry
        result = jax.lax.stop_gradient(instance.latent_reasoning(input_latent, q_len=current_indices, kv_len=current_indices))
        # Update input_latent at the specified positions
        updated_latent = input_latent.at[batch_indices, current_indices[:, None], :].set(result[batch_indices, current_indices[:, None] - 1, :])
        return updated_latent, current_indices + 1

    # Initial carry: (input_latent, indices)
    final_latent, final_indices = jax.lax.fori_loop(0, soft_thinking_step, soft_thinking_loop_body, (input_latent, indices))

    def token_generation_loop_body(i, carry):
        input_latent, current_indices = carry
        logits = jax.lax.stop_gradient(
            instance.assemble(instance.latent_reasoning(input_latent, q_len=current_indices, kv_len=current_indices))
        )
        # Update input_latent at the specified positions
        # Gather from the actual last valid position
        last_valid_logits = logits[batch_indices, current_indices - 1]  # Shape: (batch_size, vocab_size)
        last_valid_logits = last_valid_logits[:, None, :]  # Add sequence dimension back: (batch_size, 1, vocab_size)

        output_tokens = sample_tokens(last_valid_logits)
        result_embedding = jax.lax.stop_gradient(instance.embed(output_tokens))
        updated_latent = input_latent.at[batch_indices, current_indices[:, None], :].set(result_embedding[batch_indices, current_indices[:, None] - 1, :])
        return updated_latent, current_indices + 1

    final_latent, final_indices = jax.lax.fori_loop(0, token_generation_step - 1, token_generation_loop_body, (final_latent, final_indices))

    logits = instance.assemble(instance.latent_reasoning(final_latent, q_len=final_indices, kv_len=final_indices))
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=target) * mask
    return jnp.mean(loss)


@nnx.jit(static_argnames=['soft_thinking_step', 'token_generation_step'])
def train(instance, optimizer, metrics, data, target, mask, indices, soft_thinking_step, token_generation_step):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(instance, data, target, mask, indices, soft_thinking_step, token_generation_step)
    metrics.update(loss=loss)
    optimizer.update(instance, grads)


def main(init_snap: str=None, size: str = 'large'):
    # Load configuration
    with open('./configs/transformer_cfg.json', 'r') as file:
        config = json.load(file)
        mlflow.log_params(config[f'model_{size}'])
        mlflow.log_params(config['cfg_posttrain'])

        feature = int(config[f'model_{size}']['Feature'])
        attn_feature = int(config[f'model_{size}']['ATTN Feature'])
        ffn_feature = int(config[f'model_{size}']['FFN Feature'])
        num_head = int(config[f'model_{size}']['Head Count'])
        decoder_count = int(config[f'model_{size}']['Decoder Count'])
        init_scalar = float(config[f'model_{size}']['Init Scalar'])
        max_len = int(config[f'model_{size}']['Max Length'])
        rope_base = float(config[f'model_{size}']['RoPE Base'])

        context_len = int(config['cfg_posttrain']['Context Length'])
        peak_lr = float(config['cfg_posttrain']['Peak LR'])
        gradient_clipping = float(config['cfg_posttrain']['Grads Clipping'])
        weight_decay = float(config['cfg_posttrain']['Weight Decay'])
        total_steps = int(config['cfg_posttrain']['Total Steps'])
        warmup_steps = int(config['cfg_posttrain']['Warmup Steps'])
        batch_size = int(config['cfg_posttrain']['Batch Size'])

        if 'Accumulation' in config['cfg_posttrain']:
            accu_steps = int(config['cfg_posttrain']['Accumulation'])
        else:
            accu_steps = 1

        if size != 'large':
            accu_steps /= 2
            accu_steps = int(accu_steps)
            batch_size *= 2

    # Build model instance
    key = jax.random.key(0)
    instance = model.Transformer(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=32,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.train(
        rope_base=rope_base,
        max_len=max_len
    )

    if init_snap:
        snap = Snapshot(os.path.dirname(init_snap))
        instance = snap.load(os.path.basename(init_snap), instance)

    # Build optimizer & learning rate scheduler
    scheduler = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=peak_lr,
        warmup_steps=warmup_steps,
        decay_steps=total_steps
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(gradient_clipping),
        optax.adamw(
            learning_rate=scheduler,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.95,
            eps=1e-6
        )
    )

    optimizer = nnx.Optimizer(instance, optax.MultiSteps(optimizer, every_k_schedule=accu_steps), wrt=nnx.Param)
    metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'))
    snap = Snapshot(path=datetime.now().strftime(f'snapshot/cfg_completion/transformer/cfg3f_coco_s{SOFT_THINK_STEP}ar{TOKEN_GEN_STEP}_{size}_%d-%m-%y_%H:%M'))

    # Data loader setup
    train_loader = CoconutDataLoaderMultiAr(
        path=[PATH_TO_CFG_DATA],
        pattern='cfg_train*.parquet',
        pad_token=0,
        mem=128,
        batch_size=batch_size,
        context_len=context_len,
        threads=2,
        dtype=np.int32,
        soft_thinking_steps=SOFT_THINK_STEP + TOKEN_GEN_STEP - 1
    )

    # Train
    for step in range(total_steps * accu_steps):
        data, target, mask, indices = next(train_loader)
        train(instance, optimizer, metrics, data, target, mask, indices, SOFT_THINK_STEP, TOKEN_GEN_STEP)

        if step % (accu_steps * 100) == 0:
            train_loss = metrics.compute().get('loss')
            logger.info(f'Train loss at {step // accu_steps}: {train_loss}')
            grads = optimizer.opt_state.acc_grads
            metrics.reset()

            try:
                mlflow.log_metrics({
                    'train loss': train_loss,
                    'grad: global': optax.global_norm(grads),
                    'grad: embed': optax.global_norm(grads['featurizer']),
                    'grad: final norm': optax.global_norm(grads['norm']),
                    'learning rate': scheduler(step // accu_steps),
                }, step=step // accu_steps)

                for i in range(decoder_count):
                    mlflow.log_metrics({
                        f'grad: decoder-{i:02d}': optax.global_norm(grads['decoders'][i]),
                        f'grad: attention-{i:02d}-k': optax.global_norm(grads['decoders'][i]['attention']['w_k']),
                        f'grad: attention-{i:02d}-q': optax.global_norm(grads['decoders'][i]['attention']['w_q']),
                        f'grad: attention-{i:02d}-v': optax.global_norm(grads['decoders'][i]['attention']['w_v'])
                    }, step=step // accu_steps)

            except Exception as e:
                print(f'Logging error at step {step}: {e}\n')

            if step % (accu_steps * 200) == 0:
                snap.save(f'cfg_coco_{step // accu_steps}', instance)

    snap.save(f'cfg_{total_steps}', instance)
    train_loader.stop()
    time.sleep(10)
    sys.exit(0)



if __name__ == '__main__':
    from util import prepare_training_env
    PATH_TO_CFG_DATA = './data'
    model_size = sys.argv[1]
    init_snapshot = sys.argv[2]
    SOFT_THINK_STEP = 4
    TOKEN_GEN_STEP = 4

    mlflow_credential = 'YOUR_MLFLOW_CREDENTIALS'
    mlflow_tracking_uri = 'YOUR_MLFLOW_TRACKING_URI'
    gpu_preallocation_fraction = '.99'
    prepare_training_env(mlflow_tracking_uri, gpu_preallocation_fraction, mlflow_credential)
    mlflow.set_experiment('Transformer')
    with mlflow.start_run(run_name=datetime.now().strftime(f'cfg3f_coco_s{SOFT_THINK_STEP}ar{TOKEN_GEN_STEP}_{model_size}_%d/%m/%y-%H:%M')):
        mlflow.set_tag('train_data', 'cfg3f')
        main(init_snap=init_snapshot, size=model_size)
