import os

import argparse
from tqdm import tqdm
from src.llms import get_registed_model
import os
from datasets import load_dataset, Dataset
from src.utils.qa_utils import eval_path_result_w_ans
from src import utils
import json
from multiprocessing import Pool
from functools import partial
from src.qa_prompt_builder import PathGenerationWithAnswerPromptBuilder
from typing import List, Dict, Any, Tuple
import tiktoken


PATH_START_TOKEN = "<PATH>"
PATH_END_TOKEN = "</PATH>"


ANS_TEMPLATE = """# Reasoning Path:
{reasoning_path}
# Answer:
{answer}"""

TRAIN_QA_PROMPT_TEMPLATE = '''
Reasoning path is a sequence of triples in the KG that connects the topic entities in the question to answer entities. Given a question and potentially helpful additional information, please generate some reasoning paths in the KG starting from the topic entities to answer the question.

# Question: 
{question}
# Topic entities:
{entities}
Additional information:
{reasoning_paths}
'''

QA_PROMPT_TEMPLATE = '''<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Reasoning path is a sequence of triples in the KG that connects the topic entities in the question to answer entities. Given a question and potentially helpful additional information, please generate some reasoning paths in the KG starting from the topic entities to answer the question.

# Question: 
{question}
# Topic entities:
{entities}
Additional information:
{reasoning_paths}
<|im_end|>
<|im_start|>assistant
'''

def path_to_string(path: List[List[str]]) -> str:
    """
    Convert path to string format
    path: [["entity1", "relation1", "entity2"], ["entity2", "relation2", "entity3"]]
    Convert to: "entity1 -> relation1 -> entity2 -> relation2 -> entity3"
    """
    if not path or len(path) == 0:
        return ""
    
    # Extract all elements, avoiding duplicate intermediate entities
    elements = []
    for i, triple in enumerate(path):
        if len(triple) != 3:
            continue
        if i == 0:
            # First triple, add all three elements
            elements.extend(triple)
        else:
            # Subsequent triples, only add the relation and tail entity (avoid duplicating the head entity)
            elements.extend(triple[1:])
    
    return " -> ".join(elements)

def get_ground_truth_stats(ground_truth_paths, two_hop_paths, two_hop_scores):
    """
    Find score statistics in two_hop_paths for all paths in ground_truth_paths
    """
    # Convert paths to tuples for easier comparison and lookup
    two_hop_path_dict = {}
    for i, path in enumerate(two_hop_paths):
        path_tuple = tuple(tuple(triple) for triple in path)
        if path_tuple not in two_hop_path_dict:
            two_hop_path_dict[path_tuple] = []
        two_hop_path_dict[path_tuple].append(two_hop_scores[i])
    
    # Collect matched scores
    matched_scores = []
    for gt_path in ground_truth_paths:
        gt_path_tuple = tuple(tuple(triple) for triple in gt_path)
        if gt_path_tuple in two_hop_path_dict:
            matched_scores.extend(two_hop_path_dict[gt_path_tuple])
    
    if not matched_scores:
        return 0, 0, 0
    
    return max(matched_scores), min(matched_scores), sum(matched_scores) / len(matched_scores)

