import argparse
import pprint
import torch
import random
import numpy as np
import os
from datetime import datetime
import logging


from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

supported_models = [
            'meta-llama/Llama-2-7b-hf',
            'meta-llama/Llama-2-13b-hf',
            'meta-llama/Llama-2-70b-hf',
            'meta-llama/Meta-Llama-3-8B',
            'meta-llama/Meta-Llama-3-70B',
            'meta-llama/Meta-Llama-3.1-405B',
            'facebook/opt-125m',
            'facebook/opt-1.3b',
            'mistralai/Mistral-7B-v0.1',
            'Qwen/Qwen2-7B',
            'mistralai/Codestral-22B-v0.1',
            'Qwen/Qwen2-1.5B'
            ]
supported_datasets = ['wikitext2', 'ptb', 'c4']

# These flags disable using TensorFloat-32 tensor cores (to avoid numerical issues)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
DEV = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

def llama_down_proj_groupsize(model, groupsize):
    
    assert groupsize > 1, 'groupsize should be greater than 1!'
    
    if model.config.intermediate_size % groupsize == 0:
        logging.info(f'(Act.) Groupsiz = Down_proj Groupsize: {groupsize}')
        return groupsize

    group_num = int(model.config.hidden_size/groupsize)
    assert groupsize*group_num == model.config.hidden_size, 'Invalid groupsize for llama!'

    down_proj_groupsize = model.config.intermediate_size//group_num
    assert down_proj_groupsize*group_num == model.config.intermediate_size, 'Invalid groupsize for down_proj!'
    logging.info(f'(Act.) Groupsize: {groupsize}, Down_proj Groupsize: {down_proj_groupsize}')
    return down_proj_groupsize



def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    random.seed(seed)

# Dump the log both to console and a log file.
def config_logging(log_file, level=logging.INFO):
    class LogFormatter(logging.Formatter):
        def format(self, record):
            if record.levelno == logging.INFO:
                self._style._fmt = "%(message)s"
            else:
                self._style._fmt = "%(levelname)s: %(message)s"
            return super().format(record)

    console_handler = logging.StreamHandler()
    console_handler.setFormatter(LogFormatter())

    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(LogFormatter())

    logging.basicConfig(level=level, handlers=[console_handler, file_handler])


