import argparse
import datasets
import gc
import sys
import torch
import warnings
from transformers import AutoTokenizer
from tqdm import tqdm
from model_loader import *
import os
import copy 

def compute_perplexity(
    encodings, model, tokenizer, add_start_token: bool = True, device=None, max_length=None, sliding_window=256, truncate=False, aggressive_memory=False, hide_progress=False,
):
    r"""Compute "sliding window" perplexity on a dataset. Validated against the calculations reported in arXiv 2306.15595"""
    if device is not None:
        assert device in ["gpu", "cpu",
                          "cuda"], "device should be either gpu or cpu."
        if device == "gpu":
            device = "cuda"
    else:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    if add_start_token:
        # leave room for <BOS> token to be added:
        assert (
            tokenizer.bos_token is not None
        ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
        max_tokenized_len = max_length - 1
    else:
        max_tokenized_len = max_length

    encoded_texts = encodings["input_ids"]
    attn_masks = encodings["attention_mask"]

    # encoded_texts = encoded_texts[0:49:5]
    # encoded_texts = encoded_texts[0:1]
    
    if max_length and truncate:
        encoded_texts = [x[0:max_tokenized_len] for x in encoded_texts]
        attn_masks = [x[0:max_tokenized_len] for x in attn_masks]
        sliding_window = max_tokenized_len

    pbar = tqdm(total=len(encoded_texts), disable=hide_progress)
    nlls = []
    for encoding_index in range(0, len(encoded_texts)):

        labels = torch.tensor(encoded_texts[encoding_index:encoding_index+1])
        seq_len = labels.size(1)

        prev_end_loc = 0
        for begin_loc in range(0, seq_len, sliding_window):

            end_loc = min(begin_loc + max_tokenized_len, seq_len)
            
            # save computing
            if (begin_loc + max_tokenized_len) > seq_len:
                break
            
            trg_len = end_loc - prev_end_loc
            input_ids = labels[:, begin_loc:end_loc].to(device)

            if add_start_token:
                bos_tokens_tensor = torch.tensor(
                    [[tokenizer.bos_token_id]] * input_ids.size(dim=0)).to(device)
                input_ids = torch.cat(
                    [bos_tokens_tensor, input_ids], dim=1)

            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100

            # print('input_size: ', input_ids.shape)
            with torch.no_grad():
                outputs = model(input_ids, labels=target_ids)
                neg_log_likelihood = outputs.loss
            
            if aggressive_memory:
                outputs = None
                input_ids = None
                target_ids = None
                gc.collect()
                torch.cuda.empty_cache()

            # print('nll: ', neg_log_likelihood)
            nlls.append(neg_log_likelihood.cpu())

            ppl = float(torch.exp(torch.stack(nlls).mean()))
            pbar.set_postfix(ppl=ppl)

            prev_end_loc = end_loc
            if end_loc == seq_len:
                break

        pbar.update(1)

    ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu())
    return {"mean_perplexity": ppl}