def build_qa_prompt(data, two_hop_k = 100):
    question = data["question"]
    
    ground_truth_paths = data["ground_truth_paths"]

    two_hop_paths = data["two_hop_paths"]
    two_hop_scores = data["two_hop_paths_scores"]

    max_score, min_score, avg_score = 0, 0, 0
    our_prompt_paths = []
    
    if len(ground_truth_paths) > 0:
        # Find indices and corresponding scores of ground_truth_paths in two_hop_paths
        max_score, min_score, avg_score = get_ground_truth_stats(
            ground_truth_paths, two_hop_paths, two_hop_scores
        )
        # Only take the top two_hop_k two-hop paths
        two_hop_paths = two_hop_paths[:two_hop_k] if len(two_hop_paths) > two_hop_k else two_hop_paths
        # Create a copy to avoid modifying the original data
        all_paths = two_hop_paths.copy()

        total_positions = len(all_paths) + len(ground_truth_paths)
        step = total_positions // (len(ground_truth_paths) + 1)
        for idx, ground_path in enumerate(ground_truth_paths):
            insert_position = min(step * (idx + 1), len(all_paths))
            all_paths.insert(insert_position, ground_path)

        # Use the merged paths
        our_prompt_paths = all_paths
    else:
        our_prompt_paths = two_hop_paths[:two_hop_k] if len(two_hop_paths) > two_hop_k else two_hop_paths

    # Do not use logits strength and logits filtering
    # our_prompt_paths = []

    paths_block = "\n".join(f"<PATH>{path_to_string(p)}</PATH>" for p in our_prompt_paths)

    # Replace with your own QA_PROMPT_TEMPLATE, and substitute question, [topic entity], and Reasoning Paths
    input = QA_PROMPT_TEMPLATE.replace("{question}", question).replace(
        "{entities}", ", ".join(topic_entity for topic_entity in data["q_entity"])
    ).replace(
        "{reasoning_paths}", paths_block
    )

    ground_truth_idx = []
    # Record the positions of ground_truth_paths in input
    for gt_path in ground_truth_paths:
        gt_path_str = path_to_string(gt_path)
        start = input.find(gt_path_str)
        end = start + len(gt_path_str) if start != -1 else -1
        if start != -1 and end != -1:
            ground_truth_idx.append((start, end))
    scores = [max_score, min_score, avg_score]

    return input, ground_truth_idx, scores

def merge_rule_result(qa_dataset, rule_dataset, n_proc=1, filter_empty=False):
    question_to_rule = dict()
    for data in rule_dataset:
        qid = data["id"]
        predicted_paths = data["prediction"]
        ground_paths = data["ground_paths"]
        question_to_rule[qid] = {
            "predicted_paths": predicted_paths,
            "ground_paths": ground_paths,
        }

    def find_rule(sample):
        qid = sample["id"]
        sample["predicted_paths"] = []
        sample["ground_paths"] = []
        sample["predicted_paths"] = question_to_rule[qid]["predicted_paths"]
        sample["ground_paths"] = question_to_rule[qid]["ground_paths"]
        return sample  # TODO: ignore the sample with zero paths.

    qa_dataset = qa_dataset.map(find_rule, num_proc=n_proc)
    if filter_empty:
        qa_dataset = qa_dataset.filter(
            lambda x: len(x["ground_paths"]) > 0, num_proc=n_proc
        )
    return qa_dataset

def get_output_file(path, force=False):
    if not os.path.exists(path) or force:
        fout = open(path, "w")
        return fout, []
    else:
        with open(path, "r") as f:
            processed_results = []
            for line in f:
                try:
                    results = json.loads(line)
                except:
                    raise ValueError("Error in line: ", line)
                processed_results.append(results["id"])
        fout = open(path, "a")
        return fout, processed_results

def prediction(data, processed_list, input_builder, model, reason_errors_LLM):
    question = data["question"]
    answer = data["answer"]
    id = data["id"]
    if id in processed_list:
        return None, "already processed"

    input_query, ground_paths, trie = input_builder.process_input(data)
    if trie is None:
        return None, "error, no trie"
    start_token_ids = model.tokenizer.convert_tokens_to_ids(
        input_builder.PATH_START_TOKEN
    )
    end_token_ids = model.tokenizer.convert_tokens_to_ids(input_builder.PATH_END_TOKEN)
    # input = model.prepare_model_prompt(input_query)

    # Replace with our QA_PROMPT_TEMPLATE
    input, ground_truth_idx, scores = build_qa_prompt(data, two_hop_k=10)

    prediction = model.generate_sentence(
        question,
        input,
        ground_truth_idx,
        scores,
        trie,
        start_token_ids=start_token_ids,
        end_token_ids=end_token_ids,
        enable_constrained_by_default=False,
        reason_errors_LLM=reason_errors_LLM,
    )
    if prediction is None:
        return None, "error,no prediction"
    result = {
        "id": id,
        "question": question,
        "prediction": prediction,
        "ground_truth": answer,
        "ground_truth_paths": ground_paths,
        "input": input,
    }
    return result, "success"