def parser_gen():
    parser = argparse.ArgumentParser()

    # General Arguments
    parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-hf',
                        help='Model to load;', choices=supported_models)
    parser.add_argument('--seed', type=int, default=0, help='Random Seed for HuggingFace and PyTorch')
    parser.add_argument('--eval_dataset', type=str, default='wikitext2',
                        help='Dataset for Evaluation (default: wikitext2)', choices=supported_datasets,)
    parser.add_argument('--hf_token', type=str, default=None)
    parser.add_argument('--bsz', type=int, default=32,
                        help='Batch-size for PPL evaluation (default:32)')

    # Calibration Arguments
    parser.add_argument('--nsamples', type=int, default=128,
                        help='Number of calibration data samples.')
    parser.add_argument('--seqlen', type=int, default=2048,
                        help='Number of calibration data sequence length.')
    parser.add_argument('--cal_dataset', type=str, default='wikitext2',
                        help='calibration data samples.', choices=supported_datasets)
    parser.add_argument('--percdamp', type=float, default=.01,
                        help='Percent of the average Hessian diagonal to use for dampening.')
    
    #! By default, only this option is turned on.
    parser.add_argument('--act_order', action=argparse.BooleanOptionalAction, default=False,
                        help='act-order in SparseGPT, sort the hessian')
    
    #! Emperically lead to better performance. But we are not turning this on for this paper.
    parser.add_argument('--block_act', action=argparse.BooleanOptionalAction, default=False,
                        help='blockwise fill with act order in DuoGPT') 
    
    parser.add_argument('--dxxt_permutation', action=argparse.BooleanOptionalAction, default=False,
                        help='Use dxxt instead of H diag for act-order in DuoGPT')
    
    parser.add_argument('--use_v2', action=argparse.BooleanOptionalAction, default=False,
                        help='enable GPTQv2 asymmetric calibration')
    parser.add_argument('--use_wanda', action=argparse.BooleanOptionalAction, default=False,
                        help='enable GPTQv2 asymmetric calibration')


    # Save/Load Quantized Model Arguments
    parser.add_argument('--load_pmodel_path', type=str, default='./ckpts',
                        help='Load the prunned model from the specified path!')
    parser.add_argument('--save_pmodel_path', type=str, default='./ckpts', 
                        help='Save the prunned model to the specified path!')
    parser.add_argument('--save_ckpt', action=argparse.BooleanOptionalAction, default=False,
                        help='Save the prunned model to the specified path!')
    parser.add_argument('--load_ckpt', action=argparse.BooleanOptionalAction, default=False,
                        help='Load the prunned model from the specified path!')
    parser.add_argument('--hf_cache_path', type=str, default='./hf_cache', 
                        help='The cache path for the hugging face. Dont forget to set this!')
    #* ''
    # WandB Arguments
    parser.add_argument('--wandb', action=argparse.BooleanOptionalAction, default=False)
    parser.add_argument('--wandb_name', type=str, default='None')
    parser.add_argument('--wandb_id', type=str, default='')
    parser.add_argument('--wandb_project', type=str, default='DuoGPT')
    parser.add_argument('--wandb_dir', type=str, default='./wandb_logs/DuoGPT')

    #! Newly Added Arguments for Pruning
    parser.add_argument("--sparsity", type=float, default=0, help="Target weight sparsity")
    parser.add_argument("--a_sparsity", type=float, default=0, help="Target activation sparsity")
    parser.add_argument("--prunen", type=int, default=0, help="N for N:M pruning. Default 0 is turning off.")
    parser.add_argument("--prunem", type=int, default=0, help="M for N:M pruning. Default 0 is turning off.")
    parser.add_argument("--blocksize", type=int, default=128, help="Blocksize to use for adaptive mask selection.")
    parser.add_argument("--act_blocksize", type=int, default=32, help="Blocksize to use for block act_order.")

    parser.add_argument("--scale_alpha", type=float, default=0.125, help="alpha scaling for helping calibration")

    parser.add_argument('--enable_ap_calibration', action=argparse.BooleanOptionalAction, default=False,
                        help='enable activation sparsity in calibration')

    #! Emperically did not show significant improvement. Left here for legacy codes.
    parser.add_argument('--enable_ap_anneal', action=argparse.BooleanOptionalAction, default=False,
                        help='enable activation sparsity annealing in calibration')

    #! Experiments for weight quantization
    parser.add_argument('--w_bits', type=int, default=16, 
                        help='Number of bits for weights of the Linear layers')
    parser.add_argument('--group_size', type=int, default=128, 
                        help='Group size for quantizing weights of the Linear layers')
                        
    #! Experiments Arguments
    parser.add_argument('--save_name', type=str, default=None, help='The path to save experiment data, '
                                                                    'including quantized models, dumped layer inputs, etc. The data will be saved in experiments/[model]/save_name. Default: [datetime].')
    parser.add_argument('--capture_layer_io', action=argparse.BooleanOptionalAction, default=False,
                        help='Capture the input and output of the specified decoder layer and dump into a file')
    parser.add_argument('--layer_idx', type=int, default=10, help='Which decoder layer to capture')
    parser.add_argument('--enable_wanda_comparison', action=argparse.BooleanOptionalAction, default=False,
                        help='adjust the default context length to compare with wanda')
    parser.add_argument('--exp_name', type=str, default=None, help='The experiment name')
    parser.add_argument('--debug', action=argparse.BooleanOptionalAction, default=False,
                        help='Debug mode, will terminate after the first decoder layer is calibrated.')

    parser.add_argument('--act_distr_catch', action=argparse.BooleanOptionalAction, default=False,
                        help='enable the teal repro to get activation distributions.')
    
    parser.add_argument('--act_teal', action=argparse.BooleanOptionalAction, default=False,
                        help='enable the th-based pruning.')
    
    #! LM Eval Arguments
    parser.add_argument("--lm_eval", action="store_true", help="Evaluate the model on LM Eval tasks.")
    parser.add_argument(
        '--tasks',
        nargs='+',
        # default=["gsm8k"],
        # default=["piqa", "hellaswag", "arc_easy", "arc_challenge", "winogrande", "boolq", "openbookqa"],
        default=["piqa", "hellaswag", "arc_easy", "arc_challenge", "winogrande"],
    )
    parser.add_argument('--lm_eval_batch_size', type=int, default=32, help='Batch size for evaluating with lm eval harness.')
    parser.add_argument('--distribute', action=argparse.BooleanOptionalAction, default=False,
                        help='Distribute the model on multiple GPUs for evaluatione')
    parser.add_argument("--lm_ppl", action="store_true", help="Evaluate the model on LM PPL tasks.")


    args = parser.parse_args()
    if args.lm_eval:
        from lm_eval import tasks
        from lm_eval.tasks import initialize_tasks
        from lm_eval import utils as lm_eval_utils
        initialize_tasks()
        for task in args.tasks:
            if task not in lm_eval_utils.MultiChoice(tasks.ALL_TASKS):
                raise ValueError(f"Invalid task: {task}")

    if args.save_name is None:
        args.save_name = datetime.now().strftime("%Y%m%d_%H%M%S")
    path = './DuoGPTmisc' #! set your own saving path.
    setattr(args, 'save_path',
            os.path.join(path, 'experiments', args.model, args.save_name))
    os.makedirs(args.save_path, exist_ok=True)

    config_logging(os.path.join(args.save_path, f'{args.save_name}.log'))
    

    if args.model == 'facebook/opt-125m' or args.model == 'facebook/opt-1.3b':
        logging.warning('Warning: OPT-125M/1.3B is only for debugging purposes!!')


    if args.wandb:
        assert args.wandb_id is not None and args.wandb_project is not None, 'WandB ID/project is not provided!'
        
    logging.info('Arguments: ')
    logging.info(pprint.pformat(vars(args)))
    logging.info('--' * 30)
    return args


def cleanup_memory(verbos=True) -> None:
    """Run GC and clear GPU memory."""
    import gc
    import inspect
    caller_name = ''
    try:
        caller_name = f' (from {inspect.stack()[1].function})'
    except (ValueError, KeyError):
        pass

    def total_reserved_mem() -> int:
        return sum(torch.cuda.memory_reserved(device=i) for i in range(torch.cuda.device_count()))

    memory_before = total_reserved_mem()

    # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        memory_after = total_reserved_mem()
        if verbos:
            logging.info(
                f"GPU memory{caller_name}: {memory_before / (1024 ** 3):.2f} -> {memory_after / (1024 ** 3):.2f} GB"
                f" ({(memory_after - memory_before) / (1024 ** 3):.2f} GB)"
            )

def distribute_model(model) -> None:
    """Distribute the model across available GPUs. NB: only implemented for Llama-2."""
    no_split_module_classes = ['LlamaDecoderLayer']
    max_memory = get_balanced_memory(
        model,
        no_split_module_classes=no_split_module_classes,
    )

    device_map = infer_auto_device_map(
        model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
    )

    print(device_map)
    dispatch_model(
        model,
        device_map=device_map,
        offload_buffers=True,
        offload_dir="offload",
        state_dict=model.state_dict(),
    )

    cleanup_memory()