
from genai import Client, Credentials
import datetime
import pytz
import logging
import json
import copy
from dotenv import load_dotenv
from genai.text.generation import CreateExecutionOptions
from genai.schema import (
    DecodingMethod,
    LengthPenalty,
    ModerationParameters,
    ModerationStigma,
    TextGenerationParameters,
    TextGenerationReturnOptions,
)

try:
    from tqdm.auto import tqdm
except ImportError:
    print("Please install tqdm to run this example.")
    raise

load_dotenv()
client = Client(credentials=Credentials.from_env())

logging.getLogger("bampy").setLevel(logging.DEBUG)
fh = logging.FileHandler('bampy.log')
fh.setLevel(logging.DEBUG)
logging.getLogger("bampy").addHandler(fh)

# parameters = TextGenerationParameters(
#     max_new_tokens=250,
#     min_new_tokens=1,
#     random_seed=5,
#     temperature=0.05,
#     top_k=50,
#     top_p=1.0,
#     decoding_method=DecodingMethod.GREEDY,
#     length_penalty=LengthPenalty(start_index=5, decay_factor=1.5),
#     return_options=TextGenerationReturnOptions(
#         # if ordered is False, you can use return_options to retrieve the corresponding prompt
#         input_text=True,
#     ),
# )

parameters = TextGenerationParameters(
    max_new_tokens=250,
    min_new_tokens=1,
    decoding_method=DecodingMethod.GREEDY,
    # length_penalty=LengthPenalty(start_index=5, decay_factor=1.5),
    return_options=TextGenerationReturnOptions(
        # if ordered is False, you can use return_options to retrieve the corresponding prompt
        input_text=True,
    ),
)

