import copy
import json
import torch
import numpy as np
from typing import Union

import megatron
from megatron.core import parallel_state
from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec
from megatron.core.transformer.spec_utils import import_module
from megatron.core.transformer.moe.moe_utils import reduce_aux_losses_tracker_across_ranks, clear_aux_losses_tracker
from megatron.core.num_microbatches_calculator import get_num_microbatches

from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml


def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
    """Builds the model.

        If you set the use_legacy_models to True, it will return the legacy GPT model and if not the core GPT model.

        Args:
            pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
            post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.


        Returns:
            Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
        """

    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"

    print_rank_0('building GPT model ...')

    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=False,
            pre_process=pre_process,
            post_process=post_process,
        )
    else:
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if use_te:
                transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
            else:
                transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)

        model = GPTModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=False,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            rotary_base=args.rotary_base,
            rope_scaling=args.use_rope_scaling,
            rope_scaling_factor=args.rope_scaling_factor,
        )

    return model

def swap_letters_and_numbers(char):
    letter_map = {
        'A': 0,
        'B': 1,
        'C': 2,
        'D': 3,
        'E': 4
    }
    number_map = {
        '1': 0,
        '2': 1,
        '3': 2,
        '4': 3
    }

    if isinstance(char, int):
        return char
    elif char in letter_map:
        return letter_map[char]
    elif char in number_map:
        return number_map[char]
    else:
        raise ValueError(f"Unsupported character: '{char}'")

def to_jsonable(obj):
    if isinstance(obj, torch.Tensor):
        return obj.detach().cpu().tolist()
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, (np.integer, np.floating)):
        return obj.item()
    if isinstance(obj, dict):
        return {k: to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [to_jsonable(v) for v in obj]
    return obj

def track_eval_moe_metrics(task, iterations):
    """Track the MoE metrics for logging and print evaluation results."""
    args = get_args()

    # Aux loss logging
    reduce_aux_losses_tracker_across_ranks(validation=True)
    
    tracker = parallel_state.get_val_moe_layer_wise_logging_tracker()
    if args.additional_metrics:
        with open(f"logs/{task}_moe_log_tracker.json", "w") as f:
            tracker_copy = copy.deepcopy(tracker)
            json_data = to_jsonable(tracker_copy)
            json.dump(json_data, f, indent=2, sort_keys=True)

    loss_scale = 1 / (get_num_microbatches() * iterations)

    aux_losses = {k: v['values'].float() * loss_scale for k, v in tracker.items()}

    clear_aux_losses_tracker(validation=True)

    # ---- Evaluation printout ----
    if aux_losses:
        metrics_str = " | ".join(
            f"{name}: {loss_list.mean().item():.4f}"
            for name, loss_list in aux_losses.items()
        )

        string = f" Metryki MoE po ewaluacji | {metrics_str}"
        length = len(string) + 1

        print('-' * length)
        print(string)
        print('-' * length)
    
    return aux_losses['val_active_experts'].mean().item()