#!/usr/bin/env python
"""
Main script for ARQ with real activations
Modified from main.py to properly pass dataloader to rotation
"""
import utils
import torch
import model_utils
import data_utils
import transformers
import quant_utils
import rotation_utils
import gptq_utils
import eval_utils
import hadamard_utils
import argparse
from arq_rotation_utils_real_acts import add_arq_args, integrate_arq_rotation_real_acts

def get_parser():
    parser = argparse.ArgumentParser()
    
    # Model and data
    parser.add_argument('--model', type=str, help='model name')
    parser.add_argument('--hf_token', type=str, default=None)
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--bsz', type=int, default=1, help='Batch size for calibration.')
    parser.add_argument('--cal_dataset', type=str, default='wikitext2', choices=['wikitext2', 'c4', 'ptb'], help='Calibration dataset')
    parser.add_argument('--eval_dataset', type=str, default='wikitext2', choices=['wikitext2', 'c4', 'ptb'], help='Evaluation dataset')
    
    # Rotation args
    parser.add_argument('--rotate', action='store_true', help='Rotate the weights')
    parser.add_argument('--rotate_mode', type=str, default='hadamard', choices=['hadamard', 'random', 'arq'])
    parser.add_argument('--fp32_had', action='store_true', help='Apply Hadamard in FP32 (FP16 by default)')
    
    # Weight quantization args
    parser.add_argument('--w_bits', type=int, default=16, help='Weight bits')
    parser.add_argument('--w_asym', action='store_true', help='Asymmetric weight quantization')
    parser.add_argument('--w_clip', action='store_true', help='Clipping weight quantization')
    parser.add_argument('--w_rtn', action='store_true', help='RTN weight quantization')
    parser.add_argument('--w_groupsize', type=int, default=-1, help='Weight groupsize')
    parser.add_argument('--percdamp', type=float, default=0.01, help='GPTQ percdamp')
    parser.add_argument('--act_order', action='store_true', help='GPTQ activation order')
    parser.add_argument('--int8_down_proj', action='store_true', help='Use INT8 for down_proj')
    
    # Activation quantization args
    parser.add_argument('--a_bits', type=int, default=16, help='Activation bits')
    parser.add_argument('--a_asym', action='store_true', help='Asymmetric activation quantization')
    parser.add_argument('--a_clip_ratio', type=float, default=1.0, help='Activation clipping ratio')
    parser.add_argument('--a_groupsize', type=int, default=-1, help='Activation groupsize')
    
    # KV cache quantization args
    parser.add_argument('--k_bits', type=int, default=16, help='Key bits')
    parser.add_argument('--k_asym', action='store_true', help='Asymmetric key quantization')
    parser.add_argument('--k_pre_rope', action='store_true', help='Pre-RoPE quantization for K-cache (not supported yet)')
    parser.add_argument('--k_clip_ratio', type=float, default=1.0, help='Key clipping ratio')
    parser.add_argument('--k_groupsize', type=int, default=-1, help='Key groupsize')
    parser.add_argument('--v_bits', type=int, default=16, help='Value bits')
    parser.add_argument('--v_asym', action='store_true', help='Asymmetric value quantization')
    parser.add_argument('--v_clip_ratio', type=float, default=1.0, help='Value clipping ratio')
    parser.add_argument('--v_groupsize', type=int, default=-1, help='Value groupsize')
    
    # Save/load quantized model
    parser.add_argument('--save_qmodel_path', type=str, default=None)
    parser.add_argument('--load_qmodel_path', type=str, default=None)
    
    # Evaluation args
    parser.add_argument('--seqlen', type=int, default=2048)
    parser.add_argument('--eval_ppl', action='store_true', help='Evaluate perplexity')
    parser.add_argument('--lm_eval', action='store_true', help='Evaluate on LM tasks')
    parser.add_argument('--tasks', nargs='+', default=None, help='Tasks for lm-eval')
    parser.add_argument('--wandb', action='store_true', help='Use wandb')
    parser.add_argument('--wandb_project', type=str, default='quarot')
    parser.add_argument('--wandb_id', type=str, default=None)
    parser.add_argument('--capture_layer_io', action='store_true', help='Capture layer I/O')
    parser.add_argument('--layer_idx', type=int, default=-1, help='Layer index to capture')
    
    # Add ARQ-specific args
    parser = add_arq_args(parser)
    
    return parser


