import torch
import transformers
import itertools
from dataclasses import dataclass

from argparse_dataclass import ArgumentParser
from auto_gptq import BaseQuantizeConfig
from auto_gptq import AutoGPTQForCausalLM
from transformers import AutoTokenizer, LlamaConfig

from quick_extend.dataset.redpajama import RedPajamaDataset
from quick_extend.models.load_model import ModelConfig, load_model


@dataclass
class QuantizeConfig(ModelConfig):
    output_path: str = None


def parse_args():
    parser = ArgumentParser(QuantizeConfig)
    config = parser.parse_args()
    print(config)
    return config


def main():
    config = parse_args()

    model, tokenizer = load_model(config, for_gptq=True, for_training=False)

    dataset = RedPajamaDataset(tokenizer, 4096)
    examples = [{'input_ids': item[0], "attention_mask": torch.ones_like(item[0])}
                for item in itertools.islice(dataset, 256)]

    # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
    model.quantize(examples)

    # save quantized model using safetensors
    model.save_quantized(config.output_path, use_safetensors=True)


if __name__ == "__main__":
    main()
