import os
import tempfile
from typing import Callable, Optional

from transformers import AutoConfig, LlamaConfig, LlamaForCausalLM

from mergekit.architecture import get_architecture_info
from mergekit.config import MergeConfiguration
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
from mergekit.merge import MergeOptions, run_merge


def run_and_check_merge(
    config: MergeConfiguration,
    check_nan: bool = True,
    check_tensors: bool = True,
    validate: Optional[Callable[[str], None]] = None,
    index_json_name: Optional[str] = None,
):
    if index_json_name is None:
        index_json_name = "model.safetensors.index.json"

    with tempfile.TemporaryDirectory() as tmpdir:
        run_merge(config, out_path=tmpdir, options=MergeOptions())
        assert os.path.exists(
            os.path.join(tmpdir, index_json_name)
        ), "No index file for merge"
        assert os.path.exists(
            os.path.join(tmpdir, "config.json")
        ), "No config json produced by merge"

        if check_nan:
            # check for NaN in output
            loader = LazyTensorLoader.from_disk(tmpdir, lazy_unpickle=False)
            tp = loader.index.tensor_paths
            sorted_tensors = sorted(tp.keys(), key=lambda k: tp[k])
            for tensor_name in sorted_tensors:
                tensor = loader.get_tensor(tensor_name)
                has_nan = tensor.view(-1).isnan().any()
                assert not has_nan, "Output contains NaN"

        if check_tensors:
            config = AutoConfig.from_pretrained(tmpdir)
            arch_info = get_architecture_info(config)

            index = ShardedTensorIndex.from_disk(tmpdir)
            for weight_info in arch_info.all_weights(config):
                if weight_info.name not in index.tensor_paths:
                    raise RuntimeError(f"Output missing tensor {tensor_name}")

        if validate:
            validate(tmpdir)


def make_picollama(path: str, vocab_size: int = 64):
    cfg = LlamaConfig(
        vocab_size=vocab_size,
        hidden_size=32,
        intermediate_size=48,
        num_attention_heads=16,
        num_hidden_layers=2,
    )
    model = LlamaForCausalLM(cfg)
    model.save_pretrained(path, safe_serialization=True)
    return str(path)
