import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 
import argparse  
import torch  
import pandas as pd  
from transformers import AutoModel, AutoTokenizer  
from datasets import Dataset, load_from_disk  
import numpy as np  
import pdb  
import json
from tabulate import tabulate
from prompts.bird_prompt_combine.prompt.bird_all import rag_phi_prompt_v1, rag_phi_sql_3_shot, rag_phi_sql_3_shot_cot, rag_phi_sql_3_shot_code
from prompts.wikitq_prompt_combine.prompt.wikitq_all import rag_phi_prompt_v1,rag_llama_prompt_v1, baseline_phi_few_shot_rag, baseline_phi_few_shot_rag_distill, baseline_phi_few_shot_rag_gen_distill
from code_utils import remove_comments

def initialize_model_and_tokenizer(checkpoint, device):  
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)  
    model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)  
    return tokenizer, model  

def dataframe_to_markdown(df):  
    """  
    Convert a Pandas DataFrame into a Markdown table.  
  
    Parameters:  
    - df (pd.DataFrame): The DataFrame to convert.  
  
    Returns:  
    - str: The Markdown table as a string.  
    """  
    # Convert the DataFrame to a Markdown table using tabulate  
    markdown_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False)  
    return markdown_table  
  
def get_embeddings(text_or_texts, tokenizer, model, device):  
    """  
    Generate embeddings for a given text or a list of texts using the specified Transformer model.  
      
    Args:  
    - text_or_texts: A string or a list of strings for which the embeddings are to be generated.  
      
    Returns:  
    - A tensor representing the embeddings of the input text(s), with shape [batch_size, embedding_dim].  
    """  
    # Check if the input is a single string or a list of strings  
    # and ensure it is in the format expected by the tokenizer  
    if isinstance(text_or_texts, str):  
        texts_to_process = [text_or_texts]  # Make it a list of one string  
    else:  
        texts_to_process = text_or_texts  
      
    # # Tokenize the input text(s) with appropriate padding and truncation  
    # # and move the tokenized input to the same device as the model  
    # encoded_input = tokenizer(texts_to_process, return_tensors="pt", padding=True, truncation=True).to(device)  
      
    # # Generate model output using the encoded input  
    # # Directly access the tensor representing the embeddings  
    # embeddings = model(**encoded_input)[0]  
    
      
    with torch.no_grad():  # Disable gradient computation
        encoded_input = tokenizer(texts_to_process, return_tensors="pt", padding=True, truncation=True).to(device)  
        embeddings = model(**encoded_input)[0]  
        if embeddings.dim() == 1:  
            embeddings = embeddings.unsqueeze(0)  # Add a batch dimension if it's missing  
        # embeddings = embeddings.cpu().detach().numpy()  # Move to CPU and convert to numpy
    
    del encoded_input
    torch.cuda.empty_cache()  # Clear GPU memory
    return embeddings

  
def load_dataset_from_jsonl(file_path):  
    df = pd.read_json(file_path, lines=True)  
    return Dataset.from_pandas(df)  
  
def compute_and_store_embeddings(dataset, tokenizer, model, device):  
    embeddings = dataset.map(  
        lambda x: {"embeddings": get_embeddings(x["function_code"], tokenizer, model, device).detach().cpu().numpy()[0]}
    )  
    return embeddings  

def release_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()


def wrap_up_rag_prompt(prompt_template, rag_results):
    case_studies_string = ""
    for i, rag_result in enumerate(rag_results):
        # pdb.set_trace()
        case_number = i + 1
        case_study = rag_result.get('case_study', '')
        case_studies_string += f"# Case Study {case_number}\n{case_study}\n\n"

    rag_prompt = prompt_template.replace('[[case_study]]', case_studies_string)
    
    return rag_prompt

def wrap_up_rag_few_shot_prompt(prompt_template, rag_results):
    few_shot_string = ""
    for i, rag_result in enumerate(rag_results):
        # pdb.set_trace()
        example_number = i + 1
        final_code = str(rag_result.get('final_code', ''))
        data_path = rag_result.get('data_path', '')
        question = rag_result.get('question', '')
        data_overview = rag_result.get('data_overview', '')
        example = f"# Data Overview at the path {data_path} (first ten rows):\n## first ten rows:\n{data_overview}\n...\n\n# Question: {question}\n\n#Code:\n```python\n{final_code}\n```"
        
        few_shot_string += f"# Example {example_number}\n{example}\n\n"

    rag_prompt = prompt_template.replace('[[few_shot_examples]]', few_shot_string)
    # pdb.set_trace()
    return rag_prompt



