from argparse import ArgumentParser
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, AwqConfig
from scaled_rope.patch import *
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, dispatch_model, infer_auto_device_map
from awq.quantize.quantizer import real_quantize_model_weight
from transformers import AutoTokenizer
from tinychat.demo import gen_params, stream_output
from tinychat.stream_generators import StreamGenerator
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from tinychat.utils.prompt_templates import get_prompter
from awq.quantize.pre_quant import run_awq, apply_awq, apply_awq_ntk, apply_awq_search
from accelerate.utils import (
    OffloadedWeightsLoader,
    check_cuda_p2p_ib_support,
    check_device_map,
    extract_submodules_state_dict,
    find_tied_parameters,
    get_balanced_memory,
    infer_auto_device_map,
    is_mlu_available,
    is_musa_available,
    is_npu_available,
    is_torch_version,
    is_xpu_available,
    load_checkpoint_in_model,
    offload_state_dict,
    parse_flag_from_env,
    retie_parameters,
)
import os
# import sys
# sys.path.append(os.path.join(os.path.dirname(__file__), "llm-awq"))

# Set the CUDA_VISIBLE_DEVICES environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# torch.cuda.set_device(1)
def load_model(model, args):
    if args.custom_model:
        from scaled_rope.modeling_llama_yarn import LlamaForCausalLM
        from scaled_rope.configuration_llama import LlamaConfig
        model_cls = LlamaForCausalLM
        config_cls = LlamaConfig
    elif args.custom_model_together:
        from scaled_rope.modeling_llama_together_yarn import LlamaForCausalLM
        from scaled_rope.configuration_llama import LlamaConfig
        model_cls = LlamaForCausalLM
        config_cls = LlamaConfig
    elif args.custom_model_mistral:
        from scaled_rope.modeling_mistral_yarn import MistralForCausalLM
        from scaled_rope.configuration_mistral import MistralConfig
        model_cls = MistralForCausalLM
        config_cls = MistralConfig
    else:
        model_cls = AutoModelForCausalLM
        config_cls = AutoConfig

    config = config_cls.from_pretrained(
        model, trust_remote_code=not args.custom_model)
            
    # Note: To avoid OOM after huggingface transformers 4.36.2
    config.use_cache = False
    
    
    if args.max_position_embeddings:
        config.max_position_embeddings = args.max_position_embeddings
    if args.factor:
        config.rope_scaling["factor"] = args.factor
    if args.no_use_cache:
        config.use_cache = False
    else:
        config.use_cache = True
    if args.sliding_window_attention:
        config.sliding_window = args.sliding_window_attention
    if args.custom_model or args.custom_model_together or args.custom_model_mistral:
        if args.linear:
            config.rope_scaling = {
                "type": "linear",
                "factor": args.linear
            }
        elif args.dynamic_ntk:
            config.rope_scaling = {
                "type": "dynamic",
                "factor": args.dynamic_ntk
            }
        elif args.part_ntk:
            config.rope_scaling = {
                "type": "ntk-by-parts",
                "factor": args.part_ntk
            }
        elif args.yarn:
            config.rope_scaling = {
                "type": "yarn",
                "factor": args.yarn,
                "original_max_position_embeddings": args.original_max_position_embeddings,
            }
        elif args.dynamic_yarn:
            config.rope_scaling = {
                "type": "dynamic-yarn",
                "factor": args.factor if args.factor else (config.rope_scaling.get("factor", 1.0) if config.rope_scaling is not None else 1.0),
                "original_max_position_embeddings": args.original_max_position_embeddings if args.original_max_position_embeddings else config.rope_scaling["original_max_position_embeddings"],
                "finetuned": args.finetuned if args.finetuned else (config.rope_scaling.get("finetuned", False) if config.rope_scaling is not None else False)
            }
    else:
        if args.rerope:
            assert not args.custom_model and not args.custom_model_together
            from transformers.models.llama.modeling_llama import LlamaAttention
            from scaled_rope.LlamaReRoPE import forward_with_rerope
            LlamaAttention.forward = forward_with_rerope

    if args.load_in_8bit or args.load_in_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            load_in_8bit=args.load_in_8bit,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        torch_dtype = None
        config.pretraining_tp = 1
    else:
        quantization_config = None
        torch_dtype = torch.float16

    
    loaded = model_cls.from_pretrained(
        model,
        torch_dtype=torch_dtype,
        device_map="auto",
        trust_remote_code=not args.custom_model,
        config=config,
        quantization_config=quantization_config,
        use_flash_attention_2=args.flash_attention,
        low_cpu_mem_usage = True
    )
        
    return loaded

