import argparse
import os
import torch
import json
import re

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset   
from datetime import datetime

from memgpt.database.utils.utils_database import load_current_database, DatabaseManager, DatabaseLookupError


success = 0
USE_SPECIAL_TOKENS = True

def load_args():
    args = argparse.ArgumentParser()
    args.add_argument("--save-dir", type=str, required=True)
    args.add_argument("--database-path", type=str, default="", help="Path to the database file.")
    args.add_argument("--eval_dataset_path", type=str, default="", help="Path to the eval dataset file.")
    args.add_argument("--model", type=str, default="gpt2")
    args.add_argument("--sentence-model", type=str, default="sentence-transformers/all-mpnet-base-v2")

    args.add_argument("--dataset", type=str, default="wikipedia")
    args.add_argument("--cache-dir", type=str, default=None)
    args.add_argument("--num-samples", type=int, default=10)
    args.add_argument("--world-size", type=int, default=4)

    ## sampling parameters  
    args.add_argument("--temperature", type=float, default=0) # default greedy decoding
    args.add_argument("--top-p", type=float, default=0.9)
    args.add_argument("--max-new-tokens", type=int, default=None)
    args.add_argument("--repetition-penalty", type=float, default=1.5)
    args.add_argument("--seed", type=int, default=42)

    args.add_argument("--enable_dblookup", action="store_true", help="Enable database lookup for entity and relationship extraction.")
    args.add_argument("--threshold", type=float, default=0.7, help="Threshold for top-k retrieval.")

    args = args.parse_args()
    return args

def get_loggings(args): 
    args_postfix = f"t{args.temperature}_p{args.top_p}_s{args.seed}_rep{args.repetition_penalty}_th{args.threshold}_len{args.max_new_tokens}"  
    print(f"Logging postfix: {args_postfix}")

    model_name = get_model_name(args)
    logging_file = os.path.join(args.save_dir, f"{args.dataset}_{model_name}.jsonl")
    os.makedirs(os.path.dirname(logging_file), exist_ok=True)
    return logging_file
    
def get_model_name(args):
    path_parts = args.model.rstrip('/').split('/')
    model_name = path_parts[-2] + "_ckpt" + path_parts[-1].split('-')[-1] if "checkpoint" in path_parts[-1] else path_parts[-1]
    return model_name + "_dblookup" if args.enable_dblookup else model_name

def normalize_db_format(text):
    text = re.sub(r'<\|db_entity\|>\s*', '<|db_entity|> ', text)
    text = re.sub(r'<\|db_relationship\|>\s*', '<|db_relationship|> ', text)
    text = re.sub(r'<\|db_return\|>\s*', '<|db_return|> ', text)
    text = re.sub(r'<\|db_end\|>\s*', '<|db_end|> ', text)
    return text


def generate_response(prompts):
    encoded_text = tokenizer.encode(prompts)
    is_finished = False  # Initialize the flag
    
    if len(encoded_text) >= args.max_new_tokens:
        is_finished = True
        return "", is_finished  
        
    response = llm.generate(prompts=prompts,
                    sampling_params=sampling_params,
                    use_tqdm=False)

    encoded_text += response[0].outputs[0].token_ids
    output_text = tokenizer.decode(encoded_text, clean_up_tokenization_spaces=True)
    output_text = normalize_db_format(output_text)
    
    if len(encoded_text) > args.max_new_tokens:
        is_finished = True
    
    if prompts in output_text:
        output_text = output_text.split(prompts)[-1]
    else:
        output_text = tokenizer.decode(response[0].outputs[0].token_ids, clean_up_tokenization_spaces=True) 
        output_text = normalize_db_format(output_text)

    # Check if the last token is an EOS, BOS, or special token
    last_token = response[0].outputs[0].token_ids[-1] if response[0].outputs[0].token_ids else None
    special_tokens = (tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids("<s>"))
    
    is_finished = is_finished or (last_token in special_tokens)

    return output_text, is_finished


def extract_non_dblookup_segments(text, use_special_tokens=True):
    """
    Extracts segments of text that are not within DB lookup tokens.
    
    Args:
        text (str): The input text containing DB lookup tokens
        use_special_tokens (bool): Whether to use <|db_*|> tokens (True) or [dblookup()] format (False)
        
    Returns:
        list: List of text segments that are outside the DB lookup tokens
    """    
    # For special tokens, create a pattern that matches the entire DB lookup
    if use_special_tokens:
        pattern = r"<\|db_entity\|>.*?<\|db_end\|>"
    else:
        # For regular dblookup format, create a pattern that matches the entire lookup
        pattern = r"\[dblookup\(.*?\) ->.*?\]"
    
    # Find all matches for DB lookups
    matches = list(re.finditer(pattern, text, re.DOTALL))
    
    # If no matches, return the whole text as a single segment
    if not matches:
        return [text] if text.strip() else []
    
    # Extract segments that are not within DB lookups
    segments = []
    last_end = 0
    
    for match in matches:
        start, end = match.span()
        # Add segment before this match (if any)
        if start > last_end:
            segment = text[last_end:start]
            if segment.strip():  # Only add non-empty segments
                segments.append(segment)
        last_end = end
    
    # Add the final segment after the last match (if any)
    if last_end < len(text):
        segment = text[last_end:]
        if segment.strip():  # Only add non-empty segments
            segments.append(segment)
    
    return segments
 
 
