import torch
from any_precision import AnyPrecisionForCausalLM_3456
from any_precision import AnyPrecisionForCausalLM_3456_tri
from any_precision import AnyPrecisionForCausalLM_3456_whole
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
from tqdm import tqdm
import os
from any_precision.evaluate.helpers import dataloader
from any_precision.evaluate.helpers.utils import vprint, logprint
from filelock import Timeout, FileLock
from functools import partial

name_to_pathNtype={
    "llama3":(lambda prec: "/work/cache/packed/anyprec-(Meta-Llama-3-8B)-w8_orig3-gc1-c4_s100_blk512", 
                lambda path, mode: AnyPrecisionForCausalLM_3456.from_quantized(path, precisions=[3,4,5,6], mode=mode, 
                        prefill_as_decode=True, model="llama3").eval().cuda()),

    "phi":(lambda prec: "/work/cache/packed/anyprec-(phi-3-medium)-w8_orig3-gc1-c4_s100_blk512", 
                lambda path, mode: AnyPrecisionForCausalLM_3456.from_quantized(path, precisions=[3,4,5,6], mode=mode, 
                        prefill_as_decode=True, model="phi").eval().cuda()),
}

def get_tokenizer_type(model_name):
    if 'llama3' in model_name.lower():
        tokenizer_type = 'llama3'
    elif 'phi' in model_name.lower():
        tokenizer_type = 'phi'
    else:
        tokenizer_type = None

    return tokenizer_type

@torch.no_grad()
def auto_model_load(model_name, device='cuda', dtype=torch.float16, verbose=True, precision=3, mode="orig"):
    """
    Args:
        model_path: path of the model to evaluate
        device: the device to use for evaluation, either 'cuda' or 'cpu'
        dtype: the dtype to use for evaluation, either torch.float16 or torch.float32
        verbose: whether to print progress

    Returns:
        (tokenizer, model) tuple loaded from the given path, with the given device and dtype.
    """
    logprint(verbose, "Loading tokenizer and model...")

    path_lambda, model_type = name_to_pathNtype[model_name]
    path = path_lambda(precision)

    tokenizer = AutoTokenizer.from_pretrained(path)
    model = model_type(path, mode)

    try:
        model.set_precision(precision)
        model.setMotherLayer()
    except:
        pass

    logprint(verbose, f"{model_name} model loaded to device: {model.device}")

    tokenizer_type = get_tokenizer_type(model_name)

    if tokenizer_type is None:
        logprint(verbose, f"Unknown tokenizer type for {model_name}. Cannot use cached input tokens.")

    return tokenizer_type, tokenizer, model


@torch.no_grad()
def evaluate_ppl(model, tokenizer, testcases, verbose=True, chunk_size=2048, tokenizer_type=None):
    """
    Args:
        model: model to evaluate
        tokenizer: tokenizer to use
        testcases: testcases names to evaluate on, passed on to dataloader.get_loaders
        verbose: whether to print progress
        chunk_size: the size of the chunks into which the test set is split
        tokenizer_type: set to llama, llama-2, or opt to use cached input tokens
                        for the corresponding test set

    Returns:
        A dictionary of perplexity scores, with keys being the testcases names and values being the perplexity scores.

    Note that the perplexity scores are calculated over non-overlapping chunks of the test set.
    """

    if isinstance(model, AnyPrecisionForCausalLM_3456):
        is_anyprec = True
    else:
        is_anyprec = False

    results = {}

    for testcase_name in testcases:
        vprint(verbose, f"---------------------- {testcase_name} ----------------------")
        try:
            model.clear_comp_count()
            logprint(verbose, f"<<<< Resetting comp count >>>>")
        except:
            pass

        input_tokens = load_input_tokens(tokenizer_type, testcase_name, tokenizer, verbose)

        input_tokens.to("cuda:0")

        logprint(verbose, "Calculating perplexity...")

        seq_len = input_tokens.input_ids.size(1)
        nsamples = seq_len // chunk_size  # floor(seq_len / chunk_size)

        neg_log_likelihoods = []
        for i in tqdm(range(nsamples), disable=not verbose):
            begin_loc = i * chunk_size

            input_ids = input_tokens.input_ids[:, begin_loc:begin_loc + chunk_size]

            with torch.no_grad():
                outputs = model(input_ids, labels=input_ids)
                neg_log_likelihood = outputs.loss
                neg_log_likelihoods.append(neg_log_likelihood)

        ppl = torch.exp(torch.stack(neg_log_likelihoods).mean())
        logprint(verbose, f"Perplexity: {ppl.item()}")

        results[f"{testcase_name}"] = ppl.item()
        try:
            eb = model.get_effective_bits()
            results[f"{testcase_name}_eb"] = eb
            logprint(verbose, f"effective bits: {eb}")
        except:
            pass

    return results

current_dir = "/work/cache"

def load_input_tokens(tokenizer_type, testcase_name, tokenizer, verbose):
    """ Load input tokens from cache if available, otherwise load from dataloader and save to cache. """
    input_tokens_cache_path = f"{current_dir}/input_tokens_cache/dataloader-{tokenizer_type}-{testcase_name}-test.pt"
    if tokenizer_type and os.path.exists(input_tokens_cache_path):
        logprint(verbose, f"Loading cached input tokens from {input_tokens_cache_path}...")
        input_tokens = torch.load(input_tokens_cache_path)
    else:
        logprint(verbose, "Loading test set...")

        raw_text = dataloader.get_loaders(testcase_name)

        logprint(verbose, "Tokenizing test set...")

        input_tokens = tokenizer(raw_text, return_tensors='pt')
        # save input_tokens to cache
        if tokenizer_type:
            logprint(verbose, f"Caching input tokens to {input_tokens_cache_path}...")
            # we must create the directory if it doesn't exist
            os.makedirs(os.path.dirname(input_tokens_cache_path), exist_ok=True)
            torch.save(input_tokens, input_tokens_cache_path)

    return input_tokens


