from dataclasses import dataclass, field
import os
from os.path import isdir, isfile
from pathlib import Path
import sys

from transformers import AutoTokenizer


@dataclass
class GptqConfig:
    ckpt: str = field(
        default=None,
        metadata={
            "help": "Load quantized model. The path to the local GPTQ checkpoint."
        },
    )
    wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
    groupsize: int = field(
        default=-1,
        metadata={"help": "Groupsize to use for quantization; default uses full row."},
    )
    act_order: bool = field(
        default=True,
        metadata={"help": "Whether to apply the activation order GPTQ heuristic"},
    )


def load_gptq_quantized(model_name, gptq_config: GptqConfig):
    print("Loading GPTQ quantized model...")

    try:
        script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
        module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa")

        sys.path.insert(0, module_path)
        from llama import load_quant
    except ImportError as e:
        print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
        print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
        sys.exit(-1)

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    # only `fastest-inference-4bit` branch cares about `act_order`
    if gptq_config.act_order:
        model = load_quant(
            model_name,
            find_gptq_ckpt(gptq_config),
            gptq_config.wbits,
            gptq_config.groupsize,
            act_order=gptq_config.act_order,
        )
    else:
        # other branches
        model = load_quant(
            model_name,
            find_gptq_ckpt(gptq_config),
            gptq_config.wbits,
            gptq_config.groupsize,
        )

    return model, tokenizer


def find_gptq_ckpt(gptq_config: GptqConfig):
    if Path(gptq_config.ckpt).is_file():
        return gptq_config.ckpt

    for ext in ["*.pt", "*.safetensors"]:
        matched_result = sorted(Path(gptq_config.ckpt).glob(ext))
        if len(matched_result) > 0:
            return str(matched_result[-1])

    print("Error: gptq checkpoint not found")
    sys.exit(1)