if __name__=="__main__":
    success = 0
    total_time = []
    total_words = []

    args = load_args()

    if args.enable_dblookup:
        if not ("dwiki" in args.model and "plain" not in args.model):
            print(f"Database lookup can only be enabled for models trained on the dwiki dataset and with dblookup patterns, but not for {args.model}")
            args.enable_dblookup = False
        if "new" in args.model:
            USE_SPECIAL_TOKENS = True
            print("Using special tokens for dblookup.") 

    logging_file = get_loggings(args)
    if os.path.exists(logging_file):    
        # with open(logging_file, "r") as f:
        #     data = [json.loads(line) for line in f]
        data = load_dataset("json", data_files=logging_file, field="examples", split="train")
        if len(data) >= args.num_samples:
            print(f"Already generated {len(data)} samples. Exiting.")
            exit()

    if args.enable_dblookup:
        if not args.database_path or not os.path.exists(args.database_path):
            db_manager = load_current_database("./database")
        else:
            db_manager = DatabaseManager()
            db_manager.load_database(args.database_path)

        print(f"Loaded database {db_manager}")
    
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, legacy=False)
    llm = LLM(model=args.model, tensor_parallel_size=args.world_size, max_model_len=args.max_new_tokens+48, gpu_memory_utilization=0.85, dtype=torch.bfloat16, seed=args.seed, tokenizer=args.model)
    
    # with open(args.eval_dataset_path, "r") as f:
    #     entity_lst = f.readlines()
    eval_dataset = load_dataset("json", data_files=args.eval_dataset_path, field="examples", split="train")
    # subset
    eval_dataset = eval_dataset.select(range(min(args.num_samples, len(eval_dataset))))
    # split with dblookup
    
    # Stop generation when sampling end or start tokens
    stop_token_ids = [tokenizer.eos_token_id, tokenizer.bos_token_id, tokenizer.convert_tokens_to_ids("<s>")]

    stop_token_ids = []  # Initialize as empty list or with your default values
    include_stop_str_in_output = True  # Default value
    skip_special_tokens = True  # Default value
    logit_bias = {}  # Initialize as empty dict

    if args.enable_dblookup:
        stop_token_ids += [tokenizer.encode(" ->")[-1], tokenizer.encode("->")[-1], tokenizer.convert_tokens_to_ids("<|db_return|>")]
        include_stop_str_in_output = False
        skip_special_tokens = False

        entity_token_id = tokenizer.convert_tokens_to_ids("<|db_entity|>")
        relationship_token_id = tokenizer.convert_tokens_to_ids("<|db_relationship|>")
        return_token_id = tokenizer.convert_tokens_to_ids("<|db_return|>")
        end_token_id = tokenizer.convert_tokens_to_ids("<|db_end|>")
        logit_bias = {entity_token_id: 5.0, relationship_token_id: 2.0, return_token_id: 2.0, end_token_id: 2.0}

    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_new_tokens,
        seed=args.seed,
        repetition_penalty = args.repetition_penalty,
        stop_token_ids=stop_token_ids,
        include_stop_str_in_output=include_stop_str_in_output,
        skip_special_tokens=skip_special_tokens,
        logit_bias=logit_bias,
        # spaces_between_special_tokens=spaces_between_special
    )
    
    # def generate_dblookup(annotation_text):
    def generate_dblookup(example):
        global success
        
        annotation_text = example["annotated_text"] 
        segments = extract_non_dblookup_segments(annotation_text, use_special_tokens=False)
        prompt = ""
        for i, seg in enumerate(segments):
            try:
                prompt += seg

                if i == len(segments) - 1:
                    break

                print(f"Prompt: {prompt}")
                response, is_finished = generate_response(prompt + "<|db_entity|>")
                print(f"Response: {response}")

                if is_finished:
                    prompt += response
                    break

                db_manager.init_topk_retriever(default_threshold=args.threshold) 
                return_value = db_manager.retrieve_from_database("<|db_entity|>" + response)
                print(f"Return value: {return_value}")  

                ## debug: fake return value for No relevant data found
                if "No relevant data found" in return_value or "unknown" in return_value:
                    return_value = None
                else:
                    success += 1
                
                if USE_SPECIAL_TOKENS:
                    # Only add the dblookup call if there is a successful return value
                    prompt = prompt + "<|db_entity|>" + response + return_value + "<|db_end|>" if return_value else prompt.strip()

            except DatabaseLookupError as e:
                print(f"Database lookup error: {e}")
            
        print("*"*20)
        print(f"[Model]: {prompt}")
        # return prompt
        example["dynamic_text"] = prompt
        return example

    eval_dataset = eval_dataset.map(generate_dblookup)

    with open(logging_file, "w") as f:
        json.dump({
            "examples": eval_dataset.to_list(),
            "metadata": {
                "source_datasets": [args.eval_dataset_path],
                "last_modified": datetime.now().isoformat()
            }   
        }, f, indent=4)

    print(f"saved to {logging_file}")  

    print("Failure Statistics:")
    print(DatabaseLookupError.get_failure_statistics())
    print(f"Success times: {success}")

    num_failures = sum(DatabaseLookupError.get_failure_statistics().values())
    success_rate = success / (success + num_failures) if success + num_failures > 0 else 0
    print(f"Success rate: {success_rate}")