from llmfoundry.callbacks import HuggingFaceCheckpointer

import contextlib
import copy
import logging
import math
import os
import re
import shutil
import tempfile
import time
import warnings
from multiprocessing.context import SpawnProcess
from pathlib import Path
from typing import Any, Optional, Sequence, Union
import json

import numpy as np
import torch
import torch.nn as nn
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.devices import Device
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
from composer.utils import (
    dist,
    format_name_with_dist_and_time,
    maybe_create_remote_uploader_downloader_from_uri,
    parse_uri,
)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
    StateDictOptions,
    get_model_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.exceptions import StoragePermissionError
from llmfoundry.utils.huggingface_hub_utils import \
    edit_files_for_hf_compatibility

try:
    import transformer_engine.pytorch as te
    is_te_imported = True
except ModuleNotFoundError:
    is_te_imported = False

log = logging.getLogger(__name__)

__all__ = ['HuggingFaceCheckpointer']

_LICENSE_FILE_PATTERN = re.compile(r'license(\.[a-z]+|$)', re.IGNORECASE)

from contextlib import contextmanager

def find_hq_config(model):
    config = getattr(model, "hq_config", None)
    if config is not None:
        return config
    if hasattr(model, "module"):
        return find_hq_config(model.module)
    if hasattr(model, "model"):
        return find_hq_config(model.model)
    return None


