from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
import time
from datetime import datetime
import torch
from hf_generation import my_generate
import argparse
import json
import os
import numpy as np
from util import CKPT
from ast import literal_eval as eval

device = "cuda"
#NUM_ASSISTANT_TOKENS_LIST = [4,5,6,10]
NUM_REPEAT = 1

def set_up(args):
    if args.do_sample:
        print("do_sample for SpeculativeDecoding")

    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    use_quantization = args.use_quantization

    checkpoint = CKPT[args.model_name]
    assistant_checkpoint = CKPT[args.assistant_name]

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    if use_quantization:
        print("use quantization")
        # qconfig = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
        # qconfig = BitsAndBytesConfig(load_in_4bit=True)
        # qconfig = BitsAndBytesConfig(load_in_8bit=True)
        # qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
        # qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
        qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
        assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, quantization_config=qconfig, device_map='cuda:0')
    else:
        assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, torch_dtype=torch.bfloat16, device_map='cuda:0')

    if args.num_assistant_tokens_schedule == 'ada':
        from wrap_model import WrapModel, WrapModelConfig
        model_config = WrapModelConfig.from_json_file(args.assist_acc_head_dir + '/config.json')
        wrapped_model = WrapModel(assistant_model, num_layers=model_config.num_layers)
        print("Loading from acc_head checkpoint:", args.assist_acc_head_dir + '/assist_acc_head.pth')
        wrapped_model.assist_acc_head = torch.load(args.assist_acc_head_dir + '/assist_acc_head.pth')
        assist_acc_head = wrapped_model.assist_acc_head.to("cuda:0")
    else:
        assist_acc_head = None

    model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')

    # print(model.hf_device_map)
    # print(assistant_model.hf_device_map)
    return model, assistant_model, tokenizer, assist_acc_head

def assist(model, assistant_model, tokenizer, assist_acc_head, inputs, max_length, num_assistant_tokens=None, oracle_token_num_list=None):
    # outputs = model.generate(**inputs, generation_config=generation_config, assistant_model=assistant_model, max_length=max_length)
    before=time.time()
    total_mismatched_tokens = 0
    total_num_LM_call = 0
    for i in range(NUM_REPEAT):
        if args.num_assistant_tokens_schedule == 'oracle':
            assistant_model.max_assistant_tokens = oracle_token_num_list[0]
        else:
            assistant_model.max_assistant_tokens = None
        outputs, num_mismatched_tokens, num_LM_call = my_generate(model=model, **inputs, assistant_model=assistant_model, \
                max_length=max_length, num_assistant_tokens_schedule=args.num_assistant_tokens_schedule, \
                num_assistant_tokens=num_assistant_tokens, do_sample=args.do_sample, \
                oracle_token_num_list=oracle_token_num_list, assist_acc_head=assist_acc_head, \
                stop_threshold=args.stop_threshold, bound=args.bound)
        total_mismatched_tokens += num_mismatched_tokens
        total_num_LM_call += num_LM_call
    after = time.time()
    assisted_time = (after - before)/NUM_REPEAT
    avg_mismatched_tokens = total_mismatched_tokens/NUM_REPEAT
    avg_LM_call = total_num_LM_call/NUM_REPEAT
    print("assisted time: {:.2f}".format(assisted_time))
    print("avg_mismatched_tokens: {:.2f}".format(avg_mismatched_tokens))
    print("avg_LM_call: {:.2f}".format(avg_LM_call))

    return outputs, avg_mismatched_tokens, avg_LM_call, assisted_time

def nonassist(model, tokenizer, inputs, max_length):
    before=time.time()
    for i in range(NUM_REPEAT):
        outputs = model.generate(**inputs, max_length=max_length, do_sample=args.do_sample)
    after = time.time()
    nonassisted_time = (after - before)/NUM_REPEAT
    print("nonassisted time {:.2f}".format(nonassisted_time))
    return outputs, nonassisted_time

def SLM_only(assistant_model, tokenizer, inputs, max_length):
    before=time.time()
    for i in range(NUM_REPEAT):
        outputs = assistant_model.generate(**inputs, max_length=max_length, do_sample=args.do_sample)
    after = time.time()
    SLM_only_time = (after - before)/NUM_REPEAT
    print("SLM_only time {:.2f}".format(SLM_only_time))
    return outputs, SLM_only_time


