import torch
import transformers
def main():
    import os 
    os.environ['MODEL_VERSION'] = 'MY_LLAMA'

    import data_utils
    import eval_utils
    import utils

    args = utils.parser_gen()
    transformers.set_seed(args.seed)

    if os.environ.get("MODEL_VERSION", "MY_LLAMA")=="MY_LLAMA":
        from sparsity_utils import SrLlamaForCausalLM
        model = SrLlamaForCausalLM.from_pretrained(
            args.model,
            torch_dtype=torch.bfloat16,  
        ).cuda()

        model.act_wrapper()
        model.eval()

        import main_utils
        if args.w_bits < 16:
            main_utils.add_weight_quantization(model,args)

        if args.a_bits < 16 or args.v_bits < 16:
            main_utils.add_input_quantization(model,args)

    if os.environ.get("MODEL_VERSION", "MY_LLAMA")=="MY_QWEN":
        from sparsity_utils import SrQwen2ForCausalLM
        model = SrQwen2ForCausalLM.from_pretrained(
            args.model,
            torch_dtype=torch.bfloat16,  
        ).cuda()

        model.act_wrapper()
        model.eval()

        import main_utils
        if args.w_bits < 16:
            main_utils.add_weight_quantization(model,args)

        if args.a_bits < 16 or args.v_bits < 16:
            main_utils.add_input_quantization(model,args)
      
    # Evaluating on dataset
    model.seqlen = 2048
    testloader = data_utils.get_loaders(
            args.eval_dataset,
            seed=args.seed,
            model=args.model,
            seqlen=model.seqlen,
            hf_token=args.hf_token,
            eval_mode=True
        )
    
    dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args)

    print(f'PPL {args.eval_dataset.upper()}: {dataset_ppl:.2f}')

if __name__ == '__main__':
    main()
