import argparse
import datasets
import gc
import sys
import torch
import warnings
from transformers import AutoTokenizer
from tqdm import tqdm
from model_loader import *
from model_loader_qwen25 import (
    load_model_and_apply_patches_qwen25_awq,
    load_model_and_apply_patches_qwen25_non_awq, 
    add_args as add_qwen_args
)
import os


def compute_perplexity(
    encodings, model, tokenizer, add_start_token: bool = True, device=None, max_length=None, sliding_window=256, truncate=False, aggressive_memory=False, hide_progress=False,
):
    r"""Compute "sliding window" perplexity on a dataset. Validated against the calculations reported in arXiv 2306.15595"""
    if device is not None:
        assert device in ["gpu", "cpu",
                          "cuda"], "device should be either gpu or cpu."
        if device == "gpu":
            device = "cuda"
    else:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    if add_start_token:
        # leave room for <BOS> token to be added:
        assert (
            tokenizer.bos_token is not None
        ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
        max_tokenized_len = max_length - 1
    else:
        max_tokenized_len = max_length

    encoded_texts = encodings["input_ids"]
    attn_masks = encodings["attention_mask"]

    # encoded_texts = encoded_texts[0:49:5]
    # encoded_texts = encoded_texts[0:1]
    
    if max_length and truncate:
        encoded_texts = [x[0:max_tokenized_len] for x in encoded_texts]
        attn_masks = [x[0:max_tokenized_len] for x in attn_masks]
        sliding_window = max_tokenized_len

    pbar = tqdm(total=len(encoded_texts), disable=hide_progress)
    nlls = []
    for encoding_index in range(0, len(encoded_texts)):

        labels = torch.tensor(encoded_texts[encoding_index:encoding_index+1])
        seq_len = labels.size(1)

        prev_end_loc = 0
        for begin_loc in range(0, seq_len, sliding_window):

            end_loc = min(begin_loc + max_tokenized_len, seq_len)
            
            # save computing
            if (begin_loc + max_tokenized_len) > seq_len:
                break
            
            trg_len = end_loc - prev_end_loc
            input_ids = labels[:, begin_loc:end_loc].to(device)

            if add_start_token:
                bos_tokens_tensor = torch.tensor(
                    [[tokenizer.bos_token_id]] * input_ids.size(dim=0)).to(device)
                input_ids = torch.cat(
                    [bos_tokens_tensor, input_ids], dim=1)

            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100

            # print('input_size: ', input_ids.shape)
            with torch.no_grad():
                outputs = model(input_ids, labels=target_ids)
                neg_log_likelihood = outputs.loss
            
            if aggressive_memory:
                outputs = None
                input_ids = None
                target_ids = None
                gc.collect()
                torch.cuda.empty_cache()

            # print('nll: ', neg_log_likelihood)
            nlls.append(neg_log_likelihood.cpu())

            ppl = float(torch.exp(torch.stack(nlls).mean()))
            pbar.set_postfix(ppl=ppl)

            prev_end_loc = end_loc
            if end_loc == seq_len:
                break

        pbar.update(1)

    ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu())
    return {"mean_perplexity": ppl}


def is_qwen_model(model_path):
    """Check if the model is a Qwen model"""
    try:
        from transformers import AutoConfig
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        if hasattr(config, 'architectures') and config.architectures:
            return "Qwen2ForCausalLM" in config.architectures or "QwenLMHeadModel" in config.architectures
        # Fallback: check model path
        return "qwen" in model_path.lower()
    except:
        return "qwen" in model_path.lower()