def load_model_awq(model, args, quant_path = None, awq_cache = None, temp = None):
    if args.custom_model:
        from scaled_rope.modeling_llama_yarn import LlamaForCausalLM
        from scaled_rope.configuration_llama import LlamaConfig
        model_cls = LlamaForCausalLM
        config_cls = LlamaConfig
    elif args.custom_model_together:
        from scaled_rope.modeling_llama_together_yarn import LlamaForCausalLM
        from scaled_rope.configuration_llama import LlamaConfig
        model_cls = LlamaForCausalLM
        config_cls = LlamaConfig
    elif args.custom_model_mistral:
        from scaled_rope.modeling_mistral_yarn import MistralForCausalLM
        from scaled_rope.configuration_mistral import MistralConfig
        model_cls = MistralForCausalLM
        config_cls = MistralConfig
    else:
        model_cls = AutoModelForCausalLM
        config_cls = AutoConfig

    config = config_cls.from_pretrained(
        model, trust_remote_code=not args.custom_model)
    if args.max_position_embeddings:
        config.max_position_embeddings = args.max_position_embeddings
    if args.factor:
        config.rope_scaling["factor"] = args.factor
    if args.no_use_cache:
        config.use_cache = False
    else:
        config.use_cache = True
    if args.sliding_window_attention:
        config.sliding_window = args.sliding_window_attention
    if args.custom_model or args.custom_model_together or args.custom_model_mistral:
        if args.linear:
            config.rope_scaling = {
                "type": "linear",
                "factor": args.linear
            }
        elif args.dynamic_ntk:
            config.rope_scaling = {
                "type": "dynamic",
                "factor": args.dynamic_ntk
            }
        elif args.part_ntk:
            config.rope_scaling = {
                "type": "ntk-by-parts",
                "factor": args.part_ntk
            }
        elif args.yarn:
            config.rope_scaling = {
                "type": "yarn",
                "factor": args.yarn,
                "original_max_position_embeddings": args.original_max_position_embeddings,
            }
        elif args.dynamic_yarn:
            config.rope_scaling = {
                "type": "dynamic-yarn",
                "factor": args.factor if args.factor else (config.rope_scaling.get("factor", 1.0) if config.rope_scaling is not None else 1.0),
                "original_max_position_embeddings": args.original_max_position_embeddings if args.original_max_position_embeddings else config.rope_scaling["original_max_position_embeddings"],
                "finetuned": args.finetuned if args.finetuned else (config.rope_scaling.get("finetuned", False) if config.rope_scaling is not None else False)
            }
    else:
        if args.rerope:
            assert not args.custom_model and not args.custom_model_together
            from transformers.models.llama.modeling_llama import LlamaAttention
            from scaled_rope.LlamaReRoPE import forward_with_rerope
            LlamaAttention.forward = forward_with_rerope

    if args.load_in_8bit or args.load_in_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            load_in_8bit=args.load_in_8bit,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        torch_dtype = None
        config.pretraining_tp = 1
    else:
        quantization_config = None
        torch_dtype = torch.bfloat16

    tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
    
    device_map = "cpu" if args.apply_hardmard else "auto"
    with init_empty_weights():
        loaded = model_cls.from_pretrained(
            model,
            config=config,
            torch_dtype=torch.float16,
            device_map=device_map,
            trust_remote_code=not args.custom_model)

    q_config = {"zero_point": True, "q_group_size": 128}
    loaded.eval()

    if quant_path is not None:
        real_quantize_model_weight(
            loaded, w_bit=4, q_config=q_config, init_only=True)

        print(f"Loading awq model checkpoint directly from {quant_path}")
        loaded = load_checkpoint_and_dispatch(
            loaded, quant_path,
            device_map="auto",
            no_split_module_classes=["LlamaDecoderLayer"]
        )
        return loaded, tokenizer

    
    if args.awq_cache is not None:
        print(f"Loading pre-computed AWQ cache results from {awq_cache} and apply to the model")
        awq_results = torch.load(awq_cache, map_location="cpu")

    if temp != 0 and temp != None and args.dynamic_with_log_distance != True:
        print(f"Applying AWQ scale and clip with temperature {temp}")
        if args.exclude_value_proj:
            print(f"Exclude value projection")
            apply_awq(loaded, awq_results, temp=temp, beta_point=args.beta_point, exclude_value_proj=True)
        else:
            print(f"not exclude value projection")
            apply_awq(loaded, awq_results, temp=temp, beta_point=args.beta_point, args=args)
    elif args.dynamic_with_log_distance:
        print(f"Applying AWQ scale and clip with dynamic_log_distance")
        apply_awq(loaded, awq_results,
                  max_tokenlength=args.original_max_position_embeddings,
                  dynamic_with_log_distance=True)
    elif args.individual_channel_up != None or args.individual_channel_down != None:
        print(f"Applying AWQ scale and clip with individual channel up {args.individual_channel_up} and down {args.individual_channel_down}")
        apply_awq(loaded, awq_results, beta_point=args.beta_point, args=args) 
    elif args.apply_hardmard:
        print(f"Apply hardmard to activations so no AWQ scale and clip")
        loaded = load_model_and_apply_patches_hadamard(loaded, args)
        # apply_awq(loaded, awq_results, beta_point=args.beta_point, args=args)
    elif args.naive_quant:
        print(f"Using awq infra to quant model to naively nearest neighbor with groups but no AWQ scale and clip")
        # load_model_and_apply_patches_hadamard(loaded, args)
    else:
        print(f"Applying AWQ scale and clip with beta point {args.beta_point}")
        apply_awq(loaded, awq_results, beta_point=args.beta_point, args=args)
    
    if args.quant_activation:
        print(f"Applying activation quantization with bit width {args.quant_activation_bitwidth}")
        patch_model_with_activation_quant(loaded, num_bits=args.quant_activation_bitwidth, group_size=128)
        # apply_awq(loaded, awq_results, beta_point=args.beta_point, args=args) 
    
    real_quantize_model_weight(
        loaded, w_bit=4, q_config=q_config)

    device_map = "auto"
    no_split_module_classes = ["LlamaDecoderLayer"]
    max_memory = None
    if isinstance(device_map, str):
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                loaded,
                max_memory=max_memory,
                no_split_module_classes=no_split_module_classes,
                low_zero=(device_map == "balanced_low_0"),
            )
            print(f"max_memory: {max_memory}")
        device_map = infer_auto_device_map(
            loaded,
            max_memory=max_memory,
            no_split_module_classes=no_split_module_classes,
        )

    print(f"device_map: {device_map}")
    loaded = dispatch_model(
        loaded,
        device_map=device_map,
    )
    
    # make_quant_attn(model, "cuda:0")
    # make_quant_norm(model)
    # make_fused_mlp(model)

    return loaded, tokenizer

