import torch

def load_model(model_type, checkpoint_path, device, precision, use_tp, drafter_checkpoint_path=None, rank_group=None, group=None, use_tp_draft=False, rank_group_draft=None, group_draft=None):
    if model_type == "standard":
        from spec_benchmark.Engine.models.standard_model import Transformer
    else:
        raise ValueError(f"Invalid model type: {model_type}")
    
    with torch.device('meta'):
        if model_type in ["longspec", "longspec_tree", "eagle", "eagle_chain"]:
            model = Transformer.from_name(checkpoint_path.parent.name, drafter_checkpoint_path.parent.name)
        else:
            model = Transformer.from_name(checkpoint_path.parent.name)

    if "int8" in str(checkpoint_path):
        print("Using int8 weight-only quantization!")
        from spec_benchmark.Engine.utils import WeightOnlyInt8QuantHandler
        simple_quantizer = WeightOnlyInt8QuantHandler(model)
        model = simple_quantizer.convert_for_runtime()
    
    checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
    if "model" in checkpoint and "stories" in str(checkpoint_path):
        checkpoint = checkpoint["model"]
    model.load_state_dict(checkpoint, assign=True, strict=False)

    if drafter_checkpoint_path is not None:
        drafter_checkpoint = torch.load(str(drafter_checkpoint_path), mmap=True, weights_only=True)
        if "longspec" in model_type:
            model.glide.load_state_dict(drafter_checkpoint, assign=True)
        elif "eagle" in model_type:
            model.eagle.load_state_dict(drafter_checkpoint, assign=True)
        else:
            raise ValueError(f"Invalid model type: {model_type} with drafter checkpoint path: {drafter_checkpoint_path}")

    if use_tp:
        from spec_benchmark.Engine.utils import apply_tp
        print("Applying tensor parallel to model ...")
        apply_tp(model, rank_group, group=group)

    if use_tp_draft:
        from spec_benchmark.Engine.utils import apply_tp_eagle
        print("Applying tensor parallel to draft model ...")
        if "longspec" in model_type:
            raise NotImplementedError("LongSpec is not supported for tensor parallel yet.")
            # apply_tp_longspec(model.glide, rank_group_draft, group=group_draft)
        elif "eagle" in model_type:
            apply_tp_eagle(model.eagle, rank_group_draft, group_draft)
    
    model = model.to(device=device, dtype=precision)
    return model.eval()