def generateAnswers_bam_models(testingUnits):
    TestModels = ['mistralai/mixtral-8x7b-instruct-v0-1','mistralai/mistral-7b-instruct-v0-2', 'google/flan-ul2', 'meta-llama/llama-2-70b-chat', 'ibm/granite-13b-chat-v2','ibm/granite-13b-instruct-v2', 'ibm/granite-13b-lab-incubation', 'meta-llama/llama-2-7b-chat', 'meta-llama/llama-2-13b-chat']

    prompt_all = {}

    prompt_0 = """Provide a short answer for the following question based on your internal knowledge.\n\n"""
    prompt_0_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_0}Question: {question}\n"""         
        prompt_0_all.append(prompt_str)    
    prompt_all['prompt_0'] = prompt_0_all
    
    prompt_1 = """Provide a short answer for the following question based on the given context.\n\n"""
    prompt_1_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_1}Question: {question}\n"""  
        prompt_str += f"""Context: {unit['question1_context1']}\n"""       
        prompt_1_all.append(prompt_str)    
    prompt_all['prompt_1'] = prompt_1_all
    
    prompt_2 = """Provide a short answer for the following question based on the given context.\n\n"""
    prompt_2_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_2}Question: {question}\n"""  
        prompt_str += f"""Context: {unit['question1_context2']}\n"""       
        prompt_2_all.append(prompt_str)    
    prompt_all['prompt_2'] = prompt_2_all


    prompt_3 = """Provide a short answer for the following question based on the given context.\n\n"""
    prompt_3_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_3}Question: {question}\n"""  
        if unit['samepassage'] == 'Same':
            prompt_str += f"""Context: {unit['question1_context1']}\n"""
        else:
            prompt_str += f"""Context: {unit['question1_context1']}\n {unit['question1_context2']}\n"""       
        prompt_3_all.append(prompt_str)    
    prompt_all['prompt_3'] = prompt_3_all


    prompt_4 = """Provide a short answer for the following question based on the given context. Carefully investigate the given context and provide a concise response that reflects the comprehensive view of the context, even if the answer contains contradictory information reflecting the heterogeneous nature of the context. \n\n"""
    
    prompt_4_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_4}Question: {question}\n"""  
        if unit['samepassage'] == 'Same':
            prompt_str += f"""Context: {unit['question1_context1']}\n"""
        else:
            prompt_str += f"""Context: {unit['question1_context1']}\n {unit['question1_context2']}\n"""       
        prompt_4_all.append(prompt_str)    
    prompt_all['prompt_4'] = prompt_4_all



    for model in TestModels:
        model_name = str(model).split('/')[1]
        model_name_save = str(model).split('/')[1].replace('-', '_')
        testingResults_copy = copy.deepcopy(testingUnits)
        # yields batch of results that are produced asynchronously and in parallel
        for prompt_tempalte, prompts in prompt_all.items():
            print(f"Testing model {model_name} with prompt {prompt_tempalte}")
            for idx, response in tqdm(
                enumerate(
                    client.text.generation.create(
                        model_id=model,
                        inputs=prompts,
                        # set to ordered to True if you need results in the same order as prompts
                        execution_options=CreateExecutionOptions(ordered=True),
                        parameters=parameters,
                    )
                ),
                total=len(prompts),
                desc="Progress",
                unit="input",
            ):
                result = response.results[0]
                testingResults_copy[idx]["ModelAnswer" + "_" + prompt_tempalte] = result.generated_text
                testingResults_copy[idx]["ModelInput" + "_" + prompt_tempalte] = result.input_text
                if prompt_tempalte=="prompt_0" or prompt_tempalte == "prompt_3" or prompt_tempalte == "prompt_4":
                    testingResults_copy[idx]['goldAnswer_' + prompt_tempalte] = testingResults_copy[idx]['question1_answer1'] + ' | ' + testingResults_copy[idx]['question1_answer2']
                if prompt_tempalte=="prompt_1":
                    testingResults_copy[idx]['goldAnswer_' + prompt_tempalte] = testingResults_copy[idx]['question1_answer1'] 
                if prompt_tempalte=="prompt_2":
                    testingResults_copy[idx]['goldAnswer_' + prompt_tempalte] = testingResults_copy[idx]['question1_answer2'] 

           
        # saving the prompting results to a json file for the model
        with open('/Users/yhou/git/wikiEvidenceVeracity/data/annotation/promptExp/LLM_Answers/LLM_Answers_Model_' + model_name +'.json', 'w') as f:
            json.dump(testingResults_copy, f, indent=4)        
        
        
def generateAnswers_bam_models_for_annotation(testingUnits):
    TestModels = ['mistralai/mistral-7b-instruct-v0-2', 'mistralai/mixtral-8x7b-instruct-v0-1', 'meta-llama/llama-2-70b-chat']

    prompt_all = {}

    prompt_0 = """Provide a short answer for the following question based on your internal knowledge.\n\n"""
    prompt_0_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_0}Question: {question}\n"""         
        prompt_0_all.append(prompt_str)    
    prompt_all['prompt_0'] = prompt_0_all
    
    prompt_1 = """Provide a short answer for the following question based on the given context.\n\n"""
    prompt_1_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_1}Question: {question}\n"""  
        prompt_str += f"""Context: {unit['question1_context1']}\n"""       
        prompt_1_all.append(prompt_str)    
    prompt_all['prompt_1'] = prompt_1_all
    
    prompt_2 = """Provide a short answer for the following question based on the given context.\n\n"""
    prompt_2_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_2}Question: {question}\n"""  
        prompt_str += f"""Context: {unit['question1_context2']}\n"""       
        prompt_2_all.append(prompt_str)    
    prompt_all['prompt_2'] = prompt_2_all


    prompt_3 = """Provide a short answer for the following question based on the given context.\n\n"""
    prompt_3_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_3}Question: {question}\n"""  
        if unit['samepassage'] == 'Same':
            prompt_str += f"""Context: {unit['question1_context1']}\n"""
        else:
            prompt_str += f"""Context: {unit['question1_context1']}\n {unit['question1_context2']}\n"""       
        prompt_3_all.append(prompt_str)    
    prompt_all['prompt_3'] = prompt_3_all


    prompt_4 = """Provide a short answer for the following question based on the given context. Carefully investigate the given context and provide a concise response that reflects the comprehensive view of the context, even if the answer contains contradictory information reflecting the heterogeneous nature of the context. \n\n"""
    
    prompt_4_all = []
    for unit in testingUnits:
        question = unit['question1'] 
        prompt_str = f"""{prompt_4}Question: {question}\n"""  
        if unit['samepassage'] == 'Same':
            prompt_str += f"""Context: {unit['question1_context1']}\n"""
        else:
            prompt_str += f"""Context: {unit['question1_context1']}\n {unit['question1_context2']}\n"""       
        prompt_4_all.append(prompt_str)    
    prompt_all['prompt_4'] = prompt_4_all



    for model in TestModels:
        model_name = str(model).split('/')[1]
        model_name_save = str(model).split('/')[1].replace('-', '_')
        testingResults_copy = copy.deepcopy(testingUnits)
        # yields batch of results that are produced asynchronously and in parallel
        for prompt_tempalte, prompts in prompt_all.items():
            print(f"Testing model {model_name} with prompt {prompt_tempalte}")
            for idx, response in tqdm(
                enumerate(
                    client.text.generation.create(
                        model_id=model,
                        inputs=prompts,
                        # set to ordered to True if you need results in the same order as prompts
                        execution_options=CreateExecutionOptions(ordered=True),
                        parameters=parameters,
                    )
                ),
                total=len(prompts),
                desc="Progress",
                unit="input",
            ):
                result = response.results[0]
                testingResults_copy[idx]["ModelAnswer" + "_" + prompt_tempalte] = result.generated_text
                testingResults_copy[idx]["ModelInput" + "_" + prompt_tempalte] = result.input_text
                if prompt_tempalte=="prompt_0" or prompt_tempalte == "prompt_3" or prompt_tempalte == "prompt_4":
                    testingResults_copy[idx]['goldAnswer_' + prompt_tempalte] = testingResults_copy[idx]['question1_answer1'] + ' | ' + testingResults_copy[idx]['question1_answer2']
                if prompt_tempalte=="prompt_1":
                    testingResults_copy[idx]['goldAnswer_' + prompt_tempalte] = testingResults_copy[idx]['question1_answer1'] 
                if prompt_tempalte=="prompt_2":
                    testingResults_copy[idx]['goldAnswer_' + prompt_tempalte] = testingResults_copy[idx]['question1_answer2'] 

           
        # saving the prompting results to a json file for the model
        with open('/Users/yhou/git/wikiEvidenceVeracity/data/annotation/promptExp/LLM_Answers_Annotation/LLM_Answers_Model_' + model_name +'.json', 'w') as f:
            json.dump(testingResults_copy, f, indent=4)        

    
def load_testingdata():
    # load the testing json file
    testingUnits = []
    # with open('/Users/yhou/git/wikiEvidenceVeracity/data/annotation/AnnotationResults_all_valid.json') as f:
    question_ID = 1
    with open('/Users/yhou/git/wikiEvidenceVeracity/data/annotation/AnnotationResults_all_valid_March_11.json') as f:
        annotationResults = json.load(f)
        for annotation in annotationResults:
            testingUnit = {}
            # print(annotation['annotationResult']['question1'])
            testingUnit['question1'] = annotation['annotationResult']['question1']
            if 'Contradict type I' not in annotation['annotationResult']:
                testingUnit['taxonomy-contradictTypeI'] = "unknown"
            else:    
                testingUnit['taxonomy-contradictTypeI'] = annotation['annotationResult']['Contradict type I']
            if 'Contradict type II' not in annotation['annotationResult']:
                testingUnit['taxonomy-contradictTypeII'] = "unknown"
            else:    
                testingUnit['taxonomy-contradictTypeII'] = annotation['annotationResult']['Contradict type II']
            if 'Contradict type III' not in annotation['annotationResult']:
                testingUnit['taxonomy-contradictTypeIII'] = "unknown"
            else:    
                testingUnit['taxonomy-contradictTypeIII'] = annotation['annotationResult']['Contradict type III']
            if 'Contradict type IV' not in annotation['annotationResult']:
                testingUnit['taxonomy-contradictTypeIV'] = "unknown"
            else:    
                testingUnit['taxonomy-contradictTypeIV'] = annotation['annotationResult']['Contradict type IV']

            if 'wikitag_label_samepassage' not in annotation['annotationResult']:
                testingUnit['samepassage'] = 'Different'
            else:
                testingUnit['samepassage'] = annotation['annotationResult']['wikitag_label_samepassage']           
            
            # check whether annotationResults contains a key
            if 'paragraphA_information_standalone' not in annotation['annotationResult']:
                testingUnit['question1_context1'] = annotation['annotationResult']['paragraphA_information']
            else:
                testingUnit['question1_context1'] = annotation['annotationResult']['paragraphA_information_standalone']
            if 'paragraphB_information_standalone' not in annotation['annotationResult']:
                testingUnit['question1_context2'] = annotation['annotationResult']['paragraphB_information']
            else:
                testingUnit['question1_context2'] = annotation['annotationResult']['paragraphB_information_standalone']
            testingUnit['question1_answer1'] = annotation['annotationResult']['question1_answer1']
            testingUnit['question1_answer2'] = annotation['annotationResult']['question1_answer2']
            
            testingUnit['title'] = annotation['title']
            testingUnit['url'] = annotation['url']
            testingUnit['question_ID'] = question_ID
            question_ID += 1
            
            if 'question2' in annotation['annotationResult']:
                # deep copy the testingUnit
                testingUnit2 = copy.deepcopy(testingUnit)
                testingUnit2['question1'] = annotation['annotationResult']['question2']
                testingUnit2['question1_answer1'] = annotation['annotationResult']['question2_answer1']
                testingUnit2['question1_answer2'] = annotation['annotationResult']['question2_answer2']
                testingUnit2['question_ID'] = question_ID
                question_ID += 1
                testingUnits.append(testingUnit2)
            testingUnits.append(testingUnit)
    return testingUnits

def prepareLSAnnotationFile():
    # load the testing json file
    testingUnits_new = []
    with open('/Users/yhou/git/wikiEvidenceVeracity/data/annotation/promptExp/testingResult.json') as f:
        testingUnits = json.load(f)
        for unit in testingUnits:
            # check unit name, if it contains "-", change it to "_"
            newunit = {}
            for key in unit.keys():
                if '-' in key:
                    new_key = key.replace('-', '_')
                    newunit[new_key] = unit[key]
                    newunit[new_key+'_oldname'] = key
                else:
                    newunit[key] = unit[key]
                    newunit[key+'_oldname'] = key
            testingUnits_new.append(newunit)           
    # saving testingunits to a new json file
    with open('/Users/yhou/git/wikiEvidenceVeracity/data/annotation/promptExp/testingResult_LS.json', 'w') as f:
        json.dump(testingUnits_new, f, indent=4)
    
if __name__ == "__main__":  
    testingUnits = load_testingdata()
    print("question size", len(testingUnits))
    generateAnswers_bam_models(testingUnits)
    # generateAnswers_bam_models_for_annotation(testingUnits)