def load_model_awq_search(model, args, quant_path = None, awq_cache = None):
    if args.custom_model:
        from scaled_rope.modeling_llama_yarn import LlamaForCausalLM
        from scaled_rope.configuration_llama import LlamaConfig
        model_cls = LlamaForCausalLM
        config_cls = LlamaConfig
    elif args.custom_model_together:
        from scaled_rope.modeling_llama_together_yarn import LlamaForCausalLM
        from scaled_rope.configuration_llama import LlamaConfig
        model_cls = LlamaForCausalLM
        config_cls = LlamaConfig
    elif args.custom_model_mistral:
        from scaled_rope.modeling_mistral_yarn import MistralForCausalLM
        from scaled_rope.configuration_mistral import MistralConfig
        model_cls = MistralForCausalLM
        config_cls = MistralConfig
    else:
        model_cls = AutoModelForCausalLM
        config_cls = AutoConfig

    config = config_cls.from_pretrained(
        model, trust_remote_code=not args.custom_model)
    if args.max_position_embeddings:
        config.max_position_embeddings = args.max_position_embeddings
    if args.factor:
        config.rope_scaling["factor"] = args.factor
    if args.no_use_cache:
        config.use_cache = False
    else:
        config.use_cache = True
    if args.sliding_window_attention:
        config.sliding_window = args.sliding_window_attention
    if args.custom_model or args.custom_model_together or args.custom_model_mistral:
        if args.linear:
            config.rope_scaling = {
                "type": "linear",
                "factor": args.linear
            }
        elif args.dynamic_ntk:
            config.rope_scaling = {
                "type": "dynamic",
                "factor": args.dynamic_ntk
            }
        elif args.part_ntk:
            config.rope_scaling = {
                "type": "ntk-by-parts",
                "factor": args.part_ntk
            }
        elif args.yarn:
            config.rope_scaling = {
                "type": "yarn",
                "factor": args.yarn,
                "original_max_position_embeddings": args.original_max_position_embeddings,
            }
        elif args.dynamic_yarn:
            config.rope_scaling = {
                "type": "dynamic-yarn",
                "factor": args.factor if args.factor else (config.rope_scaling.get("factor", 1.0) if config.rope_scaling is not None else 1.0),
                "original_max_position_embeddings": args.original_max_position_embeddings if args.original_max_position_embeddings else config.rope_scaling["original_max_position_embeddings"],
                "finetuned": args.finetuned if args.finetuned else (config.rope_scaling.get("finetuned", False) if config.rope_scaling is not None else False)
            }
    else:
        if args.rerope:
            assert not args.custom_model and not args.custom_model_together
            from transformers.models.llama.modeling_llama import LlamaAttention
            from scaled_rope.LlamaReRoPE import forward_with_rerope
            LlamaAttention.forward = forward_with_rerope

    if args.load_in_8bit or args.load_in_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            load_in_8bit=args.load_in_8bit,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        torch_dtype = None
        config.pretraining_tp = 1
    else:
        quantization_config = None
        torch_dtype = torch.bfloat16

    tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
    with init_empty_weights():
        loaded = model_cls.from_pretrained(
            model,
            config=config,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=not args.custom_model)

    return loaded, tokenizer

