import argparse
from types import SimpleNamespace
import pprint
import torch
import random
import numpy as np
import os
from datetime import datetime
import logging

# These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
DEV = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')


def llama_down_proj_groupsize(model, groupsize):
    
    assert groupsize > 1, 'groupsize should be greater than 1!'
    
    if model.config.intermediate_size % groupsize == 0:
        logging.info(f'(Act.) Groupsiz = Down_proj Groupsize: {groupsize}')
        return groupsize

    group_num = int(model.config.hidden_size/groupsize)
    assert groupsize*group_num == model.config.hidden_size, 'Invalid groupsize for llama!'

    down_proj_groupsize = model.config.intermediate_size//group_num
    assert down_proj_groupsize*group_num == model.config.intermediate_size, 'Invalid groupsize for down_proj!'
    logging.info(f'(Act.) Groupsize: {groupsize}, Down_proj Groupsize: {down_proj_groupsize}')
    return down_proj_groupsize

def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    random.seed(seed)

# Dump the log both to console and a log file.
def config_logging(log_file, level=logging.INFO):
    class LogFormatter(logging.Formatter):
        def format(self, record):
            if record.levelno == logging.INFO:
                self._style._fmt = "%(message)s"
            else:
                self._style._fmt = "%(levelname)s: %(message)s"
            return super().format(record)

    console_handler = logging.StreamHandler()
    console_handler.setFormatter(LogFormatter())

    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(LogFormatter())

    logging.basicConfig(level=level, handlers=[console_handler, file_handler])


# def nested_namespace(args):
#     args_dict = vars(args)
#     nested_dict = {}
    
#     for full_key, value in args_dict.items():
#         key_parts = full_key.split('.')
#         current_level = nested_dict
        
#         for i, part in enumerate(key_parts[:-1]):
#             if part not in current_level:
#                 current_level[part] = {}
#             elif not isinstance(current_level[part], dict):
#                 raise ValueError(
#                     f"name conflict: '{'.'.join(key_parts[:i+1])}' "
#                 )
#             current_level = current_level[part]
        
#         final_key = key_parts[-1]
#         if final_key in current_level:
#             raise ValueError(
#                 f"arg conflict: '{full_key}' "
#             )
#         current_level[final_key] = value
    
#     def dict_to_namespace(d):
#         if isinstance(d, dict):
#             return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
#         return d
    
#     return dict_to_namespace(nested_dict)


