import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 
import argparse  
import torch  
import pandas as pd  
from transformers import AutoModel, AutoTokenizer  
from datasets import Dataset, load_from_disk  
import numpy as np  
import pdb  
import json
from tabulate import tabulate
from prompts.bird_prompt_combine.prompt.bird_all import rag_phi_prompt_strict, rag_phi_sql_3_shot, rag_phi_sql_3_shot_cot, rag_phi_sql_3_shot_code
  
def initialize_model_and_tokenizer(checkpoint, device):  
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)  
    model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)  
    return tokenizer, model  

def dataframe_to_markdown(df):  
    """  
    Convert a Pandas DataFrame into a Markdown table.  
  
    Parameters:  
    - df (pd.DataFrame): The DataFrame to convert.  
  
    Returns:  
    - str: The Markdown table as a string.  
    """  
    # Convert the DataFrame to a Markdown table using tabulate  
    markdown_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False)  
    return markdown_table  
  
def get_embeddings(text_or_texts, tokenizer, model, device):  
    """  
    Generate embeddings for a given text or a list of texts using the specified Transformer model.  
      
    Args:  
    - text_or_texts: A string or a list of strings for which the embeddings are to be generated.  
      
    Returns:  
    - A tensor representing the embeddings of the input text(s), with shape [batch_size, embedding_dim].  
    """  
    # Check if the input is a single string or a list of strings  
    # and ensure it is in the format expected by the tokenizer  
    if isinstance(text_or_texts, str):  
        texts_to_process = [text_or_texts]  # Make it a list of one string  
    else:  
        texts_to_process = text_or_texts  
      
    # # Tokenize the input text(s) with appropriate padding and truncation  
    # # and move the tokenized input to the same device as the model  
    # encoded_input = tokenizer(texts_to_process, return_tensors="pt", padding=True, truncation=True).to(device)  
      
    # # Generate model output using the encoded input  
    # # Directly access the tensor representing the embeddings  
    # embeddings = model(**encoded_input)[0]  
    
      
    with torch.no_grad():  # Disable gradient computation
        encoded_input = tokenizer(texts_to_process, return_tensors="pt", padding=True, truncation=True).to(device)  
        embeddings = model(**encoded_input)[0]  
        if embeddings.dim() == 1:  
            embeddings = embeddings.unsqueeze(0)  # Add a batch dimension if it's missing  
        # embeddings = embeddings.cpu().detach().numpy()  # Move to CPU and convert to numpy
    
    del encoded_input
    torch.cuda.empty_cache()  # Clear GPU memory
    return embeddings

  
def load_dataset_from_jsonl(file_path):  
    df = pd.read_json(file_path, lines=True)  
    return Dataset.from_pandas(df)  
  
def compute_and_store_embeddings(dataset, tokenizer, model, device):  
    embeddings = dataset.map(  
        lambda x: {"embeddings": get_embeddings(x["function_code"], tokenizer, model, device).detach().cpu().numpy()[0]}
    )  
    return embeddings  

def release_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

def wrap_up_rag_prompt(prompt_template, rag_results):
    case_studies_string = ""
    for i, rag_result in enumerate(rag_results):
        # pdb.set_trace()
        case_number = i + 1
        case_study = rag_result.get('case_study', '')
        case_studies_string += f"# Case Study {case_number}\n{case_study}\n\n"

    rag_prompt = prompt_template.replace('[[case_study]]', case_studies_string)
    
    return rag_prompt

def wrap_up_rag_few_shot_prompt(prompt_template, rag_results):
    few_shot_string = ""
    for i, rag_result in enumerate(rag_results):
        # pdb.set_trace()
        example_number = i + 1
        final_code = str(rag_result.get('final_code', ''))
        question = rag_result.get('question', '')
        evidence = rag_result.get('evidence', '')
        full_question = question + ' ' + evidence
        data_overview = rag_result.get('data_overview', '')
        example = f"# The schema and data overviews:\n{data_overview}\n...\n\n# Question: {full_question}\n\n# SQLite:\n```sqlite\n{final_code}\n```"
        # example = f"## Question: {question}\n#Thought: I need to see the data samples in the first 3 rows:\n\n#Code:\n```python\nimport pandas as pd\ndf = pd.read_csv('{data_path}, sep='\t)\nprint(df.head(3))\n```\n\n#Observation:\n{data_overview}\n\n# Thought:\nI can generate code to answer this question:\n\n# Code:\n```python\n{final_code}\n```"
        # example = example.replace('[[final_code]]', final_code). replace('[[data_path]]', data_path).replace('[[data_title]]', data_title).replace('[[question]]', question).replace('[[data_overview]]', data_overview)
        
        few_shot_string += f"# Example {example_number}\n{example}\n\n"

    rag_prompt = prompt_template.replace('[[few_shot_examples]]', few_shot_string)
    # pdb.set_trace()
    return rag_prompt