def main(args):
    if args.wandb:
        import wandb
        wandb.init(project=args.wandb_project, entity=args.wandb_id)
        wandb.config.update(args)
        
    transformers.set_seed(args.seed)
    model = model_utils.get_model(args.model, args.hf_token)
    model.eval()
    
    # Get calibration dataloader early if using ARQ
    trainloader = None
    if args.rotate and args.rotate_mode == 'arq':
        trainloader = data_utils.get_loaders(
            args.cal_dataset, nsamples=args.nsamples,
            seed=args.seed, model=args.model,
            seqlen=model.seqlen, hf_token=args.hf_token, eval_mode=False
        )
    
    # Rotate the weights
    if args.rotate:
        rotation_utils.fuse_layer_norms(model)
        rotation_utils.rotate_model(model, args, dataloader=trainloader)
        utils.cleanup_memory(verbos=True)
            
        quant_utils.add_actquant(model) #Add Activation Wrapper to the model
        qlayers = quant_utils.find_qlayers(model)
        for name in qlayers:
            if 'down_proj' in name:
                had_K, K = hadamard_utils.get_hadK(model.config.intermediate_size)
                qlayers[name].online_full_had = True
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].fp32_had = args.fp32_had
            if 'o_proj' in name:
                had_K, K = hadamard_utils.get_hadK(model.config.num_attention_heads)
                qlayers[name].online_partial_had = True
                qlayers[name].had_K = had_K
                qlayers[name].K = K
                qlayers[name].had_dim = model.config.hidden_size//model.config.num_attention_heads
                qlayers[name].fp32_had = args.fp32_had
    else:
        quant_utils.add_actquant(model) #Add Activation Wrapper to the model as the rest of the code assumes it is present
        
                
    if args.w_bits < 16:
        save_dict = {}
        if args.load_qmodel_path: # Load Quantized Rotated Model
            assert args.rotate, "Model should be rotated to load a quantized model!"
            assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
            print("Load quantized model from ", args.load_qmodel_path)
            save_dict = torch.load(args.load_qmodel_path)
            model.load_state_dict(save_dict["model"])
            
        elif not args.w_rtn: # GPTQ Weight Quantization
            assert "llama" in args.model, "Only llama is supported for GPTQ!"
            
            # Get trainloader if not already created
            if trainloader is None:
                trainloader = data_utils.get_loaders(
                    args.cal_dataset, nsamples=args.nsamples,
                    seed=args.seed, model=args.model,
                    seqlen=model.seqlen, hf_token=args.hf_token, eval_mode=False
                )
            quantizers = gptq_utils.gptq_fwrd(model, trainloader, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
        else: # RTN Weight Quantization
            quantizers = gptq_utils.rtn_fwrd(model, utils.DEV, args)
            save_dict["w_quantizers"] = quantizers
            
        if args.save_qmodel_path:
            save_dict["model"] = model.state_dict()
            torch.save(save_dict, args.save_qmodel_path)


    # Add Input Quantization
    if args.a_bits < 16 or args.v_bits < 16:
        qlayers = quant_utils.find_qlayers(model, layers=[quant_utils.ActQuantWrapper])
        down_proj_groupsize = -1
        if args.a_groupsize > 0 and "llama" in args.model:
            down_proj_groupsize = utils.llama_down_proj_groupsize(model, args.a_groupsize)
        
        for name in qlayers:            
            layer_input_bits = args.a_bits
            layer_groupsize = args.a_groupsize
            layer_a_sym = not(args.a_asym)
            layer_a_clip = args.a_clip_ratio
            
            if 'v_proj' in name and args.v_bits < 16: #Set the v_proj precision
                qlayers[name].out_quantizer.configure(bits=args.v_bits,
                                              groupsize=args.v_groupsize,
                                              sym=not(args.v_asym),
                                              clip_ratio=args.v_clip_ratio)
            
            if 'lm_head' in name: #Skip lm_head quantization   
                layer_input_bits = 16
            
            if 'down_proj' in name: #Set the down_proj precision
                layer_groupsize = down_proj_groupsize    
                
            qlayers[name].quantizer.configure(bits=layer_input_bits, 
                                              groupsize=layer_groupsize,
                                              sym=layer_a_sym,
                                              clip_ratio=layer_a_clip)

    # Add KV cache quantization  
    if args.k_bits < 16:
        if args.k_pre_rope:
            raise NotImplementedError("Pre-RoPE quantization is not supported yet!")
        else:
            rope_function_name = model_utils.get_rope_function_name(model)
            layers = model_utils.get_layers(model)
            k_quant_config = {'k_bits':args.k_bits, "k_groupsize": args.k_groupsize,
                                          "k_sym": not(args.k_asym), "k_clip_ratio": args.k_clip_ratio}
            for layer in layers:
                rotation_utils.add_qk_rotation_wrapper_after_function_call_in_forward(
                            layer.self_attn, 
                            rope_function_name, 
                            config=model.config,
                            **k_quant_config)
        
    # Evaluate perplexity
    if not args.lm_eval and not args.eval_ppl:
        args.eval_ppl = True
        
    if args.eval_ppl:
        print("Evaluating perplexity...")
        testloader = data_utils.get_loaders(
            args.eval_dataset, seed=args.seed, model=args.model, 
            seqlen=model.seqlen, hf_token=args.hf_token, eval_mode=True
        )
        ppl = eval_utils.evaluator(model, testloader, utils.DEV, args)
        print(f"WikiText-2 PPL: {ppl:.3f}")
        
    # Evaluate on LM tasks
    if args.lm_eval:
        # Import lm_eval utils (compatible with 0.3.0)
        import lm_eval
        from lm_eval import tasks, evaluator
        from lm_eval.base import BaseLM
        import torch
        
        # Custom wrapper for pre-loaded model
        class PreloadedHFLM(BaseLM):
            def __init__(self, model, tokenizer, batch_size=1):
                super().__init__()
                self.model = model
                self.tokenizer = tokenizer
                self._batch_size = batch_size
                self._device = next(model.parameters()).device
                
            @property
            def eot_token_id(self):
                return self.tokenizer.eos_token_id
            
            @property
            def max_length(self):
                try:
                    return self.model.config.max_position_embeddings
                except:
                    return 2048
                    
            @property
            def vocab_size(self):
                return self.tokenizer.vocab_size
                
            @property
            def model_type(self):
                return self.model.config.model_type
                
            @property
            def device(self):
                return self._device
                
            def tok_encode(self, string: str):
                return self.tokenizer.encode(string, add_special_tokens=False)
                
            def tok_decode(self, tokens):
                return self.tokenizer.decode(tokens)
                
            def _model_call(self, inps):
                with torch.no_grad():
                    return self.model(inps)[0]
                    
            def _model_generate(self, context, max_length, eos_token_id):
                return self.model.generate(
                    context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
                )
                
            @property
            def batch_size(self):
                return self._batch_size
                
            @batch_size.setter 
            def batch_size(self, batch_size):
                self._batch_size = batch_size
                
            @property
            def max_gen_toks(self):
                return 256
        
        # Create tokenizer and model wrapper
        tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token)
        hflm = PreloadedHFLM(model=model, tokenizer=tokenizer, batch_size=args.bsz)
        
        # Use lm-eval 0.3.0 API
        print(f"\nRunning 0-shot evaluation for tasks: {args.tasks}")
        
        # Use tasks list or default
        task_list = args.tasks if args.tasks else ['hellaswag', 'arc_easy', 'arc_challenge']
        
        eval_results = evaluator.simple_evaluate(
            model=hflm,
            tasks=task_list,
            batch_size=32,
            num_fewshot=0,
            no_cache=True
        )
        
        # Extract results from lm-eval 0.3.0 format
        results = eval_results['results']
        
        print("\nLM-Eval Results:")
        print("="*50)
        
        for task_name in task_list:
            if task_name in results:
                if 'acc_norm' in results[task_name]:
                    acc = results[task_name]['acc_norm']
                    print(f"{task_name}: {acc:.4f}")
                elif 'acc' in results[task_name]:
                    acc = results[task_name]['acc']
                    print(f"{task_name}: {acc:.4f}")
            else:
                print(f"Warning: No accuracy metric found for task {task_name}")
                continue
                
            if args.wandb:
                wandb.log({task_name: acc})


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    
    # Integrate ARQ settings if using ARQ mode
    if args.rotate_mode == 'arq':
        integrate_arq_rotation_real_acts(args)
    
    main(args)