def apply_awq_search_helper(loaded, awq_cache = None, scale = None, channel = None):
    q_config = {"zero_point": True, "q_group_size": 128}
    loaded.eval()

    print(f"Loading pre-computed AWQ cache results from {awq_cache} and apply to the model")
    awq_results = torch.load(awq_cache, map_location="cpu")

    apply_awq_search(loaded, awq_results, scale, channel)

    real_quantize_model_weight(
        loaded, w_bit=4, q_config=q_config)

    device_map = "auto"
    no_split_module_classes = ["LlamaDecoderLayer"]
    max_memory = None
    if isinstance(device_map, str):
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                loaded,
                max_memory=max_memory,
                no_split_module_classes=no_split_module_classes,
                low_zero=(device_map == "balanced_low_0"),
            )
            print(f"max_memory: {max_memory}")
        device_map = infer_auto_device_map(
            loaded,
            max_memory=max_memory,
            no_split_module_classes=no_split_module_classes,
        )

    loaded = dispatch_model(
        loaded,
        device_map=device_map,
    )
    
    # make_quant_attn(model, "cuda:0")
    # make_quant_norm(model)
    # make_fused_mlp(model)

    return loaded
    
def add_args(parser: ArgumentParser):
    parser.add_argument("--dynamic-linear", action="store_true")
    parser.add_argument("--dynamic-ntk", type=float)
    parser.add_argument("--dynamic-part-ntk", action="store_true")
    parser.add_argument("--dynamic-yarn", action="store_true")
    parser.add_argument("--ntk", default= None, type=float)
    parser.add_argument("--part-ntk", type=float)
    parser.add_argument("--linear", type=float)
    parser.add_argument("--yarn", default= None, type=float)
    parser.add_argument("--rerope", type=float)
    parser.add_argument("--factor", type=float)
    parser.add_argument("--load-in-8bit", action="store_true")
    parser.add_argument("--load-in-4bit", action="store_true")
    parser.add_argument("--quant_weight_bit_width", type=str, default='int8', help='choose from [“float8”,“int8”,“int4”,“int2”]')
    parser.add_argument("--quant_activation_bit_width", type=str, default='int8', help='choose from [“float8”,“int8”]')
    parser.add_argument("--finetuned", action="store_true")
    parser.add_argument("--gpt-neox-max-length", type=int)
    parser.add_argument("--adapter", type=str)
    parser.add_argument("--max-position-embeddings", type=int)
    parser.add_argument("--original-max-position-embeddings", type=int)
    parser.add_argument("--sliding-window-attention", type=int)
    parser.add_argument("--custom-model", action="store_true")
    parser.add_argument("--custom-model-together", action="store_true")
    parser.add_argument("--custom-model-mistral", action="store_true")
    parser.add_argument("--flash-attention", action="store_true")
    parser.add_argument("--no-use-cache", action="store_true")
    return parser