def parser_gen():
    parser = argparse.ArgumentParser()

    # General Arguments
    parser.add_argument("--model", type=str, default="qwen2_vl",
                        help="Model class name to specify a ModelClass.")
    parser.add_argument("--model_args", type=str, default=None,
                        help="Model arguments to init a model instance.")
    parser.add_argument("--batch_size", type=str, default=1, metavar="auto|auto:N|N",
                        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.")
    parser.add_argument("--device", type=str, default="cuda",
                        help="Device to use (e.g. cuda, cuda:0, cpu)")
    parser.add_argument("--use_flash_attention_2", action="store_true")
    parser.add_argument('--seed', type=int, default=42, 
                        help='Random Seed for HuggingFace and PyTorch')
    
    # Multimodel Calib Data Arguments
    parser.add_argument("--calib_data", default="coco", choices=["coco", "wikitext"])
    parser.add_argument("--n_samples", default=10, type=int)
    parser.add_argument("--seqlen", default=512, type=int,
                        help="Sequence length for the calibration data. This is used to truncate the input sequences.")
    parser.add_argument("--data_path", default="your/data/path", type=str)
    parser.add_argument("--image_folder", default="your/image/folder", type=str)
    parser.add_argument("--interleave_format", action="store_true")
    parser.add_argument("--few_shot_format", action="store_true")
    parser.add_argument("--text_data_path", default="", type=str)
    
    # Rotation Arguments
    parser.add_argument('--rotate', action=argparse.BooleanOptionalAction, default=False, 
                        help="Rotate the moodel. This will include online rotation for down-projection and out-projection. Note that this does not apply rotation to the K/Q and they will be rotated if we want to quantize the Keys")
    parser.add_argument('--rotate_mode', type=str, default='hadamard', choices=['hadamard', 'random'])
    parser.add_argument('--rotation_seed', type=int, default=-1,
                        help='Random Seed for generating random matrix')
    parser.add_argument('--fp32_had', action=argparse.BooleanOptionalAction, default=False,
                        help='Apply Hadamard rotation in FP32 (default: False)')

    # Activation Quantization Arguments
    parser.add_argument('--a_bits', type=int, default=16,
                        help='''Number of bits for inputs of the Linear layers. This will be
                        for all the linear layers in the model (including down-projection and out-projection)''')
    parser.add_argument('--a_groupsize', type=int, default=-1, 
                        help='Groupsize for activation quantization. Note that this should be the same as w_groupsize')
    parser.add_argument('--a_asym', action=argparse.BooleanOptionalAction, default=False,
                        help='ASymmetric Activation quantization (default: False)')
    parser.add_argument('--a_clip_ratio', type=float, default=1.0,
                        help='Clip ratio for activation quantization. new_max = max * clip_ratio')
    parser.add_argument('--enable_aq_calibration', action=argparse.BooleanOptionalAction, default=False,
                        help='Enable activation quantization in GPTQ(v2) (default: False)')

    # Weight Quantization Arguments
    parser.add_argument('--method', type=str, default='gptq',
                        help='''Quantization method to use.''')
    parser.add_argument('--w_bits', type=int, default=16, 
                        help='Number of bits for weights of the Linear layers')
    parser.add_argument('--w_groupsize', type=int, default=-1, 
                        help='Groupsize for weight quantization. Note that this should be the same as a_groupsize')
    parser.add_argument('--w_asym', action=argparse.BooleanOptionalAction, default=False,
                        help='ASymmetric weight quantization (default: False)')
    parser.add_argument('--w_clip', action=argparse.BooleanOptionalAction, default=False,
                        help='''Clipping the weight quantization! 
                        We do not support arguments for clipping and we find the best clip ratio during the weight quantization''')
    parser.add_argument('--percdamp', type=float, default=0.01,
                        help='Percent of the average Hessian diagonal to use for dampening.')
    parser.add_argument('--act_order', action=argparse.BooleanOptionalAction, default=False,
                        help='act-order in GPTQ(v2)')
    parser.add_argument('--static_groups', action=argparse.BooleanOptionalAction, default=False,
                        help='static groups in GPTQ(v2)')

    # General Quantization Arguments
    parser.add_argument('--int8_down_proj', action=argparse.BooleanOptionalAction, default=False,
                        help='Use INT8 for Down Projection! If this set, both weights and activations of this layer will be in INT8')

    # KV-Cache Quantization Arguments
    parser.add_argument('--v_bits', type=int, default=16,
                        help='''Number of bits for V-cache quantization. 
                        Note that quantizing the V-cache does not need any other rotation''')
    parser.add_argument('--v_groupsize', type=int, default=-1)
    parser.add_argument('--v_asym', action=argparse.BooleanOptionalAction, default=False,
                        help='ASymmetric V-cache quantization')
    parser.add_argument('--v_clip_ratio', type=float, default=1.0,
        help='Clip ratio for v-cache quantization. new_max = max * clip_ratio')
    
    parser.add_argument('--k_bits', type=int, default=16,
                        help='''Number of bits for K-cache quantization. 
                        Note that quantizing the K-cache needs another rotation for the keys/queries''')
    parser.add_argument('--k_groupsize', type=int, default=-1)
    parser.add_argument('--k_asym', action=argparse.BooleanOptionalAction, default=False, 
                        help='ASymmetric K-cache quantization')
    parser.add_argument('--k_pre_rope', action=argparse.BooleanOptionalAction, default=False, 
                        help='Pre-RoPE quantization for K-cache (not Supported yet!)')
    parser.add_argument('--k_clip_ratio', type=float, default=1.0,
        help='Clip ratio for k-cache quantization. new_max = max * clip_ratio')
    
    # gradient based token importance: vlmqv4
    parser.add_argument('--grad_from', type=str, default='block_out',
                        help='''Gradient from which layer to compute the token importance.
                        Options: block_out, attn_out. Default: block_out''')
    parser.add_argument('--grad_acton', type=str, default='all',
                        help='''Which layers to apply the gradient-based token importance.
                        Options: qkv, qkvo, all.
                        Default: all. If you want to apply the gradient on all layers, use "all"''')
    parser.add_argument('--grad_clip', action=argparse.BooleanOptionalAction, default=False,
                        help='''Clip the gradient-based token importance. 
                        This is used to avoid numerical issues when computing the token scores. Default: False''')
    parser.add_argument('--grad_norm', type=str, default='l1',
                        choices=['l1', 'l2'],
                        help='''Norm type to use for the gradient-based token importance.
                        Options: l1, l2, max. Default: l2''') 
    parser.add_argument('--grad_temperature', type=float, default=1.0,
                        help='''Temperature to use for the gradient-based token importance.
                        This is used to scale the token scores at the start of the quantization. Default: 1.0''')    
    parser.add_argument('--grad_start_idx', type=int, default=1,
                        help='''Start index for the gradient-based token importance.
                        This is used to skip the first few tokens in the sequence. Default: 1''')
    parser.add_argument('--grad_clip_times', type=float, default=10.0,
                        help='''grad clip times''')
    parser.add_argument('--grad_clip_high_only', action=argparse.BooleanOptionalAction, default=False,
                        help='''grad clip high''')
    parser.add_argument('--residual_alpha', type=float, default=0.25,
                        help='''Residual alpha for the gradient-based token importance.
                        This is used to scale the residuals in the model. Default: 0.0''')
    parser.add_argument('--random_drop_ratio', type=float, default=-1.0,
                        help='''Random drop ratio for the gradient-based token importance.''')

    # Save/Load Quantized Model Arguments
    parser.add_argument('--load_qmodel_path', type=str, default=None,
                        help='Load the quantized model from the specified path!')
    parser.add_argument('--save_qmodel_path', type=str, default="your/save/path", 
                        help='Save the quantized model to the specified path!')

    # WandB Arguments
    parser.add_argument('--wandb', action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument('--wandb_id', type=str, default=None)
    parser.add_argument('--wandb_project', type=str, default=None)


    # Experiments Arguments
    parser.add_argument('--capture_layer_io', action=argparse.BooleanOptionalAction, default=False,
                        help='Capture the input and output of the specified decoder layer and dump into a file')
    parser.add_argument('--layer_idx', type=int, default=10, help='Which decoder layer to capture')
    

    args = parser.parse_args()
    if args.grad_from == 'attn_out':
        assert args.grad_acton != 'all'
    if args.grad_from == 'o_proj_in':
        assert args.grad_acton == 'qkv'

    # setattr(args, 'save_path', f'your/path')
    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(os.path.join(args.save_path, f'logs'), exist_ok=True)
        
    config_logging(os.path.join(args.save_path, f'logs/quant_log.log'))
    
    # assert args.a_groupsize == args.w_groupsize, 'a_groupsize should be the same as w_groupsize!'
    assert args.k_pre_rope == False, 'Pre-RoPE quantization is not supported yet!'

    if args.wandb:
        assert args.wandb_id is not None and args.wandb_project is not None, 'WandB ID/project is not provided!'
        
    logging.info('Arguments: ')
    logging.info(pprint.pformat(vars(args)))
    logging.info('--' * 30)
    return args

def cleanup_memory(verbos=True) -> None:
    """Run GC and clear GPU memory."""
    import gc
    import inspect
    caller_name = ''
    try:
        caller_name = f' (from {inspect.stack()[1].function})'
    except (ValueError, KeyError):
        pass

    def total_reserved_mem() -> int:
        return sum(torch.cuda.memory_reserved(device=i) for i in range(torch.cuda.device_count()))

    memory_before = total_reserved_mem()

    # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        memory_after = total_reserved_mem()
        if verbos:
            logging.info(
                f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB"
                f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)"
            )