import argparse

from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
)

from ds_src.models import (
    MoLoSConfig,
    MoLoSLlamaForCausalLM,
)
from ds_src.utils import get_json_config


def get_args() -> argparse.Namespace:
    """ Get the arguments from the command line.

    Returns:
        argparse.Namespace: The arguments from the command line.
    """

    parser = argparse.ArgumentParser()

    parser.add_argument(
        '-m',
        '--model',
        type=str,
        required=True,
        help='The model name or path.',
    )
    parser.add_argument(
        '-bm',
        '--base_model',
        type=str,
        help='The base model name or path. (only for MoLoS)',
    )
    parser.add_argument(
        '-mc',
        '--molos_config',
        type=str,
        help='The path to the MoLoS config file. (only for MoLoS)',
    )

    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()

    print(f'Start loading the {args.model} model.')

    if args.model.lower() != 'molos':
        model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=args.model)
    else:
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path=args.base_model, )

        molos_config = MoLoSConfig()

        try:
            molos_config_values = get_json_config(
                config_path=args.molos_config)
            molos_config = MoLoSConfig(**molos_config_values)
        except Exception as e:
            print(f'Failed to load the MoLoS config file with {e}. '
                  f'Using the default MoLoS config instead.')

        model = MoLoSLlamaForCausalLM(
            config=config,
            molos_config=molos_config,
        )

    print(f'The {args.model} model loaded.')

    total_params = 0
    all_experts_params = 0
    selected_experts_params = 0
    for param_name, param in model.named_parameters():
        param_number = param.numel()

        if param.__class__.__name__ == 'Params4bit':
            param_number *= 2

        if 'experts' in param_name:
            all_experts_params += param_number

        total_params += param_number

    print(f'The model name: {args.model}')
    if args.model.lower() == 'molos':
        print(f'The base model name: {args.base_model}')

    print(f'Total parameters: {total_params:,}')

    if args.model.lower() == 'molos':
        experts_number = model.molos_config.ex_num
        selected_experts_number = model.molos_config.selected_ex_num

        selected_experts_params = int(all_experts_params / experts_number *
                                      selected_experts_number)

        print(f'All experts parameters: {all_experts_params:,}')
        print(f'Selected experts parameters: {selected_experts_params:,}')
        print(
            f'Total parameters without experts: {(total_params - all_experts_params):,}'
        )
