import os
import time
import torch
import logging

from lmms_eval.models import get_model

from transformers import AutoProcessor, AutoTokenizer
from vlmq.models import get_process_model
from vlmq.quantization.utils import utils, quant_utils, data_utils
from vlmq.quantization.gptq import gptq_utils, gptaq_utils
from vlmq.quantization.vlmq import vlmq_utils
import vlmq.utils.toolbox as toolbox
 

def main() -> None:
    args = utils.parser_gen()
    if args.wandb:
        import wandb
        wandb.init(project=args.wandb_project, entity=args.wandb_id)
        wandb.config.update(args)
    
    # prepare process_model    
    if '72B' in args.model_args or '32B' in args.model_args:
        ModelClass = get_model(args.model)
        lm = ModelClass.create_from_arg_string(
            args.model_args,
            {
                'batch_size': args.batch_size, 
                'device': 'cpu',
                'device_map': {"": "cpu"},
            },
        )
    else:
        ModelClass = get_model(args.model)
        lm = ModelClass.create_from_arg_string(
            args.model_args,
            {
                'batch_size': args.batch_size, 
                'device': 'cuda',
            },
        )
        

    if not (args.method == 'rtn'):
        # Preprocess the MLLM here, used to generate calib data
        ProcessModelClass = get_process_model(args.model)
        process_model = ProcessModelClass(model=lm._model, tokenizer=lm._tokenizer, processor=lm.processor if hasattr(lm, 'processor') else None)
        process_model.model.eval()
        
        # prepare calib data
        if args.model == 'llava_onevision':
            # raise NotImplementedError("Llava OneVision is not supported for calibration data generation.")
            dataloader = data_utils.get_multimodal_calib_dataset_llava(model=process_model, dev=utils.DEV, args=args)
        else:
            dataloader = data_utils.get_multimodal_calib_dataset(model=process_model, dev=utils.DEV, args=args)
    del process_model
    
    model = lm._model
    model.eval()
    
    # add fake wrapper
    quant_utils.add_actquant(model)
    
    # offload model to CPU (optional)
    model.cpu()
    utils.cleanup_memory(verbos=True)
    
    # quantization method entry  
    if args.w_bits < 16:
        
        save_dict = {}
        if args.load_qmodel_path: # Load Quantized Rotated Model
            assert args.rotate, "Model should be rotated to load a quantized model!"
            assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
            print("Load quantized model from ", args.load_qmodel_path)
            save_dict = torch.load(args.load_qmodel_path)
            model.load_state_dict(save_dict["model"], strict=False)
            
        elif args.method == 'rtn':
            quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
        
        elif args.method == 'gptq':
            quantizers = gptq_utils.gptq_fwrd(model, dataloader, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
        
        elif args.method == 'gptaq':
            quantizers = gptaq_utils.gptaq_fwrd(model, dataloader, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
        
        elif args.method == 'vlmq':
            quantizers = vlmq_utils.vlmq_fwrd(model, dataloader, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
        
        else:
            raise ValueError(f"Unsupported quantization method: {args.method}")
        
    
        # Unwrap actquant before saving
        quant_utils.unwrap_actquant(model)
        # save quantized model
        toolbox.save_args(args=args, tgt_path=f'{args.save_path}/logs')
        toolbox.save_model(model=model, tgt_path=args.save_path)
        org_path = args.model_args.split(',')[0].split('=')[-1]
        toolbox.copy_auxiliary_file(org_path=org_path, tgt_path=args.save_path)
            

    
if __name__ == "__main__":
    import time
    st = time.time()
    main()
    print(f"Total quantization time: {(time.time() - st)/3600} hours")