def run(model, assistant_model, tokenizer, assist_acc_head, args, item):
    len_prefix = len(eval(item['prefix']))
    inputs = {'input_ids': torch.LongTensor([eval(item['prefix'])]).to(device)}
    max_length = args.max_length
    print("max_length:", max_length)

    if args.num_assistant_tokens_schedule == 'oracle':
        if args.do_sample:
            raise NotImplementedError('Not supported in stochastic mode.')
            print("Warning. Do_sample mode not supported! might have a huge error")
        item_tokens = eval(item['tokens'])
        item_greedy_7b = eval(item['greedy_7b'])

        #max_length = 512 #fixed
        #print("max_length:", max_length)
        different_indices = []
        oracle_token_num_list = []
        for index, (item1, item2) in enumerate(zip(item_tokens, item_greedy_7b)):
            if item1 != item2:
                different_indices.append(index)
                # print(index, item1, item2)
                if len(different_indices) == 1:
                    oracle_token_num_list.append(index)
                else:
                    oracle_token_num_list.append(different_indices[-1] - different_indices[-2] - 1)
        if different_indices != []:
            oracle_token_num_list.append(max_length - different_indices[-1] - 1)
        else:
            oracle_token_num_list.append(max_length - 1)
        print('oracle_token_num_list:', oracle_token_num_list)

        res_a, num_mismatched_tokens, num_LM_call, assisted_time = assist(model, assistant_model, tokenizer, assist_acc_head, inputs, max_length, num_assistant_tokens=oracle_token_num_list[0], oracle_token_num_list=oracle_token_num_list)

    elif args.num_assistant_tokens_schedule in ['constant', 'heuristic', 'ada']:
        
        if args.num_assistant_tokens_schedule == 'ada':
            num_assistant_tokens = None
        else:
            num_assistant_tokens = args.num_assistant_tokens
            print("num_assistant_tokens:", num_assistant_tokens)

        res_a, num_mismatched_tokens, num_LM_call, assisted_time = assist(model, assistant_model, tokenizer, assist_acc_head, inputs, max_length, num_assistant_tokens=num_assistant_tokens)
    elif args.num_assistant_tokens_schedule == 'none':
        res_a = [[-1]]
        num_mismatched_tokens = -1
        num_LM_call = -1
        assisted_time = -1
    else:
        raise ValueError(f"{args.num_assistant_tokens_schedule} not supported")


    if args.num_assistant_tokens_schedule == 'none':
        res_b, nonassisted_time = nonassist(model, tokenizer, inputs, max_length)
        generated_length_nonassist = len(res_b[0]) - len_prefix

        res_c, SLM_only_time = SLM_only(assistant_model, tokenizer, inputs, max_length)
        generated_length_SLM_only = len(res_c[0]) - len_prefix
    else:
        nonassisted_time = -1
        generated_length_nonassist = -1 
        SLM_only_time = -1 
        generated_length_SLM_only = -1



    speed_up = nonassisted_time / assisted_time # deprecated!
    #print("speed up: {:.2f}".format(speed_up))
    generated_length = len(res_a[0]) - len_prefix
    print("generated_length: {:.2f}".format(generated_length))

    return assisted_time, nonassisted_time, speed_up, SLM_only_time, num_mismatched_tokens, num_LM_call, generated_length, generated_length_nonassist, generated_length_SLM_only

def parse_args():
    parser = argparse.ArgumentParser(description='benchmark performance')

    parser.add_argument('--model_name', type=str, default="70b", choices=["7b", "13b", "70b"])
    parser.add_argument('--assistant_name', type=str, default="7b", choices=["13b", "7b"])
    parser.add_argument('--max_length', type=int, default=512)
    parser.add_argument('--use_quantization', action='store_true')
    parser.add_argument('--do_sample', action='store_true')

    parser.add_argument('--num_assistant_tokens', type=int, default=5)
    parser.add_argument('--num_assistant_tokens_schedule', type=str, default="constant", choices=['constant', 'heuristic', 'oracle', 'ada', 'none'])
    parser.add_argument('--assist_acc_head_dir', type=str, default=None)
    parser.add_argument('--data_path', type=str, default='./alpaca_data/dev1k.json')
    parser.add_argument('--save_path', type=str, default='./test_results')
    parser.add_argument('--random_seed', type=int, default=47)
    parser.add_argument('--stop_threshold', type=float, default=None)
    parser.add_argument('--bound', nargs='+', type=int, default=None)

    parser.add_argument('--n_begin', type=int, default=0)
    parser.add_argument('--n_end', type=int, default=None)

    args = parser.parse_args()
    print(args)

    return args


if __name__ == "__main__":
    args = parse_args()
    data = json.load(open(args.data_path,'r'))
    if args.n_end is None:
        args.n_end = len(data)
    args.n_end = min(len(data), args.n_end)


    os.makedirs(args.save_path, exist_ok=True)

    model, assistant_model, tokenizer, assist_acc_head = set_up(args)

    results = []

    for i, item in enumerate(data[args.n_begin:args.n_end]):
        print("---------------------------------")
        print(f"data {i + args.n_begin}")
        before=time.time()

        assisted_time, nonassisted_time, speed_up, SLM_only_time, num_mismatched_tokens, num_LM_call, generated_length, generated_length_nonassist, generated_length_SLM_only = run(model, assistant_model, tokenizer, assist_acc_head, args, item)
        item.update({
            'id': i+args.n_begin,
            'spec_time': assisted_time,
            'target_time': nonassisted_time,
            #'speed_up': speed_up,
            'draft_time': SLM_only_time,
            'num_mismatched_tokens': num_mismatched_tokens,
            'num_LM_call': num_LM_call,
            'generated_length': generated_length,
            'generated_length_target': generated_length_nonassist, 
            'generated_length_draft': generated_length_SLM_only,
        })
        results.append(item)

        after=time.time()
        print("total time: {:.2f}".format(after-before))
    save_file = f"{args.save_path}/results_{args.n_begin}to{args.n_end}.json"
    with open(save_file, 'w') as f:
        f.write(json.dumps(results, indent=2))