def main(args):
    models = [x[0] for x in args.model]
    model_path = args.model[0][0]
    
    # Check if this is a Qwen model
    if is_qwen_model(model_path):
        print(f"Detected Qwen model: {model_path}")
        if args.awq:
            print("Loading Qwen model with AWQ quantization")
            loaded, tokenizer = load_model_and_apply_patches_qwen25_awq(
                model_path, args, args.quant_path, args.awq_cache, args.awq_rescale_temp)
        else:
            print("Loading Qwen model without quantization")
            loaded, tokenizer = load_model_and_apply_patches_qwen25_non_awq(model_path, args)
    else:
        # Original Llama model loading logic
        if args.original:
                print("using original yarn infrastructure")
                if args.awq:
                    print("loading awq model with original yarn infrastructure")
                    loaded, tokenizer = load_model_and_apply_patches_original_awq(args.model[0][0], 
                                                                                  args, args.quant_path, args.awq_cache, args.awq_rescale_temp)
                else:
                    loaded = load_model_and_apply_patches_original(args.model[0][0], args)
                    tokenizer = AutoTokenizer.from_pretrained(
                    models[0], model_max_length=sys.maxsize, trust_remote_code=True)
        else:
            if args.awq:
                print("loading awq model with hardware acceleration")
                loaded, tokenizer = load_model_and_apply_patches_awq(args.model[0][0], 
                    args, None, args.quant_path, args.awq_cache, args.awq_rescale_temp)
            else:    
                loaded = load_model_and_apply_patches(args.model[0][0], args)
                tokenizer = AutoTokenizer.from_pretrained(
                    models[0], model_max_length=sys.maxsize, trust_remote_code=True)
        
    
    tokenizer.pad_token = tokenizer.eos_token

    if args.tokenized:
        try:
            input_texts = datasets.load_from_disk(args.tokenized)
        except:
            input_texts = datasets.load_dataset(
                args.tokenized, name=args.subset, split=args.split)
    else:
        input_texts = datasets.load_dataset(
            args.dataset, name=args.subset, split=args.split)

        def tokenize(example):
            tokenized = tokenizer(
                example[args.feature],
                add_special_tokens=False,
                padding=True,
                truncation=False,
                max_length=sys.maxsize,
                return_attention_mask=True,
            )
            example["input_ids"] = tokenized["input_ids"]
            example["attention_mask"] = tokenized["attention_mask"]
            example["tokenized_len"] = len(tokenized["input_ids"])
            return example

        input_texts = input_texts.map(tokenize)
        if args.save_tokenized:
            input_texts.save_to_disk(args.save_tokenized)
            print(f"Saved tokenized dataset to {args.save_tokenized}")
            return

    if args.dataset_min_tokens:
        input_texts = input_texts.filter(
            lambda x: x["tokenized_len"] >= args.dataset_min_tokens)
    if args.samples:
        input_texts = input_texts[:args.samples*5-1:5]

    if args.tokens_step:
        tokens = [x for x in range(
            args.min_tokens, args.max_tokens + 1, args.tokens_step)]
    else:
        tokens = [args.min_tokens]
        while args.min_tokens < args.max_tokens:
            point = tokens[-1] * 2
            if point <= args.max_tokens:
                tokens.append(point)
            else:
                break

    results = []
    for model in tqdm(models, desc="Model", leave=False, disable=args.hide_progress):
        torch.cuda.empty_cache()        

        result = []
        for max_length in tokens:
            ppl = compute_perplexity(model=loaded, tokenizer=tokenizer, encodings=input_texts,
                                     add_start_token=tokenizer.bos_token is not None, max_length=max_length,
                                     sliding_window=args.sliding_window, truncate=args.truncate,
                                     aggressive_memory=args.aggressive_memory, hide_progress=args.hide_progress)['mean_perplexity']
            print(f"{model}: {max_length}={ppl}")
            result.append(ppl)

        result.insert(0, model)
        results.append(result)

    if args.output_file:
        # path = args.output_file[0:args.output_file.rindex('/')]
        path = args.output_file
        if not os.path.exists(path):
            os.makedirs(path)
        output_file_name = args.output_file
        output_file_name += '/govreport'
        if args.awq:
            output_file_name += '_awq_int4'
        if args.yarn != None:
            output_file_name += '_yarn' + '_' + str(int(args.yarn))
        if args.ntk != None:
            output_file_name += '_ntk' + '_' + str(args.ntk)
        # output_file_name += '_embedding_only'
        if args.awq_rescale_temp != None:
            output_file_name += '_temp_' + str(args.awq_rescale_temp)
        if args.individual_channel_up != None:
            output_file_name += '_Individual_channel_up'
            # + str(args.individual_channel_up)
        if args.individual_channel_down != None:
            output_file_name += '_Individual_channel_down' 
            # + str(args.individual_channel_down)
        # if args.individual_channel_scale != None:
        #     output_file_name += 'scale' + str(args.individual_channel_scale)
        #     # + str(args.individual_channel_down)
        output_file_name += '_beta_point_' + str(args.beta_point)
        output_file_name += args.comments
        output_file_name += '.csv'
        with open(output_file_name, "a+", encoding="utf-8") as f:
            f.write(f",{','.join([str(x) for x in tokens])}\n")
            for result in results:
                f.write(f"{','.join([str(x) for x in result])}\n")


if __name__ == "__main__":
    warnings.simplefilter("ignore")
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", action="append", nargs="+")
    
    parser.add_argument("--quant_path", default = None, type=str)
    parser.add_argument("--awq_cache", default = None, type=str)
    parser.add_argument("--awq_rescale_temp", default= None, type=float)
    parser.add_argument("--beta_point", default= 1440, type=int)
    parser.add_argument("--comments", default = '', type=str)
    parser.add_argument("--dynamic_with_log_distance", action="store_true")
    parser.add_argument("--exclude_value_proj", action="store_true")
    parser.add_argument("--rescale_attention_all", action="store_true")
    parser.add_argument("--rescale_per_head", action="store_true")
    parser.add_argument("--individual_channel_up", default= None, type=str)
    parser.add_argument("--individual_channel_down", default= None, type=str)
    parser.add_argument("--individual_channel_scale", default= None, type=float)
    parser.add_argument("--individual_channel_value", default= None, type=str)
    parser.add_argument("--scale_invert", action="store_true")
    
    parser.add_argument("--recale_specific_layer", type=int, default=None, help="Specify a layer to rescale, e.g., 0 for the first layer. If None, all layers are rescaled.")
    parser.add_argument("--search_result_path", default= None, type=str)
    parser.add_argument("--use_search_result", action="store_true")
    
    parser.add_argument("--apply_hardmard", action="store_true")
    parser.add_argument("--naive_quant", action="store_true")
    parser.add_argument("--no_pi", action="store_true")
    parser.add_argument("--quant_activation", action="store_true")
    parser.add_argument("--quant_activation_bitwidth", type=int, default=4)
    parser.add_argument("--hardmard_layers", type=str, default="")
    
    parser.add_argument("-d", "--dataset", type=str)
    parser.add_argument("-s", "--subset", type=str)
    parser.add_argument("-f", "--feature", type=str)
    parser.add_argument("--max-tokens", type=int, default=8192)
    parser.add_argument("--min-tokens", type=int, default=256)
    parser.add_argument("--dataset-min-tokens", type=int)
    parser.add_argument("--tokens-step", type=int, default=8)
    parser.add_argument("--sliding-window", type=int, default=256)
    parser.add_argument("--truncate", action="store_true")
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--samples", type=int)
    parser.add_argument("--save-tokenized", type=str)
    parser.add_argument("--tokenized", type=str)
    parser.add_argument("--output-file", type=str)
    parser.add_argument("--aggressive-memory", action="store_true")
    parser.add_argument("--hide-progress", action="store_true")
    parser.add_argument("--awq", action="store_true")
    parser.add_argument("--original", action="store_true")
    
    main(add_args(parser).parse_args())