import datasets
import datetime
import deepspeed
import inspect
import json
import os
import time
import torch
import torch.distributed as dist

from ds_src.infer import infer
from ds_src.initialize import initialize
from ds_src.logging import get_logger
from ds_src.models import (
    MoLoSConfig,
    MoSLEConfig,
)
from ds_src.SVDLLM import (
    eff_eval,
    ppl_eval,
    qa_load_balance_eval,
)
from ds_src.train import train
from ds_src.utils import (
    build_dirs,
    build_replications,
    get_args,
    get_path,
    get_yaml_config,
    handle_non_serializable_object,
    set_seed,
    set_version,
)
from lm_eval import simple_evaluate
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import TaskManager

if __name__ == '__main__':
    # Get the arguments.
    args = get_args()

    # Bind the process to the GPU.
    device = torch.device(device=args.local_rank)
    torch.cuda.set_device(device=device)

    # Initialize DeepSpeed distributed environment.
    deepspeed.init_distributed()

    # Check the gradient is normal or not.
    # torch.autograd.set_detect_anomaly(
    #     mode=True,
    #     check_nan=True,
    # )

    # Set longer timeout for distributed training.
    # dist.init_process_group(
    #     backend='nccl',
    #     timeout=datetime.timedelta(seconds=7200),
    # )

    # Unify the version of processes.
    args = set_version(args=args)

    # Get the configuration.
    config = get_yaml_config(
        config_path=args.config,
        version=args.version,
    )

    if args.local_rank == 0:
        # Build the necessary directories.
        path_dict = {
            'cache': os.path.join(
                config['outputs_dir'],
                'caches',
            ),
            'checkpoint': os.path.join(
                config['outputs_dir'],
                'checkpoints',
            ),
            'log': os.path.join(
                config['outputs_dir'],
                'logs',
            ),
            'plot': os.path.join(
                config['outputs_dir'],
                'plots',
            ),
            'replication': os.path.join(
                config['outputs_dir'],
                'replications',
            ),
        }

        build_dirs(path_dict=path_dict)

        # Save all the configurations for reproducibility.
        path_dict = {
            'ds_config': args.deepspeed_config,
            'lora_config': args.lora_config,
            'model_config': args.model_config,
            'task_config': args.config,
        }
        if args.model_config is None:
            match args.model_type:
                case 'molos':
                    path_dict['model_config'] = inspect.getfile(
                        object=MoLoSConfig)
                case 'mosle':
                    path_dict['model_config'] = inspect.getfile(
                        object=MoSLEConfig)
                case _:
                    pass

        build_replications(
            path_dict=path_dict,
            outputs_dir=os.path.join(
                config['outputs_dir'],
                'replications',
            ),
        )

    # Initialize the multi-process logger.
    while True:
        try:
            logger = get_logger(
                outputs_dir=config['outputs_dir'],
                root_path=get_path(
                    source_file=__file__,
                    return_dir=True,
                ),
            )
            source = f'{get_path(source_file=__file__)}.{__name__}'

            break
        except FileNotFoundError:
            if args.local_rank != 0:
                time.sleep(1)
            else:
                raise FileNotFoundError()

    # Log the version and arguments.
    logger.log(
        message=f'Version: {args.version}',
        source=source,
    )
    logger.log(
        message=f'Arguments: {args}',
        source=source,
    )

    # Set the seed.
    set_seed(seed=config['seed'])

    match args.mode:
        case 'chat' | 'infer':
            # Initialize the necessary components.
            model, tokenizer, _, dataloader, _, _ = initialize(
                args=args,
                config=config,
            )

            # Infer the model.
            infer(
                model=model,
                tokenizer=tokenizer,
                dataloader=dataloader,
                config=config,
                is_chat=True if args.mode == 'chat' else False,
            )
        # If the current model is a member of the Gemma family, the `batch_size` should be set to half of its original value.
        case 'test':
            datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

            # Initialize the necessary components.
            model, tokenizer, _, _, _, _ = initialize(
                args=args,
                config=config,
            )

            model.eval()
            model = model.to(dtype=torch.bfloat16)

            sequence_length = 2048
            if hasattr(model, 'molos_config'):
                sequence_length = model.molos_config.sequence_length

            results = {}

            # _results = eff_eval(
            #     model=model,
            #     tokenizer=tokenizer,
            #     dataset='wikitext2',
            #     sequence_length=512,
            #     generated_sequence_length=64,
            #     batch_size=16,
            #     device=device,
            # )
            # results.update(_results)

            LM = HFLM(
                pretrained=model,
                tokenizer=tokenizer,
                device=device,
                batch_size=16,
            )
            task_manager = TaskManager()

            _results = simple_evaluate(
                model=LM,
                tasks=[
                    'arc_challenge',
                    'arc_easy',
                    'boolq',
                    'hellaswag',
                    'logiqa',
                    'mmlu',
                    'openbookqa',
                    'piqa',
                    'sciq',
                    'winogender',
                ],
                num_fewshot=0,
                device=device,
                task_manager=task_manager,
                confirm_run_unsafe_code=True,
            )
            results.update(_results['results'])

            _results, chosen_experts_counts = ppl_eval(
                model=model,
                tokenizer=tokenizer,
                datasets=[
                    'c4',
                    'wikitext2',
                ],
                sequence_length=sequence_length,
                batch_size=8,
                device=device,
            )
            results.update(_results)

            results_str = json.dumps(
                obj=results,
                ensure_ascii=False,
                indent=4,
                default=handle_non_serializable_object,
                sort_keys=True,
            )

            logger.log(
                message=results_str,
                source=source,
            )

            try:
                for dataset, chosen_expert_count in \
                        chosen_experts_counts.items():
                    logger.log(
                        message=
                        f'The expert count results of \"{dataset}\" dataset.',
                        source=source,
                    )
                    for layer_idx in range(len(chosen_expert_count) - 1):
                        logger.log(
                            message=
                            f'Layer {layer_idx}: {chosen_expert_count[layer_idx]}',
                            source=source,
                        )

                    logger.log(
                        message=f'Total: {chosen_expert_count[-1]}',
                        source=source,
                    )
            except:
                logger.log(
                    message='The expert count results are not available.',
                    level='warning',
                    source=source,
                )

            chosen_experts_counts = qa_load_balance_eval(
                model=model,
                tokenizer=tokenizer,
                datasets=['arc_challenge'],
                sequence_length=sequence_length,
                batch_size=8,
                device=device,
            )

            try:
                for dataset, chosen_expert_count in \
                        chosen_experts_counts.items():
                    logger.log(
                        message=
                        f'The expert count results of \"{dataset}\" dataset.',
                        source=source,
                    )
                    for layer_idx in range(len(chosen_expert_count) - 1):
                        logger.log(
                            message=
                            f'Layer {layer_idx}: {chosen_expert_count[layer_idx]}',
                            source=source,
                        )

                    logger.log(
                        message=f'Total: {chosen_expert_count[-1]}',
                        source=source,
                    )
            except:
                logger.log(
                    message='The expert count results are not available.',
                    level='warning',
                    source=source,
                )
        case 'train':
            # Initialize the necessary components.
            model, _, optimizer, dataloader, scheduler, client_state = \
                initialize(
                    args=args,
                    config=config,
                )

            trained_epochs = -1
            trained_steps = -1
            finishing_this_epoch = False

            if client_state:
                trained_epochs = client_state['trained_epochs']
                trained_steps = client_state['trained_steps']
                finishing_this_epoch = client_state['finishing_this_epoch']

            # Train the model.
            train(
                model=model,
                optimizer=optimizer,
                dataloader=dataloader,
                config=config,
                scheduler=scheduler,
                trained_epochs=trained_epochs,
                trained_steps=trained_steps,
                finishing_this_epoch=finishing_this_epoch,
            )
        case _:
            message = 'The mode is not recognized.'
            logger.log(
                message=message,
                level='error',
                source=source,
            )

            raise ValueError(message)