def main(args, LLM, reason_errors_LLM):
    input_file = os.path.join(args.data_path, args.d)
    # Load dataset
    dataset = load_dataset(input_file, split=args.split)
    post_fix = f"{args.prefix}{args.prompt_mode}-{args.generation_mode}-k{args.k}-index_len{args.index_path_length}"
    if args.add_rule:
        rule_postfix = args.rule_path.replace("/", "_").replace(".", "_")
        rule_dataset = utils.load_jsonl(args.rule_path)
        dataset = merge_rule_result(dataset, rule_dataset, args.n, args.filter_empty)
        post_fix += "_" + rule_postfix
    data_name = args.d + "_undirected" if args.undirected else args.d
    output_dir = os.path.join(args.predict_path, data_name, args.model_name, args.split, post_fix)
    print("Save results to: ", output_dir)

    # Predict
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    model = LLM(args)
    reason_errors_LLM = reason_errors_LLM(args)


    print("Prepare pipline for inference...")
    model.prepare_for_inference()

    input_builder = PathGenerationWithAnswerPromptBuilder(model.tokenizer, args.prompt_mode, index_path_length=args.index_path_length, undirected=args.undirected, add_rule=args.add_rule)
    
    # Save args file
    with open(os.path.join(output_dir, 'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    
    fout, processed_list =  get_output_file(os.path.join(output_dir, 'predictions.jsonl'), force=args.force)
    
    cases_statics = [0,0,0,0]

    if args.n > 1:
        with Pool(args.n) as p:
            for res in tqdm(
                p.imap(
                    partial(
                        prediction,
                        processed_list=processed_list,
                        input_builder=input_builder,
                        model=model,
                    ),
                    dataset,
                ),
                total=len(dataset),
            ):
                if res is not None:
                    if args.debug:
                        print(json.dumps(res))
                    fout.write(json.dumps(res) + "\n")
                    fout.flush()
    else:
        all_token = []
        for data in tqdm(dataset):

            res, information = prediction(data, processed_list, input_builder, model, reason_errors_LLM)
            if information:
                print(information)
            if information == "succuss":
                cases_statics[3] += 1
            elif information == "already processed":
                cases_statics[0] += 1
            elif information == "error, no trie":
                cases_statics[1] += 1
            elif information == "error,no prediction":
                cases_statics[2] += 1

            if res is not None:
                if args.debug:
                    print(json.dumps(res))
                fout.write(json.dumps(res) + "\n")
                fout.flush()
            else:
                print("None result for: ", data["id"])

    fout.close()
            
    eval_path_result_w_ans(os.path.join(output_dir, 'predictions.jsonl'))
    
if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--data_path', type=str, default='rmanluo')
    argparser.add_argument('--d', '-d', type=str, default='RoG-webqsp')
    argparser.add_argument('--split', type=str, default='test[:100]')
    argparser.add_argument('--index_path_length', type=int, default=2)
    argparser.add_argument('--predict_path', type=str, default='results/GenPaths')
    argparser.add_argument('--model_name', type=str, help="model_name for save results", default='gcr-Llama-2-7b-chat-hf')
    argparser.add_argument('--force', action='store_true', help="force to overwrite the results")
    argparser.add_argument("--n", type=int, default=1, help="number of processes")
    argparser.add_argument("--undirected", type=lambda x: (str(x).lower() == 'true'), default=False)
    argparser.add_argument("--debug", action="store_true", help="print debug information")
    argparser.add_argument("--prompt_mode", type=str, default="zero-shot", choices=["zero-shot", "mcq-zero-shot", "few-shot"])
    argparser.add_argument("--filter_empty", action="store_true")
    argparser.add_argument("--add_rule", action="store_true")
    argparser.add_argument(
        "--rule_path",
        type=str,
        default="results/gen_rule_path/webqsp_undirected/Llama-2-7b-chat-hf_align-spectoken-joint/test/predictions_3_False.jsonl",
    )
    argparser.add_argument("--prefix", type=str, default="")

    args, _  = argparser.parse_known_args()
    
    LLM = get_registed_model(args.model_name)
    LLM.add_args(argparser)
    
    reason_errors_LLM = get_registed_model("gpt-4o-mini")
    # reason_errors_LLM.add_args(argparser)
    args = argparser.parse_args()
    
    main(args, LLM, reason_errors_LLM)