def main(args):
    models = [x[0] for x in args.model]

    # model_oroginal, tokenizer = load_model_and_apply_patches_original_awq_search(args.model[0][0], 
    #                                                             args, args.awq_cache)
    tokenizer = AutoTokenizer.from_pretrained(args.model[0][0], use_fast=False)
    tokenizer.pad_token = tokenizer.eos_token

    if args.tokenized:
        try:
            input_texts = datasets.load_from_disk(args.tokenized)
        except:
            input_texts = datasets.load_dataset(
                args.tokenized, name=args.subset, split=args.split)
    else:
        input_texts = datasets.load_dataset(
            args.dataset, name=args.subset, split=args.split)

        def tokenize(example):
            tokenized = tokenizer(
                example[args.feature],
                add_special_tokens=False,
                padding=True,
                truncation=False,
                max_length=sys.maxsize,
                return_attention_mask=True,
            )
            example["input_ids"] = tokenized["input_ids"]
            example["attention_mask"] = tokenized["attention_mask"]
            example["tokenized_len"] = len(tokenized["input_ids"])
            return example

        input_texts = input_texts.map(tokenize)
        if args.save_tokenized:
            input_texts.save_to_disk(args.save_tokenized)
            print(f"Saved tokenized dataset to {args.save_tokenized}")
            return

    if args.dataset_min_tokens:
        input_texts = input_texts.filter(
            lambda x: x["tokenized_len"] >= args.dataset_min_tokens)
    if args.samples:
        input_texts = input_texts[:args.samples*5-1:5]


    tokens = [640, 1408, 2944, 5248, 9088, 16768]
    # ppl_reference = [8.054335594177246, 7.26130485534668, 6.434703826904297, 6.272468566894531, 6.263601303100586, 6.389873027801514]
    ppl_reference = [6.087226390838623, 5.631160259246826, 5.1676740646362305, 5.0342302322387695, 5.051265239715576, 5.055654048919678]
    ppl_reference_sum= sum(ppl_reference)
    channel_result = []
    for model in models:
        for channel in tqdm(range(128), desc="Channel", leave=False, disable=False):
            print(f"search for channel number{channel}")
            # scale_result = {}
            min_ppl_difference = 0
            best_scale = 0
            for scale in tqdm([1.5, 2.0, 2.5], desc="scale", leave=False, disable=False):
                print(f"search for scale {scale}")
                torch.cuda.empty_cache()
                
                # load model and apply rescaling
                loaded, _ = load_model_and_apply_patches_original_awq_search(args.model[0][0], 
                                                                args, args.awq_cache)
                loaded = apply_awq_search_helper(loaded, args.awq_cache, scale, channel)

                ppl_result = []
                for max_length in tokens:
                    ppl = compute_perplexity(model=loaded, tokenizer=tokenizer, encodings=input_texts,
                                            add_start_token=tokenizer.bos_token is not None, max_length=max_length,
                                            sliding_window=args.sliding_window, truncate=args.truncate,
                                            aggressive_memory=args.aggressive_memory, hide_progress=args.hide_progress)['mean_perplexity']
                    print(f"{str(model)[str(model).rfind('/')+1:]},channel = {channel}, scale = {scale}: {max_length}={ppl}")
                    ppl_result.append(ppl)
                    
                ppl_result_sum = sum(ppl_result)
                print('ppl_result_sum: ', ppl_result_sum)
                ppl_sum_difference = ppl_result_sum - ppl_reference_sum
                # min_ppl_difference = min(ppl_sum_difference, min_ppl_difference)
                if min_ppl_difference >= ppl_sum_difference:
                    best_scale = scale
                    min_ppl_difference = ppl_sum_difference
            print(f"best scale for channel {channel} is {best_scale}")
            with open(f"best_channel_scales_{str(model)[str(model).rfind('/')+1:]}.txt", "a") as file:  # Changed to append mode
                file.write(f"best scale for channel {channel} is {best_scale} with ppl_difference {min_ppl_difference}\n")
            channel_result.append({best_scale: min_ppl_difference})
            
    import json  # Add this at the top of the file if not already imported

    output_dir = "search_channel_result"
    os.makedirs(output_dir, exist_ok=True)  # Create the directory if it doesn't exist

    # Save the channel_result dictionary to a JSON file
    output_file = os.path.join(output_dir, f"searched_channel_result_{str(model)[str(model).rfind('/')+1:]}.json")  # Customize the filename as needed
    with open(output_file, "w") as f:
        json.dump(channel_result, f, indent=4)
    print(f"Saved channel_result to {output_file}")

if __name__ == "__main__":
    warnings.simplefilter("ignore")
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", action="append", nargs="+")
    
    parser.add_argument("--quant_path", default = None, type=str)
    parser.add_argument("--awq_cache", default = None, type=str)

    
    parser.add_argument("-d", "--dataset", type=str)
    parser.add_argument("-s", "--subset", type=str)
    parser.add_argument("-f", "--feature", type=str)
    parser.add_argument("--max-tokens", type=int, default=8192)
    parser.add_argument("--min-tokens", type=int, default=256)
    parser.add_argument("--dataset-min-tokens", type=int)
    parser.add_argument("--tokens-step", type=int, default=8)
    parser.add_argument("--sliding-window", type=int, default=256)
    parser.add_argument("--truncate", action="store_true")
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--samples", type=int)
    parser.add_argument("--save-tokenized", type=str)
    parser.add_argument("--tokenized", type=str)
    parser.add_argument("--output-file", type=str)
    parser.add_argument("--aggressive-memory", action="store_true")
    parser.add_argument("--hide-progress", action="store_true")
    parser.add_argument("--awq", action="store_true")
    parser.add_argument("--original", action="store_true")
    main(add_args(parser).parse_args())