def apply_patches(model, args):
    # if args.no_pi:
    #     print("No interpolation applied, skip applying patches")
    #     return model
    original_max_position_embeddings = args.original_max_position_embeddings if args.original_max_position_embeddings else 4096
    if not args.custom_model and not args.custom_model_together and not args.custom_model_mistral:
        if "GPTNeoXForCausalLM" in model.config.architectures:
            assert args.gpt_neox_max_length is not None
            patch_gptneox_for_longer_sequences(model, args.gpt_neox_max_length)
        if args.dynamic_linear:
            if "GPTNeoXForCausalLM" in model.config.architectures:
                patch_gptneox_for_scaled_rotary_embeddings(model)
            elif "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_dynamic_scaled_rotary_embeddings(model)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for dyanmic linear")
        elif args.dynamic_ntk:
            if "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_dynamic_scaled_rotary_embeddings(
                    model, ntk=args.dynamic_ntk)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for dyanmic ntk")
        elif args.dynamic_part_ntk:
            if "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_dynamic_part_ntk_rotary_embeddings(
                    model, args.finetuned)
            elif "RWForCausalLM" in model.config.architectures:
                patch_falcon_for_dynamic_part_ntk_rotary_embeddings(model)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for dyanmic part ntk")
        elif args.dynamic_yarn:
            if "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_dynamic_yarn_rotary_embeddings(
                    model, args.original_max_position_embeddings, args.finetuned)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for dyanmic yarn")
        elif args.ntk:
            if "GPTNeoXForCausalLM" in model.config.architectures:
                patch_gptneox_for_ntk_scaled_rotary_embeddings(
                    model, args.ntk)
            elif "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_ntk_scaled_rotary_embeddings(model, args.ntk)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for ntk")
        elif args.linear:
            if "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_linear_scaled_rotary_embeddings(
                    model, scale=args.linear)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for linear")
        elif args.part_ntk:
            if "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_part_ntk_scaled_rotary_embeddings(
                    model, scale=args.part_ntk)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for part ntk")
        elif args.yarn:
            if "LlamaForCausalLM" in model.config.architectures:
                patch_llama_for_yarn_scaled_rotary_embeddings(
                    model, scale=args.yarn, original_max_position_embeddings=args.original_max_position_embeddings)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for YaRN")
        elif args.rerope:
            if "LlamaForCausalLM" in model.config.architectures:
                training_length = args.original_max_position_embeddings if args.original_max_position_embeddings else 4096
                window = args.rerope
                patch_llama_for_rerope(
                    model, training_length=training_length, window=window)
            else:
                raise RuntimeError(
                    f"Unsupported architecture {model.config.architectures} for YaRN")

    if args.adapter:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, args.adapter)
        model = model.merge_and_unload()


    return model


old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None, scale = 8):
    #The method is just these three lines
    max_position_embeddings = 16384
    a = scale #Alpha valueo
    base = base * a ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)
        
def load_model_and_apply_patches(model, args):
    if args.load_in_4bit or args.load_in_8bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            load_in_8bit=args.load_in_8bit,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    else:
        quantization_config = None
    # if args.awq:
    #     quantization_config = AwqConfig(bits = 4)
    # else:   
    #     if args.quant_weight_bit_width not in ['float8','int8','int4','int2']:
    #         raise ValueError("quantization bitwidth is invalid, please choose from ['float8','int8','int4','int2']") 
    #     else:    
    #         quantization_config = QuantoConfig(weights = args.quant_weight_bit_width)
            
    """ #### Apply NTK-Scaled Init patch"""
    if args.ntk:
        print('Applying ntk interpolation')
        transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init
    else: 
        print('Running without interpolation')
    
    
    config = AutoConfig.from_pretrained(model, trust_remote_code=True)
    # Note: To avoid OOM after huggingface transformers 4.36.2
    config.use_cache = False
    model = transformers.AutoModelForCausalLM.from_pretrained(model, config=config, trust_remote_code=True,
                                                 quantization_config=quantization_config,
                                                #  attn_implementation="flash_attention_2"
                                                 )
    return model
    # return apply_patches(load_model(model, args), args)
    
