import utils
import torch
import model_utils
import data_utils
import transformers
import quant_utils
import rotation_utils
import gptq_utils
import eval_utils
import hadamard_utils

def main():
    args = utils.parser_gen()
    if args.wandb:
        import wandb
        wandb.init(project=args.wandb_project, entity=args.wandb_id)
        wandb.config.update(args)
        
    transformers.set_seed(args.seed)
    model = model_utils.get_model(args.model, args.hf_token)
    model.eval()
    
    
    # Rotate the weights
    if args.rotate:
        rotation_utils.fuse_layer_norms(model)
        rotation_utils.rotate_model(model, args)
        utils.cleanup_memory(verbos=True)
            
        quant_utils.add_actquant(model) #Add Activation Wrapper to the model
        qlayers = quant_utils.find_qlayers(model)
        for name in qlayers:
            if 'down_proj' in name:
                had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
                qlayers[name].online_full_had = True
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].fp32_had = args.fp32_had
            if 'o_proj' in name:
                had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads)
                qlayers[name].online_partial_had = True
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads
                qlayers[name].fp32_had = args.fp32_had
    else:
        quant_utils.add_actquant(model) #Add Activation Wrapper to the model as the rest of the code assumes it is present
        
                
    if args.w_bits < 16:
        save_dict = {}
        if args.load_qmodel_path: # Load Quantized Rotated Model
            assert args.rotate, "Model should be rotated to load a quantized model!"
            assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
            print("Load quantized model from ", args.load_qmodel_path)
            save_dict = torch.load(args.load_qmodel_path)
            model.load_state_dict(save_dict["model"])
            
        elif not args.w_rtn: # GPTQ Weight Quantization
            assert "llama" in args.model, "Only llama is supported for GPTQ!"
            
            trainloader = data_utils.get_loaders(
                args.cal_dataset, nsamples=args.nsamples,
                seed=args.seed, model=args.model,
                seqlen=model.seqlen, eval_mode=False
            )
            quantizers = gptq_utils.gptq_fwrd(model, trainloader, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
        else: # RTN Weight Quantization
            quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
            
        if args.save_qmodel_path:
            save_dict["model"] = model.state_dict()
            torch.save(save_dict, args.save_qmodel_path)


    # Add Input Quantization
    if args.a_bits < 16 or args.v_bits < 16:
        qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
        down_proj_groupsize = -1
        if args.a_groupsize > 0 and "llama" in args.model:
            down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize)
        
        for name in qlayers:            
            layer_input_bits = args.a_bits
            layer_groupsize = args.a_groupsize
            layer_a_sym = not(args.a_asym)
            layer_a_clip = args.a_clip_ratio
            
            if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision
                qlayers[name].out_quantizer.configure(bits=args.v_bits,
                                              groupsize=args.v_groupsize,
                                              sym=not(args.v_asym),
                                              clip_ratio=args.v_clip_ratio)
            
            if 'lm_head' in name: #Skip lm_head quantization   
                layer_input_bits = 16
            
            if 'down_proj' in name: #Set the down_proj precision
                if args.int8_down_proj:
                    layer_input_bits = 8
                layer_groupsize = down_proj_groupsize

                
            qlayers[name].quantizer.configure(bits=layer_input_bits,
                                              groupsize=layer_groupsize,
                                              sym=layer_a_sym,
                                              clip_ratio=layer_a_clip)

    if args.k_bits < 16:
        if args.k_pre_rope:
            raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
        else:
            rope_function_name = model_utils.get_rope_function_name(model)
            layers = model_utils.get_layers(model)
            k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize,
                                          "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}
            for layer in layers:
                rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
                            layer.self_attn, 
                            rope_function_name, 
                            config=model.config,
                            **k_quant_config)
        
    # Evaluating on dataset
    testloader = data_utils.get_loaders(
            args.eval_dataset,
            seed=args.seed,
            model=args.model,
            seqlen=model.seqlen,
            hf_token=args.hf_token,
            eval_mode=True
        )

    
    dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args)
    if args.wandb:
            wandb.log({'ppl/{}'.format(args.eval_dataset.upper()): dataset_ppl})

    if not args.lm_eval:
        return
    else:
        # Import lm_eval utils (compatible with 0.3.0)
        import lm_eval
        from lm_eval import tasks, evaluator
        from lm_eval.base import BaseLM
        import torch
        
        # Custom wrapper for pre-loaded model
        class PreloadedHFLM(BaseLM):
            def __init__(self, model, tokenizer, batch_size=1):
                super().__init__()
                self.model = model
                self.tokenizer = tokenizer
                self._batch_size = batch_size
                self._device = next(model.parameters()).device
                
            @property
            def eot_token_id(self):
                return self.tokenizer.eos_token_id
            
            @property
            def max_length(self):
                try:
                    return self.model.config.max_position_embeddings
                except:
                    return 2048
                    
            @property
            def vocab_size(self):
                return self.tokenizer.vocab_size
                
            @property
            def model_type(self):
                return self.model.config.model_type
                
            @property
            def device(self):
                return self._device
                
            def tok_encode(self, string: str):
                return self.tokenizer.encode(string, add_special_tokens=False)
                
            def tok_decode(self, tokens):
                return self.tokenizer.decode(tokens)
                
            def _model_call(self, inps):
                with torch.no_grad():
                    return self.model(inps)[0]
                    
            def _model_generate(self, context, max_length, eos_token_id):
                return self.model.generate(
                    context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
                )
                
            @property
            def batch_size(self):
                return self._batch_size
                
            @batch_size.setter 
            def batch_size(self, batch_size):
                self._batch_size = batch_size
                
            @property
            def max_gen_toks(self):
                return 256

        
    
    if args.distribute:
        utils.distribute_model(model)
    else:
        model.to(utils.DEV)
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token)
    hflm = PreloadedHFLM(model=model, tokenizer=tokenizer, batch_size=args.lm_eval_batch_size)

    # Use lm-eval 0.3.0 API
    print(f"\nRunning 0-shot evaluation for tasks: {args.tasks}")
    
    # Convert comma-separated string to list
    if isinstance(args.tasks, str):
        task_list = args.tasks.split(',') if args.tasks else ['arc_easy', 'arc_challenge', 'piqa', 'winogrande']
    else:
        task_list = args.tasks if args.tasks else ['arc_easy', 'arc_challenge', 'piqa', 'winogrande']
    
    eval_results = evaluator.simple_evaluate(
        model=hflm,
        tasks=task_list,
        batch_size=args.lm_eval_batch_size,
        no_cache=True,
        num_fewshot=0  # Zero-shot evaluation
    )
    
    # Extract results from lm-eval 0.3.0 format
    results = eval_results['results']
    metric_vals = {}
    for task, result in results.items():
        # Try to get acc_norm first, then acc
        if 'acc_norm' in result:
            metric_vals[task] = round(result['acc_norm'], 4)
        elif 'acc' in result:
            metric_vals[task] = round(result['acc'], 4)
        else:
            print(f"Warning: No accuracy metric found for task {task}")
            
    if metric_vals:
        metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 4)
    
    print("\nLM-Eval Results:")
    print("="*50)
    for task, acc in metric_vals.items():
        print(f"{task}: {acc}")
    print("="*50)

    if args.wandb:
        wandb.log(metric_vals)


if __name__ == '__main__':
    main()
