from tqdm import tqdm
import transformers
from collections import OrderedDict
import torch
from QA_text_dataset_CNN import InferDataset
from transformers import DynamicCache

from peft import LoraConfig

from modeling_icae_multi_span import ICAE, ModelArguments, DataArguments, TrainingArguments
from icecream import ic as print

import json
import random
import os
import numpy as np

device = "cuda"
memory_size = 128
parser = transformers.HfArgumentParser(
    (ModelArguments, DataArguments, TrainingArguments))

model_args, data_args, training_args = parser.parse_args_into_dataclasses()


print("model_args", model_args)
print("data_args", data_args)
print("training_args", training_args)
data_path = data_args.data_path
save_path = training_args.output_dir + data_args.data_path.split(
    "/")[-2] + "_outputs-icae-" + str(training_args.mean_compression_rate) + ".jsonl"
print("save_path", save_path)

lora_config = LoraConfig(
    r=model_args.lora_r,
    lora_alpha=32,
    lora_dropout=model_args.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

MAX_NEW_TOKENS_dict = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

MAX_NEW_TOKENS = MAX_NEW_TOKENS_dict[data_args.data_path.split("/")[-2]]


def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(42)

weight_path = "/path/to/project/baseline2/icae/code/icae_v1/llama-2-7b-chat-finetuned-icae_zeroweight_llama2.pt"
model = ICAE(model_args, training_args, lora_config).to("cuda")
print(model.mean_compression_rate)
eos_id = model.tokenizer('<s>', return_tensors='pt').input_ids[0]
print(eos_id)
state_dict = torch.load(weight_path)
new_state_dict = OrderedDict()
for layer_name, weight in state_dict.items():
    if isinstance(weight, torch.Tensor) or weight != 0.0:
        new_state_dict[layer_name] = weight

model.load_state_dict(new_state_dict, strict=False)


file_path = data_path

lines = None

with open(file_path, "r") as f:
    lines = f.readlines()

MEMORY_TOKENS = list(range(model.vocab_size, model.vocab_size + memory_size))

model.eval()
print(f'model.mean_compression_rate: {model.mean_compression_rate}')
instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.'
}

instruction_text = instructions_map['base'] if 'NQ' in data_path else instructions_map['short']


count = 0
with torch.no_grad():
    with open(save_path, "w") as f:
        for line in tqdm(lines, desc="Generating"):
            data = json.loads(line)
            documents = [document['title'] + ' ' + document['text']
                         if 'title' in document else document['text'] for document in data['ctxs']]
            prompt = '\n'.join(documents)
            text_output = model.tokenizer(prompt, truncation=True,
                                          max_length=model.training_args.model_max_length, padding=False,
                                          return_attention_mask=False)
            
            # instruction_output = model.tokenizer(instruction_text + "\n\n", add_special_tokens=False, padding=False)
            # instruction_ids = instruction_output['input_ids'] 
            # instruction_ids = torch.tensor([instruction_ids]).to(device)
            # instruction_embeds = model.tokens_to_embeddings(instruction_ids)
            
            prompt_output = model.tokenizer(
                data['question'], add_special_tokens=False, padding=False)
            prompt_ids = [model.ft_token_id] + \
                prompt_output['input_ids'] + [model.ft_token_id]

            input_ids = torch.LongTensor([text_output['input_ids']]).to(device)
            print(input_ids.shape)
            original_num_tokens = input_ids.shape[1]
            memory_slots = model._compress(input_ids).to(device)
            print(memory_slots.shape)
            compressed_num_tokens = memory_slots.shape[0]
            prompt_ids = torch.tensor([prompt_ids]).to(device)
            prompt_embeds = model.tokens_to_embeddings(prompt_ids)

            print(prompt_embeds.device)
            print(memory_slots.device)
            # Concatenate and clone input embeddings
            decoder_input_embeddings = torch.cat(
                (memory_slots.unsqueeze(0), prompt_embeds), dim=1)  # instruction_embeds,
            print(memory_slots.shape)
            print(decoder_input_embeddings.dtype)
            output = decoder_input_embeddings.to(torch.bfloat16)
            generate_text = []
            past_key_values = None

            for i in range(MAX_NEW_TOKENS):
                with model.icae.disable_adapter():
                    out = model.icae(
                        inputs_embeds=output, past_key_values=past_key_values, use_cache=True)
                logit = out.logits[:, -1, :]
                past_key_values = out.past_key_values
                next_token_id = torch.argmax(logit, dim=-1)

                if (next_token_id == model.tokenizer.eos_token_id) or (next_token_id == 1):
                    break

                if next_token_id.item() >= 32000:
                    break

                output = model.icae.get_base_model().model.embed_tokens(
                    next_token_id).unsqueeze(1).to(device)
                generate_text.append(next_token_id.item())

            generated_text = model.tokenizer.decode(generate_text)

            output_ = {
                "question": data["question"],
                "generation": generated_text,
                "answers": data["answers"],
                "original_num_tokens": original_num_tokens,
                "compressed_num_tokens": compressed_num_tokens,
                
            }
            
            f.write(json.dumps(output_) + "\n")
            # exit(0)
            
            # count += 1
            # if count > 10:
            #     break