import json
import argparse
import subprocess
import sys

parser = argparse.ArgumentParser()
parser.add_argument('--output_file', type=str, default='pp_results.json')
parser.add_argument("--model_name", type=str, default="llama3", required=True)
parser.add_argument("--wbits", type=int, default=3)
parser.add_argument("--mode", type=str, default="orig")
parser.add_argument("--suffix", type=str, default="")
parser.add_argument("--targ_setup", type=str, default="")
parser.add_argument("--ignore", action="store_true", default=False)

args = parser.parse_args()
suffix = args.suffix
if suffix != "":
    suffix = ","+suffix

lock_path = "36_pp_results.lock"
lock = FileLock(lock_path, timeout=-1)

models = [(args.model_name, args.wbits, args.mode)]

datasets = ['wikitext2', 'c4_new']

# read previous results
if os.path.exists(args.output_file):
    with open(args.output_file) as f:
        all_results = json.load(f)
else:
    all_results = {}

new_results = {}  # results that are newly calculated, to be printed at the end

total_tests_to_run = {}  # tasks to be run will be stored here
skipped_models = []  # models that are skipped will be stored here

# Check which models/testcases need to be run
# This is done first so that we know how many tasks there are in total,
# and thus we can print the progress
for model_name,prec,mode in models:
    model_jobs = {'to_print': [], 'ppl': []}

    datasets_with_results = [dataset for dataset in datasets if all_results.get(f"{model_name},{prec},{mode}{suffix}")]


    if not args.ignore:
        model_jobs['ppl'] = [testcase for testcase in datasets if testcase not in datasets_with_results]
        if not model_jobs['ppl']:
            # All results of the target model/testcases and model/tasks combination exist, skip
            skipped_models.append(model_name)
            continue
        else:
            if datasets_with_results:
                model_jobs['to_print'].append(f"Skipping datasets: "
                                                f"{datasets_with_results} because results already exist")
    else:
        model_jobs['ppl'] = [testcase for testcase in datasets]
            
    model_jobs['to_print'].append(f"Running datasets: {model_jobs['ppl']}")
    total_tests_to_run[(model_name,prec,mode)] = model_jobs

total_ppl_job_count = sum(len(model_tasks['ppl']) for model_tasks in total_tests_to_run.values())
if skipped_models:
    print(f">> {len(skipped_models)} models will be skipped because all dataset results already exist.")

# Run all tasks
for i, (model_name,prec,mode) in enumerate(total_tests_to_run):
    model_jobs = total_tests_to_run[(model_name,prec,mode)]
    to_print = model_jobs['to_print']
    datasets_to_evaluate = model_jobs['ppl']
    print("==================================================")
    print(f" Model: {model_name}, Prec: {prec}, Mode: {mode}, Suffix: {suffix}")
    print(f"Progress: {i + 1}/{len(total_tests_to_run)}")
    print("==================================================")
    datasets_with_results = [testcase for testcase in datasets if testcase in all_results.get(model_name, {})]

    for line in to_print:
        print('>> ' + line)

    ppl_results = {}

    if args.targ_setup != "":
        setup_lock_path = "pp_setup.lock"
        setup_lock = FileLock(setup_lock_path, timeout=-1)
        targ_str : str = args.targ_setup
        targ_str_arr = targ_str.split(" ")
        call_arr = [sys.executable]
        call_arr.extend(targ_str_arr)
        print("Acquring lock...")
        setup_lock.acquire()
        print(f"Calling with {call_arr}")
        subprocess.call(call_arr)

    # Run evaluation
    tokenizer_type, tokenizer, model = auto_model_load(model_name, precision=prec, mode=mode)
    if args.targ_setup != "":
        setup_lock.release()
        print("Released lock.")

    if datasets_to_evaluate:
        ppl_results = evaluate_ppl(model, tokenizer, datasets_to_evaluate, verbose=True,
                                    chunk_size=2048, tokenizer_type=tokenizer_type)
    # Update new results
    new_results[f"{model_name},{prec},{mode}{suffix}"] = {}
    if ppl_results:
        new_results[f"{model_name},{prec},{mode}{suffix}"]['ppl'] = ppl_results

    with lock:
        # read previous results
        if os.path.exists(args.output_file):
            with open(args.output_file) as f:
                all_results = json.load(f)
        else:
            all_results = {}

        # Update all results
        if ppl_results:
            all_results.setdefault(f"{model_name},{prec},{mode}{suffix}", {}).setdefault('ppl', {}).update(ppl_results)

        # save results
        with open(args.output_file, 'w') as f:
            all_results = dict(sorted(all_results.items()))  # sort by key
            json.dump(all_results, f, indent=4)

    print()

    del model  # clear memory
    torch.cuda.empty_cache()


print("---------------------- All Results ----------------------")
# print new results as formatted json
print(json.dumps(new_results, indent=4))

if len(total_tests_to_run) == 0:
    exit(1)
