import os
import random

import numpy as np
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import load_dataset
from transformers import AutoTokenizer, TextGenerationPipeline

pretrained_model_dir = "gpt2-xl"
quantized_model_dir = "gpt2-large-4bit-128g"



def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)

    
    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

    traindataset = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        traindataset.append({"input_ids": inp, "attention_mask": attention_mask})
    return traindataset, testenc


def main():
    from transformers import AutoTokenizer

    try:
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=False)
    except:
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)

    
    quantize_config = BaseQuantizeConfig(
        bits=4,  
        group_size=128,  
        desc_act=False,  
    )

    
    model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
    model_config = model.config.to_dict()
    seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
    if any([k in model_config for k in seq_len_keys]):
        for key in seq_len_keys:
            if key in model_config:
                model.seqlen = model_config[key]
                break
    else:
        print("can't get model's sequence length from model config, will set to 2048.")
        model.seqlen = 2048

    
    traindataset, testenc = get_wikitext2(128, 0, model.seqlen, tokenizer)

    
    
    model.quantize(traindataset, use_triton=False)

    
    model.save_quantized(quantized_model_dir)

    
    model.save_quantized(quantized_model_dir, use_safetensors=True)

    
    model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_triton=False)

    
    print(tokenizer.decode(model.generate(**tokenizer("test is", return_tensors="pt").to("cuda:0"))[0]))

    
    pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, device="cuda:0")
    print(pipeline("test is")[0]["generated_text"])


if __name__ == "__main__":
    import logging

    logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
    )

    main()
