import torch
from transformers import AutoModelForCausalLM
from BT_MoE.models.hf.deepseek import DeepSeekMoEBTMoE as AutoBTMoEHFModel
from BT_MoE.core.quantize import *


def main():
    device = "cuda"
    save_dir = ""
    compress_config = BaseCompressConfig(
                                        # quantization config
                                         nbits = 3, 
                                         group_size = 64, 
                                         quant_scale = False, 
                                         quant_zero = False, 
                                         axis = 1,
                                        # compensator config
                                         iter = 10,
                                         sparse_rank = 16,
                                         dense_rank = 512,
                                         rank_strategy = None,
                                         compensator_dtype  = "int3"
                                         ) 
    model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-moe-16b-base", 
                                                 torch_dtype=torch.float16,
                                                 trust_remote_code=True)
    AutoBTMoEHFModel.compress_model(model, 
                                   compress_config=compress_config, 
                                   device=device)   
    AutoBTMoEHFModel.save_compressed(model, save_dir)




if __name__ == "__main__":
    main()

