import argparse
import gc
import math
import os
import time
import numpy as np
import torch

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers")
    parser.add_argument("--name", type=str, help="Name path", required=True)
    parser.add_argument("--batch_size", default=1, type=int, help="batch size")
    parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark")
    parser.add_argument("--greedy", action="store_true")
    parser.add_argument("--top-k", type=int, default=0)
    parser.add_argument("--top-p", type=float, default=0.0)
    parser.add_argument("--dtype", type=str, help="float16 or int8", choices=["int8", "float16"], default="float16")

    return parser.parse_args()


t_start = time.time()

num_tokens = 1

args = get_args()

local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = torch.cuda.device_count()

rank = local_rank


def print_rank0(*msg):
    if rank != 0:
        return
    print(*msg)


print_rank0(f"Using {world_size} gpus")
model_name = args.name
print_rank0(f"Loading model {model_name}")

tokenizer = AutoTokenizer.from_pretrained(model_name)

# XXX: can't automatically derive dtype via config's `from_pretrained`
dtype = torch.bfloat16 if model_name in ["bigscience/bloom", "bigscience/bigscience-small-testing"] else torch.float16
#dtype=torch.float16
#print("OVERRIDING DTYPE TO BE", dtype)
# print(get_max_memory_per_gpu_dict())

infer_dtype = args.dtype
if infer_dtype == "int8":
    dtype = torch.int8

kwargs = dict(
    device_map='auto',#"balanced_low_0",
    max_memory={0:"8GiB", 1:"8GiB", 'cpu':"390GiB"}
)

if infer_dtype == "int8":
    print_rank0("Using `load_in_8bit=True` to use quanitized model")
    kwargs["load_in_8bit"] = True
else:
    kwargs["torch_dtype"] = dtype

print("LOADING", model_name, "KWARGS", kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
print("MODEL", model.hf_device_map)

if args.benchmark:
    t_ready = time.time()

def get_activation(name):
    def hook(module, input, output):
        print(name)
        if "mlp" in name or "attn" in name or "m_coef" in name:
            if 'attn' in name:
                num_tokens = list(output[0].size())[1]
                np.save(name, output[0][:, num_tokens - 1].detach().cpu().float().numpy())
            if 'mlp' in name:
                num_tokens = list(output[0].size())[0]
                np.save(f'inp_{name}.npy', input[0][:,num_tokens-1].detach().cpu().float().numpy())
                np.save(f'out_{name}.npy', output[0][num_tokens - 1].detach().cpu().float().numpy())
    return hook

#for i in range(len(model.transformer.h)):
#    model.transformer.h[i].mlp.register_forward_hook(get_activation("mlp_" + str(i)))
#    model.transformer.h[i].self_attention.register_forward_hook(get_activation("attn_" + str(i)))
print("NO ADDED HOOKS")

### Generate

print_rank0(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}")

input_sentences = [
    """Q: On the floor, I see a silver keychain, a red pair of sunglasses, a gold sheet of paper, a black dog leash, and a blue cat toy. What color is the keychain?\nA: Silver\nQ: On the table, you see a brown sheet of paper, a red fidget spinner, a blue pair of sunglasses, a teal dog leash, and a gold cup. What color is the sheet of paper?\nA:""",]

if args.batch_size > len(input_sentences):
    # dynamically extend to support larger bs by repetition
    input_sentences *= math.ceil(args.batch_size / len(input_sentences))

generate_kwargs = dict(output_attentions=True, output_hidden_states=True, use_cache=False)
# generate_kwargs = dict(max_new_tokens=num_tokens, use_cache=False, do_sample=False)
# generate_kwargs = dict(min_length=num_tokens, max_length=num_tokens, do_sample=False)

print_rank0(f"Generate args {generate_kwargs}")
inputs = input_sentences[: args.batch_size]
print("INPUTS", inputs)

def generate():
    """returns a list of zipped inputs, outputs and number of new tokens"""

    input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to("cuda:0")

    outputs = model.forward(**input_tokens, **generate_kwargs)
    np.save('bloom_attns_inp2.npy', [o.detach().cpu().float().numpy() for o in outputs.attentions])
    np.save('bloom_hs_inp2.npy', [o.detach().cpu().float().numpy() for o in outputs.hidden_states])
    np.save("bloom_logits_inp2.npy", outputs.logits.detach().cpu().float().numpy())
    output_token = [outputs.logits[:, -1].argmax(-1)]
    #print("OUT token", output_token)
    input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
    output_tokens_lengths = [x.shape[0] for x in output_token]

    total_new_tokens = [o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)]
    outputs = tokenizer.batch_decode(output_token, skip_special_tokens=True)
    print("OUTPUTS", outputs)
    return zip(inputs, outputs, total_new_tokens)

with torch.no_grad():
    print_rank0("*** Running generate")
    t_generate_start = time.time()
    generated = generate()
    t_generate_span = time.time() - t_generate_start
    for i, o, _ in generated:
        print_rank0(f"{'-'*60}\nin={i}\nout={o}\n")


### Benchmark

if args.benchmark:
    # clear cache / free memory
    torch.cuda.empty_cache()
    gc.collect()

    print_rank0("*** Running benchmark")
    # warm up
    with torch.no_grad():
        for i in range(1):
            _ = generate()
        torch.cuda.synchronize()

    # benchmark
    t0 = time.time()
    cycles = 5
    total_new_tokens_generated = 0
    with torch.no_grad():
        for i in range(cycles):
            generated = generate()
            total_new_tokens_generated += sum(new_tokens for _, _, new_tokens in generated)
        torch.cuda.synchronize()
    througput = (time.time() - t0) / (total_new_tokens_generated)
    print_rank0(
        f"""
*** Performance stats:
Throughput per token including tokenize: {througput*1000:.2f} msecs
Start to ready to generate: {t_ready - t_start:.3f} secs
Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs
Start to finish: {t_ready - t_start + t_generate_span:.3f} secs
"""
    )