def load_model_and_apply_patches_awq(model_path, args, config = None, 
                                     quant_path = None, awq_cache = None, temp = None):
    
    """ #### Apply NTK-Scaled Init patch"""
    if args.ntk:
        print('Applying ntk interpolation')
        transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init
    else: 
        print('Running without interpolation')
    if config == None:
        config = AutoConfig.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
                                                        torch_dtype=torch.float16)
    q_config = {"zero_point": True, "q_group_size": 128}
    model.eval()
    print(f"Loading pre-computed AWQ cache results from {awq_cache} and apply to the model" )
    awq_results = torch.load(awq_cache, map_location="cpu")
    
    print(f"Applying AWQ scale and clip with temperature {temp}")
    # print("awq_results:", awq_results)
    apply_awq_ntk(model, awq_results, temp = temp, beta_point = args.beta_point)

    real_quantize_model_weight(
    model, w_bit=4, q_config=q_config, init_only=False)

    device_map = "auto"
    no_split_module_classes=["LlamaDecoderLayer"]
    max_memory = None
    if isinstance(device_map, str):
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                model,
                max_memory=max_memory,
                no_split_module_classes=no_split_module_classes,
                # dtype=dtype,
                low_zero=(device_map == "balanced_low_0"),
            )
        device_map = infer_auto_device_map(
            model,
            max_memory=max_memory,
            no_split_module_classes=no_split_module_classes,
        )
    model = dispatch_model(
        model,
        device_map=device_map,
    )

    if quant_path != None:
        print(f"Loading awq model checkpoint directly from{quant_path}")
        model = load_checkpoint_and_dispatch(
            model, quant_path,
            device_map="auto",
            no_split_module_classes=["LlamaDecoderLayer"]
        )
    
    # make_quant_attn(model, "cuda:0")
    # make_quant_norm(model)
    # make_fused_mlp(model)
    
    return model, tokenizer 
    # return apply_patches(load_model(model, args), args)
    
def load_model_and_apply_patches_original(model, args):
    return apply_patches(load_model(model, args), args)
    
def load_model_and_apply_patches_original_awq(model, args, quant_path, awq_cache, temp):
    return apply_patches(load_model_awq(model, args, quant_path, awq_cache, temp), args)

def load_model_and_apply_patches_original_awq_search(model, args, awq_cache):
    return apply_patches(load_model_awq_search(model, args, awq_cache), args)

# def load_model_and_apply_patches_hadamard(model, args):
#     """
#     Load a model and apply Hadamard transform to the input layer.
#     This is useful for models that require Hadamard encoding.
#     """
#     # Move model to CPU for patching
#     model_cpu = model.cpu()
#     # Apply Hadamard transform to the input layer
#     for name, module in model.named_modules():
#         # print("name:", name)
#         # if (("q_proj" in name) or ("k_proj" in name) or ("v_proj" in name)) and isinstance(module, nn.Linear):
#         # if "o_proj" in name and isinstance(module, nn.Linear):
#         # if "up_proj" in name and isinstance(module, nn.Linear):
#         if "gate_proj" in name and isinstance(module, nn.Linear):
#         # if isinstance(module, nn.Linear):
#             print(f"Applying Hadamard transform to {name}")
#             # print("inside loop name:", name)
#             parent = model
#             subnames = name.split(".")
#             for subname in subnames[:-1]:
#                 parent = getattr(parent, subname)
#             orig = getattr(parent, subnames[-1])
#             hadamard_layer = HadamardLinear(orig.in_features, orig.out_features, orig.bias is not None, args.quant_activation).to('cpu')
#             hadamard_layer.copy_weights_from(orig)
#             setattr(parent, subnames[-1], hadamard_layer)
#     model_gpu = model_cpu.cuda()
#     return model_gpu

def load_model_and_apply_patches_hadamard(model, args, inverse_transform_layers=False):
    # if inverse_transform_layers is None:
    #     inverse_transform_layers = set()

    model_cpu = model.cpu()
    # Parse layers from user input
    hadamard_layers = parse_hadamard_layers(args.hardmard_layers)
    
    for name, module in model.named_modules():
        if any(layer_name in name for layer_name in hadamard_layers) and isinstance(module, nn.Linear):
            print(f"Applying Hadamard transform to {name}")

            parent = model
            subnames = name.split(".")
            for subname in subnames[:-1]:
                parent = getattr(parent, subname)

            orig = getattr(parent, subnames[-1])

            hadamard_layer = HadamardLinear(
                orig.in_features,
                orig.out_features,
                orig.bias is not None,
                activation_quant=args.quant_activation,
                inverse_transform=inverse_transform_layers
            ).to('cpu')

            hadamard_layer.copy_weights_from(orig)
            setattr(parent, subnames[-1], hadamard_layer)

    return model_cpu.cuda()