class HQHFCheckpointer(HuggingFaceCheckpointer):
    def _get_hf_model(self, state: State):
        self.last_checkpoint_batch = state.timestamp.batch

        log.info('Saving HuggingFace formatted checkpoint')

        from transformers.models.auto.configuration_auto import CONFIG_MAPPING
        CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
        MPTConfig.register_for_auto_class()
        MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

        log.debug('Gathering state dict')

        if state.is_model_ddp:
            original_model: PreTrainedModel = state.model.module.model
            state_dict_model = state.model.module.model
            original_tokenizer = state.model.module.tokenizer
        elif isinstance(state.model.model, FSDP):
            original_model: PreTrainedModel = state.model.model.module
            state_dict_model = state.model.model
            original_tokenizer = state.model.tokenizer
        else:
            original_model: PreTrainedModel = state.model.model
            state_dict_model = state.model.model
            original_tokenizer = state.model.tokenizer

        cpu_offload = True

        # Add hook to move tensors to cpu to avoid CUDA OOM
        def tensor_hook(
            module: nn.Module,
            state_dict: dict[str, Any],
            prefix: str,
            *args: Any,
        ) -> dict[str, Any]:
            dtensor_fqns = []
            for fqn in state_dict.keys():
                tensor = state_dict[fqn]
                if isinstance(tensor, DTensor):
                    dtensor_fqns.append(fqn)
                    tensor = tensor.full_tensor()  # type: ignore
                    if dist.get_global_rank() == 0:
                        # Offload any DTensors to CPU
                        if cpu_offload:
                            tensor = tensor.cpu()
                        state_dict[fqn] = tensor
                    else:
                        state_dict[fqn] = None

                if isinstance(state_dict[fqn], torch.Tensor):
                    state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype)
                del tensor
            if dist.get_global_rank() != 0:
                state_dict = {}
            return state_dict

        hooks = []
        for _, module in state_dict_model.named_modules():
            hooks.append(module._register_state_dict_hook(tensor_hook),)

        state_dict = get_model_state_dict(
            state_dict_model,
            options=StateDictOptions(
                full_state_dict=True,
                cpu_offload=cpu_offload,
            ),
        )
        for hook in hooks:
            hook.remove()

        new_model_instance = None  # Need this for pyright because variable could be unbound
        whitening_state_dict = None
        if dist.get_global_rank() == 0:
            log.debug('Saving Hugging Face checkpoint in global rank 0')

            # Transform HF config before building 2nd model copy
            new_config = self.transform_config(
                original_config=original_model.config,
            )

            log.debug(f'Creating new model instance')

            # First create the model instance on meta device to avoid the
            # initialization cost.
            with init_empty_weights():
                if self.using_peft:
                    active_adapter = original_model.active_adapter
                    base_model = original_model.get_base_model()
                    new_base_model_instance = type(base_model)(new_config)

                    new_model_instance = type(original_model)(
                        new_base_model_instance,
                        original_model.peft_config[active_adapter],
                    )
                    del new_base_model_instance
                else:
                    new_model_instance = type(original_model)(new_config)
                    if new_model_instance.generation_config is not None:
                        new_model_instance.generation_config.update(
                            **original_model.generation_config.to_dict(),
                        )
            
            new_model_instance.load_state_dict(state_dict, assign=True)
            del state_dict

            whitening_state_dict = {n:p.data for n, p in new_model_instance.named_parameters() if 'whitening' in n}

            # Transform the model and tokenizer before saving
            new_model_instance, original_tokenizer = self.transform_model_and_tokenizer(
                new_model_instance,
                original_tokenizer,
            )

            # Ensure that the pretrained model name is correctly set on the saved HF checkpoint.
            if self.pretrained_model_name is not None:
                new_model_instance.name_or_path = self.pretrained_model_name
                if self.using_peft:
                    new_model_instance.base_model.name_or_path = self.pretrained_model_name
                    for k in new_model_instance.peft_config.keys():
                        new_model_instance.peft_config[
                            k
                        ].base_model_name_or_path = self.pretrained_model_name

            log.debug('Saving Hugging Face checkpoint to disk')

        return new_model_instance, original_tokenizer, whitening_state_dict


    def _save_checkpoint(
        self,
        state: State,
        logger: Logger,
        upload_to_save_folder: bool,
        register_to_mlflow: bool,
    ):
        """Save a HuggingFace formatted checkpoint.

        Args:
            state (State): The training state.
            logger (Logger): The logger.
            upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder.
            register_to_mlflow (bool): Whether to register the model to MLFlow
        """
        del logger  # unused

        save_dir = format_name_with_dist_and_time(
            str(
                Path(self.save_dir_format_str) /
                self.huggingface_folder_name_fstr,
            ),
            state.run_name,
            state.timestamp,
        )

        # Use a temporary directory if save_dir is remote.
        use_temp_dir = self.remote_ud is not None
        temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir

        new_model_instance, original_tokenizer, whitening_state_dict = self._get_hf_model(state)

        # if not os.path.exists(temp_save_dir):
        #     os.makedirs(temp_save_dir)
        os.makedirs(temp_save_dir, exist_ok=True)

        hq_config = find_hq_config(state.model)
        if hq_config is not None:
            hq_config_path = os.path.join(temp_save_dir, 'hq_config.json')
            with open(hq_config_path, 'w') as f:
                json.dump(hq_config, f)
        
        whitening_path = os.path.join(temp_save_dir, 'whitenings.pt')
        torch.save(whitening_state_dict, whitening_path)

        # if torch.distributed.get_rank() == 0:
        #     breakpoint()
        # torch.distributed.barrier()

        dist.barrier()

        if dist.get_global_rank() == 0:
            assert new_model_instance is not None
            if upload_to_save_folder:
                # This context manager casts the TE extra state in io.BytesIO format to tensor format
                # Needed for proper hf ckpt saving.
                context_manager = te.onnx_export(
                    True,
                ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
                )
                with context_manager:
                    new_model_instance.save_pretrained(
                        temp_save_dir,
                        max_shard_size='1GB',
                    )
                if original_tokenizer is not None:
                    assert isinstance(
                        original_tokenizer,
                        PreTrainedTokenizerBase,
                    )
                    original_tokenizer.save_pretrained(temp_save_dir)

                # Only need to edit files for MPT because it has custom code
                if new_model_instance.config.model_type == 'mpt':
                    log.debug('Editing MPT files for HuggingFace compatibility')
                    edit_files_for_hf_compatibility(
                        temp_save_dir,
                        self.flatten_imports,
                    )

                if self.remote_ud is not None:
                    for filename in os.listdir(temp_save_dir):
                        remote_file_name = os.path.join(save_dir, filename)
                        remote_file_uri = self.remote_ud.remote_backend.get_uri(
                            remote_file_name,
                        )
                        log.info(
                            f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
                        )
                        self.remote_ud.upload_file(
                            state=state,
                            remote_file_name=remote_file_name,
                            file_path=Path(
                                os.path.join(temp_save_dir, filename),
                            ),
                            overwrite=self.overwrite,
                        )

        dist.barrier()

        if dist.get_global_rank() == 0:
            assert new_model_instance is not None
            if self.using_peft:
                model_name = self.mlflow_logging_config.get('metadata', {}).get(
                    'pretrained_model_name',
                    None,
                )
                if model_name is not None:
                    new_model_instance.name_or_path = model_name
                    new_model_instance.model.name_or_path = model_name
                    new_model_instance.base_model.name_or_path = model_name
            if register_to_mlflow:
                self._register_hf_model(
                    temp_save_dir,
                    original_tokenizer,
                    use_temp_dir,
                    new_model_instance,
                )
            else:
                # Clean up the temporary directory if we don't need to register to mlflow.
                if use_temp_dir:
                    shutil.rmtree(temp_save_dir)
        dist.barrier()