import jax
import json
import numpy as np
import optax
import model
import jax.numpy as jnp
from flax import nnx
from exp_utils import CFGDataLoader
from snapshot import Snapshot
from datetime import datetime
import mlflow
import logging
import os

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def loss_fn(instance, data, target, mask, indices):
    """ Cosine similarity loss """
    data = instance.embed_rec(data)     # `embed_rec` must be called before `embed_ema`
    target = instance.embed_ema(target) # to trigger EMA accumulation at the correct time
    predict = instance.reason(data, mask=None, q_len=indices, kv_len=indices)
    return 4 * jnp.nansum(optax.cosine_distance(predict, target) * mask) / jnp.sum(mask)


@nnx.jit
def train(instance, optimizer, metrics, data, target, mask, indices):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(instance, data, target, mask, indices)
    metrics.update(loss=loss)
    optimizer.update(instance, grads)


def main(init_snap: str = None, size: str = 'large'):
    # Load configuration
    with open('./configs/reasoner_cfg.json', 'r') as file:

        all_config = json.load(file)
        if size == 'large':
            config = all_config['model_large']
        elif size == 'middle':
            config = all_config['model_middle']
        elif size == 'small':
            config = all_config['model_small']
        else:
            raise ValueError(f'Unsupported model size: {size}')

        feature = int(config['Feature'])
        attn_feature = int(config['ATTN Feature'])
        ffn_feature = int(config['FFN Feature'])
        num_head = int(config['Head Count'])
        decoder_count = int(config['Decoder Count'])
        init_scalar = float(config['Init Scalar'])
        max_len = int(config['Max Length'])
        rope_base = float(config['RoPE Base'])


        context_len = int(all_config['sst']['Context Length'])
        peak_lr = float(all_config['sst']['Peak LR'])
        gradient_clipping = float(all_config['sst']['Grads Clipping'])
        weight_decay = float(all_config['sst']['Weight Decay'])
        total_steps = int(all_config['sst']['Total Steps'])
        warmup_steps = int(all_config['sst']['Warmup Steps'])
        batch_size = int(all_config['sst']['Batch Size'])
        anneal_steps = int(all_config['sst']['Anneal Steps'])

        if 'Accumulation' in all_config['sst']:
            accu_steps = int(all_config['sst']['Accumulation'])
        else:
            accu_steps = 1

        if size != 'large':
            accu_steps /= 2
            accu_steps = int(accu_steps)
            batch_size *= 2

    # Build model instance
    key = jax.random.key(0)
    instance = model.Reasoner(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=32,
        ema_interval=accu_steps,
        key=key,
        dtype=jnp.bfloat16
    )
    instance.train(
        rope_base=rope_base,
        max_len=max_len
    )

    if init_snap is not None:  # Continue training from constructed initialization
        snap = Snapshot(os.path.dirname(init_snap))
        instance = snap.load(os.path.basename(init_snap), instance)
        instance.reset_embed_ema(accu_steps)  # Reset EMA cache with loaded embedding values

    # Build optimizer & learning rate scheduler
    scheduler = optax.join_schedules(
        schedules=[
            optax.linear_schedule(
                init_value=0.0,
                end_value=peak_lr,
                transition_steps=warmup_steps,
            ),
            optax.constant_schedule(value=peak_lr),
            optax.linear_schedule(
                init_value=peak_lr,
                end_value=0.0,
                transition_steps=anneal_steps
            )
        ],
        boundaries=[warmup_steps, total_steps - anneal_steps]
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(gradient_clipping),
        optax.adamw(
            learning_rate=scheduler,
            weight_decay=weight_decay,
            b1=0.9,
            b2=0.95,
            eps=1e-6
        )
    )

    optimizer = nnx.Optimizer(instance, optax.MultiSteps(optimizer, every_k_schedule=accu_steps), wrt=nnx.Param)
    metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'))
    snap = Snapshot(path=datetime.now().strftime(f'snapshot/cfg_completion/reasoner/sst_cfg3f_{size}_%d-%m-%y_%H:%M'))

    # Data loader setup
    train_loader = CFGDataLoader(
        path=[PATH_TO_CFG_DATA],
        pattern='cfg_train_*.parquet',
        tokenizer='tokenizer_250725',
        pad_token=0,
        mem=128,
        batch_size=batch_size,
        context_len=context_len,
        threads=1,
        dtype=np.int32,
    )

    # Train
    for step in range(total_steps * accu_steps):
        data, target, mask, indices = next(train_loader)
        train(instance, optimizer, metrics, data, target, mask, indices)

        if step % (accu_steps * 200) == 0:
            train_loss = metrics.compute().get('loss')
            logger.info(f'Train loss at step {step // accu_steps}: {train_loss}')
            grads = optimizer.opt_state.acc_grads
            metrics.reset()

            try:
                metric_dict = {
                    'train loss': float(train_loss),
                    'grad: global': float(optax.global_norm(grads)),
                    'grad: embed': float(optax.global_norm(grads['featurizer'])),
                    'learning rate': float(scheduler(step // accu_steps)),
                }

                for i in range(decoder_count):
                    metric_dict.update({
                        f'grad: decoder-{i:02d}': float(optax.global_norm(grads['decoders'][i])),
                        f'grad: attention-{i:02d}-k': float(optax.global_norm(grads['decoders'][i]['attention']['w_k'])),
                        f'grad: attention-{i:02d}-q': float(optax.global_norm(grads['decoders'][i]['attention']['w_q'])),
                        f'grad: attention-{i:02d}-v': float(optax.global_norm(grads['decoders'][i]['attention']['w_v']))
                    })

                mlflow.log_metrics(metric_dict, step=step // accu_steps)

            except Exception as e:
                print(f'Logging error at step {step}: {e}\n')

            if step % (accu_steps * 100) == 0:
                snap.save(f'sst_cfg3f_{step // accu_steps}', instance)

    snap.save(f'sst_cfg3f_{total_steps}', instance)
    train_loader.stop()


if __name__ == '__main__':
    import sys
    PATH_TO_CFG_DATA = './data'
    model_size = sys.argv[1]
    main(f'snapshot/cfg_completion/reasoner/init_{model_size}', model_size)