def wrap_up_rag_few_shot_gen_prompt(prompt_template, rag_results):
    few_shot_string = ""
    for i, rag_result in enumerate(rag_results):
        # pdb.set_trace()
        example_number = i + 1
        final_code = str(rag_result.get('final_code', ''))
        if final_code:
            final_code = remove_comments(final_code)
        
        data_path = rag_result.get('data_path', '')
        question = rag_result.get('question', '')
        data_overview = rag_result.get('data_overview', '')
        example = f"# Data Overview at the path {data_path} (first ten rows):\n## first ten rows:\n{data_overview}\n...\n\n# Question: {question}\n\n#Code:\n```python\n{final_code}\n```"
        
        few_shot_string += f"# Example {example_number}\n{example}\n\n"

    rag_prompt = prompt_template.replace('[[few_shot_examples]]', few_shot_string)
    # pdb.set_trace()
    return rag_prompt
     

def rag_for_dataset(data_dataset_path, output_file_path, query_embeddings_dataset, tokenizer, model, device):
    
    
    with open(data_dataset_path, 'r', encoding='utf-8') as file, open(output_file_path, 'w', encoding='utf-8') as outfile:  
        for line in file:  
            data = json.loads(line)  
            question = data['question']  
              
            # Compute embeddings for the function_code  
            question_embedding = get_embeddings([question], tokenizer, model, device).cpu().detach().numpy()
            
            # Find relevant entries in the query_dataset  
            scores, samples = query_embeddings_dataset.get_nearest_examples("embeddings", question_embedding, k=3)  
            del question_embedding  
            release_gpu_memory() 
              
            # Prepare the result to be written to the file  
            # pdb.set_trace()
            
            sample_ids = samples['question_id']
            sample_questions = samples['question']
            sample_data_path = samples['data_path']
            # sample_data_overviews = convert_tsvs_to_markdown(sample_data_path=sample_data_path)
            sample_data_overviews = samples['data_overview']
            sample_final_codes = samples['final_code']
            sample_case_studies = samples['case_study']
            sample_scores = scores
            
            # rag_result = [{'question_id': question_id, 'question': question, 'data_path': data_path, 'data_title': data_title, 'final_code': final_code, 'case_study': case_study, 'data_overview': data_overview, 'score': str(score)} for question_id, question, data_path, data_title, final_code, case_study, data_overview, score in zip(sample_ids, sample_questions, sample_data_path, sample_data_titles, sample_final_codes, sample_case_studies, sample_data_overviews, sample_scores)]  
            rag_result = [{'question_id': question_id, 'question': question, 'data_path': data_path, 'final_code': final_code, 'case_study': case_study, 'data_overview': data_overview, 'score': str(score)} for question_id, question, data_path, final_code, case_study, data_overview, score in zip(sample_ids, sample_questions, sample_data_path, sample_final_codes, sample_case_studies, sample_data_overviews, sample_scores)]  
            
            data['rag_result'] = rag_result    
            data['rag_prompt'] = wrap_up_rag_prompt(prompt_template=rag_llama_prompt_v1, rag_results=rag_result)
            data['rag_few_shot_prompt'] = wrap_up_rag_few_shot_prompt(prompt_template=baseline_phi_few_shot_rag_distill, rag_results=rag_result)
            data['rag_few_shot_gen_prompt'] = wrap_up_rag_few_shot_prompt(prompt_template=baseline_phi_few_shot_rag_gen_distill, rag_results=rag_result)
            # pdb.set_trace()
            # Write the result to the output file  
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n')  
            release_gpu_memory()  # Clear GPU memory
            
  
def main(args):  
    device = torch.device(args.device)  
    tokenizer, model = initialize_model_and_tokenizer(args.checkpoint, device)  
  
    # Load query dataset  
    query_dataset = load_dataset_from_jsonl(args.query_dataset_path)  
  
    # Compute embeddings for the query dataset and store them  
    query_embeddings_dataset = query_dataset.map(lambda x: {"embeddings": get_embeddings(x["question"], tokenizer, model, device).detach().cpu().numpy()[0]})
    query_embeddings_dataset.add_faiss_index(column="embeddings")
    
    
    rag_for_dataset(
        data_dataset_path=args.data_dataset_path,
        output_file_path=args.output_file_path,
        query_embeddings_dataset=query_embeddings_dataset,
        tokenizer=tokenizer,
        model=model,
        device=device
    )
  
    
  
if __name__ == "__main__":  
    parser = argparse.ArgumentParser(description="Process datasets and compute embeddings.")  
    parser.add_argument("--checkpoint", type=str, default="Salesforce/codet5p-110m-embedding", help="Model checkpoint for embeddings.")  
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on, e.g., 'cuda' or 'cpu'.")  
    parser.add_argument("--data_dataset_path", type=str, required=True, help="Path to the data JSONL file.")  
    parser.add_argument("--query_dataset_path", type=str, required=True, help="Path to the query JSONL file.")  
    parser.add_argument("--output_file_path", type=str, required=True, help="Path to the query JSONL file.") 
    
  
    args = parser.parse_args()  
    main(args)  
