import argparse
import datasets
import gc
import sys
from requests import options
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from tqdm import tqdm
from eval_util import log_results
from model_loader import *
import os
import copy 
from eval_util import *
import logging
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

def main(args):
    models = [x[0] for x in args.model]

    # Add specifiers to log filename
    if args.log_file != "None":
        base_name, ext = os.path.splitext(args.log_file)
        specifiers = []
        
        if args.naive_quant:
            specifiers.append("naive_rtn")
        if args.awq and not args.naive_quant and args.awq_cache is not None:
            specifiers.append("awq")
        if args.apply_hardmard:
            specifiers.append(f"hardmard_{args.hardmard_layers}")
        if args.quant_activation:
            specifiers.append(f"act_quant_{args.quant_activation_bitwidth}bit")
        if args.no_pi or args.yarn == 1:
            specifiers.append("no_pi")
        if args.individual_channel_up is not None:
            specifiers.append("individual_channel_up")
        if specifiers:
            args.log_file = f"{base_name}_{'_'.join(specifiers)}{ext}"

    setup_logger(args.log_file)

    if args.naive_quant:
        logger.info("Using naive RTN quantization.")
    
    if args.awq and not args.naive_quant and args.awq_cache is not None:
        logger.info("Using AWQ quantization.")
    
    if args.apply_hardmard:
        logger.info(f"Applying Hardmard quantization to layers: {args.hardmard_layers}")
    
    if args.quant_activation:
        logger.info(f"Quantizing activation with bitwidth {args.quant_activation_bitwidth}")
        
    if args.individual_channel_up is not None:
        logger.info(f"Use fixed channel rescale: {args.individual_channel_up} with scale {args.individual_channel_scale}")

    for model in models:
        # load model and apply rescaling
        if args.awq:
            model, tokenizer = load_model_and_apply_patches_original_awq(args.model[0][0], 
                                                                        args, args.quant_path, args.awq_cache, args.awq_rescale_temp)
        elif args.yarn == 1 or args.no_pi:
            print("load original model directly without quantization and interpolation.")
            model = AutoModelForCausalLM.from_pretrained(args.model[0][0], device_map="auto",
                torch_dtype=torch.bfloat16)
            tokenizer = AutoTokenizer.from_pretrained(
                models[0])
        else:
            model = load_model_and_apply_patches_original(args.model[0][0], args)
            tokenizer = AutoTokenizer.from_pretrained(
                models[0], model_max_length=sys.maxsize, trust_remote_code=True)

        tokenizer.pad_token = tokenizer.eos_token

        result = eval_tasks(model, tokenizer, args.tasks)  # Get a baseline
        log_results(f"result:", result)



def process_task(value):
    tasks = (value.replace(',', ' ').strip().split())
    return tasks

if __name__ == "__main__":
    warnings.simplefilter("ignore")
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", action="append", nargs="+")
    
    parser.add_argument("--quant_path", default = None, type=str)
    parser.add_argument("--awq_cache", default = None, type=str)
    parser.add_argument("--awq_rescale_temp", default= None, type=float)
    parser.add_argument("--beta_point", default= 1287, type=int)
    # parser.add_argument("--comments", default = '', type=str)
    parser.add_argument("--dynamic_with_log_distance", action="store_true")
    parser.add_argument("--exclude_value_proj", action="store_true")
    parser.add_argument("--rescale_attention_all", action="store_true")
    parser.add_argument("--rescale_per_head", action="store_true")
    parser.add_argument("--individual_channel_up", default= None, type=str)
    parser.add_argument("--individual_channel_down", default= None, type=str)
    parser.add_argument("--individual_channel_scale", default= None, type=float)
    parser.add_argument("--individual_channel_value", default= None, type=str)
    # parser.add_argument("--scale_invert", action="store_true")
    
    parser.add_argument("--recale_specific_layer", type=int, default=None, help="Specify a layer to rescale, e.g., 0 for the first layer. If None, all layers are rescaled.")
    parser.add_argument("--search_result_path", default= None, type=str)
    parser.add_argument("--use_search_result", action="store_true")

    parser.add_argument('--tasks', type=process_task, help='Task(s) to perform, such as "wikitext c4"', required=True)
    parser.add_argument('--log_file', '-o', type=str, help='Output path logging info', default= "None")

    parser.add_argument("--apply_hardmard", action="store_true")
    parser.add_argument("--naive_quant", action="store_true")
    parser.add_argument("--no_pi", action="store_true")
    parser.add_argument("--quant_activation", action="store_true")
    parser.add_argument("--quant_activation_bitwidth", type=int, default=4)
    parser.add_argument("--hardmard_layers", type=str, default="")
    
    
    # parser.add_argument("-d", "--dataset", type=str)
    # parser.add_argument("-s", "--subset", type=str)
    # parser.add_argument("-f", "--feature", type=str)
    # parser.add_argument("--max-tokens", type=int, default=8192)
    # parser.add_argument("--min-tokens", type=int, default=256)
    # parser.add_argument("--dataset-min-tokens", type=int)
    # parser.add_argument("--tokens-step", type=int, default=8)
    # parser.add_argument("--sliding-window", type=int, default=256)
    # parser.add_argument("--truncate", action="store_true")
    # parser.add_argument("--split", type=str, default="test")
    # parser.add_argument("--samples", type=int)
    # parser.add_argument("--save-tokenized", type=str)
    # parser.add_argument("--tokenized", type=str)
    # parser.add_argument("--output-file", type=str)
    parser.add_argument("--aggressive-memory", action="store_true")
    parser.add_argument("--hide-progress", action="store_true")
    parser.add_argument("--awq", action="store_true")
    parser.add_argument("--original", action="store_true")
    main(add_args(parser).parse_args())