def patch_model_with_activation_quant(model, num_bits=4, group_size=128):
    for name, module in model.named_modules():
        # print("module: ", module)
        if isinstance(module, nn.Linear):
            orig_forward = module.forward
            quant = ActivationQuantizer(num_bits=num_bits, group_size=group_size)
            def new_forward(x, orig_forward=orig_forward, quant=quant):
                return orig_forward(quant(x))
            module.forward = new_forward

from scipy.linalg import hadamard
# class HadamardLinear(nn.Module):
#     def __init__(self, in_features, out_features, bias=True):
#         super().__init__()
#         assert (in_features & (in_features - 1)) == 0, "in_features must be a power of 2"
#         self.linear = nn.Linear(in_features, out_features, bias)
#         H = hadamard(in_features).astype('float16')
#         self.register_buffer('H', torch.tensor(H, dtype=torch.float16))
#         self.n = in_features

#     def forward(self, x):
#         # Apply Hadamard transform to input
#         x = torch.matmul(x, self.H)
#         # Linear layer
#         x = self.linear(x)
#         return x

#     def copy_weights_from(self, orig_linear):
#         # Transform weights: W' = W @ H / n
#         with torch.no_grad():
#             H_inv = self.H.t() / self.n  # H is its own transpose/inverse, normalize by n
#             self.linear.weight.copy_(torch.matmul(orig_linear.weight, H_inv))
#             if orig_linear.bias is not None:
#                 self.linear.bias.copy_(orig_linear.bias)
#          # Free memory
#         del H_inv


# class HadamardLinear(nn.Module):
#     def __init__(self, in_features, out_features, bias=True):
#         super().__init__()
#         self.orig_in_features = in_features
#         self.padded_in_features = 1 << (in_features - 1).bit_length()  # next power of 2
#         self.linear = nn.Linear(self.padded_in_features, out_features, bias)
#         # Only keep Hadamard matrix as float16 to save memory
#         H = hadamard(self.padded_in_features).astype('float16')
#         self.register_buffer('H', torch.tensor(H))
#         self.n = self.padded_in_features

#     def forward(self, x):
#         # Pad input only if needed
#         if x.shape[-1] < self.padded_in_features:
#             pad_width = self.padded_in_features - x.shape[-1]
#             x = torch.nn.functional.pad(x, (0, pad_width))
#         # Apply Hadamard transform and linear layer
#         x = torch.matmul(x, self.H)
#         x = self.linear(x)
#         return x

#     def copy_weights_from(self, orig_linear):
#         # Pad weights only if needed
#         weight = orig_linear.weight
#         if weight.shape[1] < self.padded_in_features:
#             pad_width = self.padded_in_features - weight.shape[1]
#             weight = torch.nn.functional.pad(weight, (0, pad_width))
#         H_inv = self.H.t() / self.n
#         with torch.no_grad():
#             self.linear.weight.copy_(torch.matmul(weight, H_inv))
#             if orig_linear.bias is not None:
#                 self.linear.bias.copy_(orig_linear.bias)
#         # Free memory
#         del weight, H_inv


from functools import lru_cache

# @lru_cache(maxsize=8)
# def get_hadamard_matrix(size):
#     # Use float32 for faster CPU matmul, cast to float16 for storage if needed
#     return torch.tensor(hadamard(size), dtype=torch.float32)

# class HadamardLinear(nn.Module):
#     def __init__(self, in_features, out_features, bias=True, activation_quant=False):
#         super().__init__()
#         self.orig_in_features = in_features
#         self.padded_in_features = 1 << (in_features - 1).bit_length()
#         self.linear = nn.Linear(self.padded_in_features, out_features, bias)
#         # Only keep Hadamard matrix on CPU
#         H = get_hadamard_matrix(self.padded_in_features)
#         self.register_buffer('H', H.to('cpu'), persistent=False)
#         self.n = self.padded_in_features
#         self.act_quant = ActivationQuantizer(num_bits=4, group_size=128)
#         self.activation_quant = activation_quant
        
#     def forward(self, x):
#         if x.shape[-1] < self.padded_in_features:
#             pad_width = self.padded_in_features - x.shape[-1]
#             x = torch.nn.functional.pad(x, (0, pad_width))
#         # Move H to device only for computation
#         H = self.H.to(x.device)
#         x = x.to(torch.float32) @ H
#         if self.activation_quant:
#             x = self.act_quant(x)
#         x = self.linear(x.to(torch.float16))
#         return x

