import os
import sys
import json
import fire
import _pickle as pkl
import argparse
import numpy as np
from tqdm import tqdm
from collections import defaultdict

import torch
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoTokenizer, AutoModel
from peft import PeftModel

from utils import Prompter, read_relations_for_question, read_condensed_relations_for_question, convert_relation_to_text, load_entities, get_str_hit1


# evaluation
def evaluate(
        instruction=None,
        input=None,
        max_new_tokens=128,
        prompt_template: str = "",  # The prompt template to use, will default to alpaca.
        tokenizer=None,
        model=None,
):
    prompter = Prompter(prompt_template)
    prompt = prompter.generate_prompt(instruction, input)
    inputs = tokenizer(prompt, return_tensors='pt')
    input_ids = inputs["input_ids"].to(model.device)
    generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        top_k=40,
        num_beams=1,
    )

    generation_config.pad_token_id = model.config.pad_token_id
    generation_config.unk_token_id = model.config.unk_token_id
    generation_config.bos_token_id = model.config.bos_token_id
    generation_config.eos_token_id = model.config.eos_token_id

    # generate
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
        )

    qd_result = generation_output.sequences[0]
    output = tokenizer.decode(qd_result)
    return prompter.get_response(output)

def main(
    datasets: str = "WebQSP",
    base_model: str = "decapoda-research/llama-7b-hf",
    cache_path: str = "/data/home/",
    load_8bit: bool = True,
    qd_lora_weights: str = "../Keqing_WebQSP_QD",
    cr_lora_weights: str = "../Keqing_WebQSP_CR",
):
    # device
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    # dataset
    if datasets == 'WebQSP':
        test_file = json.load(open('../processed/WebQSP_test.json'))
        relations = pkl.load(open('../processed/WebQSP_relations.pkl', 'rb'))
        question2relation = defaultdict(dict)
        for split in ['train', 'test']:
            for rel_key in relations[split]:
                rel_value = relations[split][rel_key]
                if rel_value not in question2relation[split]:
                    question2relation[split][rel_value] = set(rel_key)
                else:
                    question2relation[split][rel_value] = question2relation[split][rel_value] | set(rel_key)
        question_templates = list(question2relation['train'].keys())

    # knowledge graphs
    from sentence_transformers import SentenceTransformer

    WEBQSP_DIR = os.path.realpath('/data/home/datasets/WebQSP')
    DATA_DIR = os.path.join(WEBQSP_DIR, "data")
    FREEBASE_DIR = os.path.join(DATA_DIR, "freebase_2hop")

    entity_names = load_entities(DATA_DIR)

    kg_passage_encoder = SentenceTransformer('facebook-dpr-ctx_encoder-single-nq-base').to('cuda:0')
    kg_query_encoder = SentenceTransformer('facebook-dpr-question_encoder-single-nq-base').to('cuda:0')
    embedding_size = 768

    # tokenizer
    question_tokenizer = AutoTokenizer.from_pretrained("roberta-base")

    qd_tokenizer = LlamaTokenizer.from_pretrained(base_model)
    qd_tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token)
    qd_tokenizer.bos_token_id = 0
    qd_tokenizer.eos_token_id = 0
    qd_tokenizer.unk_token_id = 0
    qd_tokenizer.padding_side = "left"  # Allow batched inference

    # model
    question_encoder = AutoModel.from_pretrained("roberta-base",
                                                 torch_dtype=torch.float32,
                                                 cache_dir=cache_path,
                                                 device_map="cuda:0")

    if device == "cuda":
        # question_decomposer
        qd_model = LlamaForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=True,
            torch_dtype=torch.float16,
            device_map="auto",
            cache_dir=cache_path
        )

        qd_model = PeftModel.from_pretrained(
            qd_model,
            qd_lora_weights,
            torch_dtype=torch.float16,
        )

        # context_reader
        cr_model = LlamaForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=True,
            torch_dtype=torch.float16,
            device_map="auto",
            cache_dir=cache_path
        )

        cr_model = PeftModel.from_pretrained(
            cr_model,
            cr_lora_weights,
            torch_dtype=torch.float16,
        )

    else:
        qd_model = LlamaForCausalLM.from_pretrained(
            base_model, device_map={"": device}, low_cpu_mem_usage=True
        )
        qd_model = PeftModel.from_pretrained(
            qd_model,
            qd_lora_weights,
            device_map={"": device},
        )
        cr_model = LlamaForCausalLM.from_pretrained(
            base_model, device_map={"": device}, low_cpu_mem_usage=True
        )
        cr_model = PeftModel.from_pretrained(
            cr_model,
            cr_lora_weights,
            device_map={"": device},
        )

    qd_model.config.pad_token_id = 0 # unk
    qd_model.config.unk_token_id = 0
    qd_model.config.bos_token_id = 0
    qd_model.config.eos_token_id = 0

    cr_model.config.pad_token_id = 0 # unk
    cr_model.config.unk_token_id = 0
    cr_model.config.bos_token_id = 0
    cr_model.config.eos_token_id = 0

    qd_model.eval()
    cr_model.eval()
    question_encoder.eval()
    if torch.__version__ >= "2" and sys.platform != "win32":
        qd_model = torch.compile(qd_model)
        cr_model = torch.compile(cr_model)

    # question template embeddings
    question_batch_input = question_tokenizer(question_templates, padding=True, truncation=True, return_tensors='pt')
    question_batch_input['input_ids'] = question_batch_input['input_ids'].to('cuda:0')
    question_batch_input['attention_mask'] = question_batch_input['attention_mask'].to('cuda:0')
    question_batch_output = question_encoder(input_ids=question_batch_input['input_ids'],
                                             attention_mask=question_batch_input['attention_mask'])

    question_embeddings = torch.sum(question_batch_output.last_hidden_state * question_batch_input['attention_mask'].unsqueeze(-1), dim=-2)
    question_embeddings = question_embeddings / torch.sum(question_batch_input['attention_mask'], dim=-1, keepdim=True)
    question_embeddings = question_embeddings / torch.sqrt(torch.sum(question_embeddings*question_embeddings, dim=-1)).unsqueeze(-1)

    qd_result_list = []
    hit1_list = []
    for idx in tqdm(range(10)):
    # for idx in tqdm(range(len(test_file))):
        qd_result = {}
        qd_qid = test_file[idx]['qid']
        qd_result["Instruction"] = test_file[idx]['instruction']
        qd_result["Input"] = test_file[idx]['input']
        qd_result["Question_Template"] = test_file[idx]['output']
        qd_result["QD_Result"] = evaluate(qd_result["Instruction"], qd_result["Input"], max_new_tokens=128, tokenizer=qd_tokenizer, model=qd_model)
        qd_question = qd_result["QD_Result"].split('"')[3]

        # question retrieval and obtain the logical chains
        if qd_question in question_templates:
            qd_logical_chains = question2relation['train'][qd_question]
        else:
            question_query_input = question_tokenizer(qd_question, padding=True, truncation=True, return_tensors='pt')
            question_query_input['input_ids'] = question_query_input['input_ids'].to('cuda:0')
            question_query_input['attention_mask'] = question_query_input['attention_mask'].to('cuda:0')
            question_query_output = question_encoder(input_ids=question_query_input['input_ids'],
                                                     attention_mask=question_query_input['attention_mask'])
            question_query_embedding = torch.mean(question_query_output.last_hidden_state, dim=-2)
            question_query_embedding = question_query_embedding / torch.sqrt(torch.sum(question_query_embedding*question_query_embedding, dim=-1)).unsqueeze(-1)

            question_similarity = torch.sum(question_embeddings*question_query_embedding, dim=-1)
            question_similarity = np.argsort(-question_similarity.detach().cpu().numpy())
            question_query_template = question_templates[question_similarity[0]]

            qd_logical_chains = question2relation['train'][question_query_template]

        print(qd_question, qd_logical_chains)
        qd_result_list.append(qd_result)

        # search the context on knowledge graph
        qd_relations = read_relations_for_question(qd_qid, kg_data_path=FREEBASE_DIR)
        qd_condensed_relations = read_condensed_relations_for_question(qd_qid, kg_data_path=FREEBASE_DIR)
        qd_condensed_relations = [tmp[1] for tmp in qd_condensed_relations]
        qd_all_relations = qd_relations + qd_condensed_relations
        qd_good_relations = []
        qd_good_passages = []
        qd_bad_relations = []
        qd_bad_passages = []

        for relation in qd_all_relations:
            if relation.rel[4:-1] in qd_logical_chains:
                title, text = convert_relation_to_text(relation, entity_names)
                passage = title + ' [SEP] ' + text
                qd_good_passages.append(passage)
                qd_good_relations.append(text)

            else:
                title, text = convert_relation_to_text(relation, entity_names)
                passage = title + ' [SEP] ' + text
                qd_bad_passages.append(passage)
                qd_bad_relations.append(text)

        import faiss

        if len(qd_good_passages) >= 20:
            qd_topk_context = []
            query_embedding = kg_query_encoder.encode(qd_result["Input"])
            passage_embeddings = kg_passage_encoder.encode(qd_good_passages)
            index = faiss.IndexFlatIP(embedding_size)
            index.add(passage_embeddings)
            scores, idxs = index.search(np.expand_dims(query_embedding, 0), 20)
            scores = scores[0]
            idxs = idxs[0]

            for score, idx in zip(scores, idxs):
                qd_topk_context.append((qd_good_relations[idx]))

        else:
            if len(qd_bad_passages) > 10000:
                qd_bad_passages = qd_bad_passages[:10000]

            qd_topk_context = qd_good_relations
            query_embedding = kg_query_encoder.encode(qd_result["Input"])
            passage_embeddings = kg_passage_encoder.encode(qd_bad_passages)
            index = faiss.IndexFlatIP(embedding_size)
            index.add(passage_embeddings)
            scores, idxs = index.search(np.expand_dims(query_embedding, 0), 20-len(qd_good_relations))
            scores = scores[0]
            idxs = idxs[0]

            for score_, idx_ in zip(scores, idxs):
                qd_topk_context.append((qd_bad_relations[idx_]))

        # generate answer condtioned on the context
        cr_result = {}
        cr_qid = test_file[idx]['qid']
        cr_result["Instruction"] = f"Use the following pieces of context to answer the users question.  If you don't know the answer, just say that you don't know, don't try to make up an answer.  \n ----------------\n {qd_topk_context}"
        cr_result["Input"] = test_file[idx]['input']
        cr_result["GroudTruth"] = test_file[idx]['answers']
        cr_result["CR_Result"] = evaluate(cr_result["Instruction"], cr_result["Input"], max_new_tokens=128, tokenizer=qd_tokenizer, model=cr_model)

        print(cr_result["GroudTruth"], '\n', cr_result["CR_Result"])

        # evaluation
        cr_hit1 = get_str_hit1(cr_result["GroudTruth"], cr_result["CR_Result"])
        hit1_list.append(cr_hit1)

    print(f'total_hit1@{np.mean(hit1_list)}')

if __name__ == '__main__':
    fire.Fire(main)