def wrap_up_rag_few_shot_gt(prompt_template, rag_results):
    few_shot_string = ""
    for i, rag_result in enumerate(rag_results):
        # pdb.set_trace()
        example_number = i + 1
        final_code = str(rag_result.get('sql', ''))
        question = rag_result.get('question', '')
        evidence = rag_result.get('evidence', '')
        full_question = question + ' ' + evidence
        data_overview = rag_result.get('data_overview', '')
        example = f"# The schema and data overviews:\n{data_overview}\n...\n\n# Question: {full_question}\n\n# SQLite:\n```sqlite\n{final_code}\n```"
        # example = f"## Question: {question}\n#Thought: I need to see the data samples in the first 3 rows:\n\n#Code:\n```python\nimport pandas as pd\ndf = pd.read_csv('{data_path}, sep='\t)\nprint(df.head(3))\n```\n\n#Observation:\n{data_overview}\n\n# Thought:\nI can generate code to answer this question:\n\n# Code:\n```python\n{final_code}\n```"
        # example = example.replace('[[final_code]]', final_code). replace('[[data_path]]', data_path).replace('[[data_title]]', data_title).replace('[[question]]', question).replace('[[data_overview]]', data_overview)
        
        few_shot_string += f"# Example {example_number}\n{example}\n\n"

    rag_prompt = prompt_template.replace('[[few_shot_examples]]', few_shot_string)
    # pdb.set_trace()
    return rag_prompt
        

def rag_for_dataset(data_dataset_path, output_file_path, query_embeddings_dataset, tokenizer, model, device):
    
    
    with open(data_dataset_path, 'r', encoding='utf-8') as file, open(output_file_path, 'w', encoding='utf-8') as outfile:  
        for line in file:  
            data = json.loads(line)  
            question = data['question']  
              
            # Compute embeddings for the function_code  
            question_embedding = get_embeddings([question], tokenizer, model, device).cpu().detach().numpy()
            
            # Find relevant entries in the query_dataset  
            scores, samples = query_embeddings_dataset.get_nearest_examples("embeddings", question_embedding, k=3)  
            del question_embedding  
            release_gpu_memory() 
              
            # Prepare the result to be written to the file  
            # pdb.set_trace()
            
            # for normal and inference
            sample_ids = samples['question_id']
            sample_questions = samples['question']
            sample_evidences = samples['evidence']
            sample_data_overviews = samples['data_overview']
            sample_final_sqls = samples['final_code']
            sample_case_studies = samples['case_study']
            sample_scores = scores
            
            rag_result = [{'question_id': str(question_id), 'question': str(question), 'evidence': str(evidence),'data_overview': str(sample_data_overview), 'final_code': str(sample_final_sql), 'case_study': str(sample_case_study), 'score': str(sample_score)} for question_id, question, evidence, sample_data_overview, sample_final_sql, sample_case_study, sample_score in zip(sample_ids, sample_questions, sample_evidences, sample_data_overviews, sample_final_sqls, sample_case_studies, sample_scores)]  
            
            
            # for gt:
            sample_ids = samples['question_id']
            sample_questions = samples['question']
            sample_evidences = samples['evidence']
            sample_data_overviews = samples['data_overview']
            sample_final_sqls = samples['sql']
            sample_scores = scores
            
            rag_result_2 = [{'question_id': str(question_id), 'question': str(question), 'evidence': str(evidence), 'data_overview': str(sample_data_overview), 'final_code': str(sample_final_sql), 'score': str(sample_score)} for question_id, question, evidence, sample_data_overview, sample_final_sql, sample_score in zip(sample_ids, sample_questions, sample_evidences, sample_data_overviews, sample_final_sqls, sample_scores)]
            
            data['rag_result'] = rag_result    
            data['rag_prompt'] = wrap_up_rag_prompt(prompt_template=rag_phi_prompt_strict, rag_results=rag_result)
            # data['rag_few_prompt_base'] = wrap_up_rag_few_shot_prompt(prompt_template=rag_phi_sql_3_shot, rag_results=rag_result)
            # data['rag_few_prompt_base'] = wrap_up_rag_few_shot_gt(prompt_template=rag_phi_sql_3_shot, rag_results=rag_result)
            data['rag_few_prompt_cot'] = wrap_up_rag_few_shot_prompt(prompt_template=rag_phi_sql_3_shot_cot, rag_results=rag_result)
            data['rag_few_prompt_code'] = wrap_up_rag_few_shot_gt(prompt_template=rag_phi_sql_3_shot_code, rag_results=rag_result)
            # pdb.set_trace()
            # Write the result to the output file  
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n')  
            release_gpu_memory()  # Clear GPU memory
            
  
def main(args):  
    device = torch.device(args.device)  
    tokenizer, model = initialize_model_and_tokenizer(args.checkpoint, device)  
  
    # Load query dataset  
    query_dataset = load_dataset_from_jsonl(args.query_dataset_path)  
  
    # Compute embeddings for the query dataset and store them  
    query_embeddings_dataset = query_dataset.map(lambda x: {"embeddings": get_embeddings(x["question"], tokenizer, model, device).detach().cpu().numpy()[0]})
    query_embeddings_dataset.add_faiss_index(column="embeddings")
    
    
    rag_for_dataset(
        data_dataset_path=args.data_dataset_path,
        output_file_path=args.output_file_path,
        query_embeddings_dataset=query_embeddings_dataset,
        tokenizer=tokenizer,
        model=model,
        device=device
    )
  
    
  
if __name__ == "__main__":  
    parser = argparse.ArgumentParser(description="Process datasets and compute embeddings.")  
    parser.add_argument("--checkpoint", type=str, default="Salesforce/codet5p-110m-embedding", help="Model checkpoint for embeddings.")  
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on, e.g., 'cuda' or 'cpu'.")  
    parser.add_argument("--data_dataset_path", type=str, required=True, help="Path to the data JSONL file.")  
    parser.add_argument("--query_dataset_path", type=str, required=True, help="Path to the query JSONL file.")  
    parser.add_argument("--output_file_path", type=str, required=True, help="Path to the query JSONL file.") 
  
    args = parser.parse_args()  
    main(args)  