#     def copy_weights_from(self, orig_linear):
#         weight = orig_linear.weight
#         if weight.shape[1] < self.padded_in_features:
#             pad_width = self.padded_in_features - weight.shape[1]
#             weight = torch.nn.functional.pad(weight, (0, pad_width))
#         H_inv = get_hadamard_matrix(self.padded_in_features).t() / self.n
#         with torch.no_grad():
#             new_weight = (weight.to(torch.float32) @ H_inv).to(self.linear.weight.dtype)
#             self.linear.weight.copy_(new_weight)
#             if orig_linear.bias is not None:
#                 self.linear.bias.copy_(orig_linear.bias)
#         # Optionally delete H_inv to free CPU RAM
#         del weight, H_inv
        
        

@lru_cache(maxsize=8)
def get_hadamard_matrix(size):
    H = hadamard(size) / (size ** 0.5)
    return torch.tensor(H, dtype=torch.float32)

class HadamardLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True, activation_quant=False, inverse_transform=False):
        super().__init__()
        self.orig_in_features = in_features
        self.padded_in_features = 1 << (in_features - 1).bit_length()
        self.linear = nn.Linear(self.padded_in_features, out_features, bias)

        H = get_hadamard_matrix(self.padded_in_features)
        self.register_buffer('H', H.to('cpu'), persistent=False)

        self.act_quant = ActivationQuantizer(num_bits=4, group_size=128)
        self.activation_quant = activation_quant
        self.inverse_transform = inverse_transform

    def forward(self, x):
        orig_dim = x.shape[-1]
        if orig_dim < self.padded_in_features:
            pad_width = self.padded_in_features - orig_dim
            x = F.pad(x, (0, pad_width))

        H = self.H.to(x.device)
        x_transformed = torch.matmul(x.to(torch.float32), H)

        if self.activation_quant:
            x_transformed = self.act_quant(x_transformed)
            # print(f"Activation quantization applied with {self.act_quant.num_bits} bits")

        x_out = self.linear(x_transformed.to(torch.float16))

        # if self.inverse_transform:
        #     H_inv = H.t()
        #     x_out = torch.matmul(x_out.to(torch.float32), H_inv)
        #     x_out = x_out[..., :orig_dim].to(torch.float16)

        return x_out

    def copy_weights_from(self, orig_linear):
        weight = orig_linear.weight

        if weight.shape[1] < self.padded_in_features:
            pad_width = self.padded_in_features - weight.shape[1]
            weight = F.pad(weight, (0, pad_width))

        H_inv = self.H.t()

        with torch.no_grad():
            new_weight = torch.matmul(weight.to(torch.float32), H_inv)
            self.linear.weight.copy_(new_weight.to(self.linear.weight.dtype))

            if orig_linear.bias is not None:
                self.linear.bias.copy_(orig_linear.bias)

        del weight, H_inv

def parse_hadamard_layers(layer_arg):
    """
    Parse user-input comma-separated layer names into a list.
    """
    if layer_arg:
        return [layer.strip() for layer in layer_arg.split(",") if layer.strip()]
    return []
        
class ActivationQuantizer(nn.Module):
    def __init__(self, num_bits=8, group_size=128):
        super().__init__()
        assert num_bits in [4, 8], "Only 4 or 8 bit quantization supported"
        self.num_bits = num_bits
        self.group_size = group_size

    def forward(self, x):
        orig_shape = x.shape
        x_flat = x.view(-1, x.shape[-1])  # [batch*seq, features]
        num_groups = (x_flat.shape[1] + self.group_size - 1) // self.group_size
        xq = torch.empty_like(x_flat)

        for g in range(num_groups):
            start = g * self.group_size
            end = min((g + 1) * self.group_size, x_flat.shape[1])
            xg = x_flat[:, start:end]
            # RTN: symmetric quantization
            xmax = xg.abs().max(dim=1, keepdim=True)[0] + 1e-8
            qmax = 2 ** (self.num_bits - 1) - 1
            scale = xmax / qmax
            xg_int = torch.round(xg / scale).clamp(-qmax, qmax)
            xg_deq = xg_int * scale
            xq[:, start:end] = xg_deq

        return xq.view(orig_shape)
    
