# %%
import argparse
import sys
sys.path.append('../')
parser = argparse.ArgumentParser()
from evaluate import load
bertscore = load("bertscore")
from Util.MESSAGES import *
from Util.smallModelOutput import *
from Util.ModelOutput import *
from Util.constants import *
from Util.datasetFilePath import *
from Util.MESSAGES import *
from Util.smallModelOutput import *
from Util.ModelOutput import *
from Util.constants import *
from Util.datasetFilePath import *
from Util.modelClasses import *

import nltk
from nltk.translate.meteor_score import meteor_score
# parser.add_argument('--api_key', dest='api_key', type=str, help='Add api_key')
# parser.add_argument('--endpoint', dest='endpoint', type=str, help='Add endpoint')
# parser.add_argument('--fewShot', dest='fewShot', type=bool, help='Add fewShot')
# parser.add_argument('--promptType', dest='promptType', type=str, help='Add promptType')
parser.add_argument('--COTType', dest='COTType', type=str, help='Add COTType', choices=['default', 'COT','smallerModel'])
parser.add_argument('--model', dest='model', type=str, help='Add model', choices=['GPT4', 'gpt-35-turbo', 'llama-2-7b-chat', 'phi', 'mistral', 'gemini', 'claude'])
parser.add_argument('--dataset', dest='dataset', type=str, help='Add dataset', choices=['gsm8k', 'MATH-Level-1', 'MATH-Level-2', 'MATH-Level-3', 'MATHBENCH', 'JEEBENCH'])

args = parser.parse_args()

# print (args.product_id)
# %%
import os
# %%
# COT_type = 'default' # 'COT'
COT_type = args.COTType
# model = "GPT4"
model = args.model
# dataset = 'gsm8k'
dataset = args.dataset

print("COT Type: ", COT_type)
print("Model: ", model)
print("Dataset: ", dataset)


# %%
import json


# %%
import os
from openai import AzureOpenAI
        
# %%


# %%
types = ['reasoning_chain', 'shuffle_reasoning_steps', 'delete_reasoning_step', 'shuffle_numerical_values', 'replace_half_numerical_values', 'shuffle_operations', 'random_reasoning_step', 'llama-2-7b-chat_incorrect_reasoning', 'mistral_incorrect_reasoning', 'phi_incorrect_reasoning']


# %%

# %%
import time
from tqdm import tqdm
import time
from tqdm import tqdm
if(model == 'gemini'):
    LMMmodels = GeminiModels(gemini_config_vision)
if(model == 'claude'):
    LMMmodels = CLAUDEModels(claude_config)
for type in types:
    dataPath, count = getDataPath(dataset, type, COT_type)
    print("Type: ", type)
    
    if(dataPath == None):
        continue
    with open(dataPath, 'r') as f:
        data = json.load(f)
    print("Data size:", min(count, len(data)))
    data_type = []
    
    saveDataPath = getSavePath(dataset, type, model, COT_type, parsed=False)
    if(os.path.exists(saveDataPath)):
        with open(saveDataPath, 'r') as f:
            data_type = json.load(f)
    if(len(data_type) == count):
        continue
    for d in tqdm(data[len(data_type):count]):
        if(type in d):
            temp = d
            temp['type'] = type
            temp[type] = d[type]
            if(dataset == 'gsm8k'):
                if(COT_type == 'COT'):
                    messages = getGSM8KFewShot_COT(d['question'], d[type])
                elif(COT_type == 'default'):
                    messages = getGSM8KFewShot_NonCOT(d['question'], d[type])
                elif(COT_type == 'smallerModel'):
                    messages = getGSM8KFewShot_COT(d['question'], d[type])
            elif(dataset == 'MATH-Level-1' or dataset == 'MATH-Level-2' or dataset == 'MATH-Level-3'):
                if(COT_type == 'COT'):
                    messages = getMATHFewShot_COT(d['problem'], d[type], dataset)
                elif(COT_type == 'default'):
                    messages = getMATHFewShot_NonCOT(d['problem'], d[type], dataset)
                elif(COT_type == 'smallerModel'):
                    messages = getMATHFewShot_COT(d['question'], d[type], dataset)
            elif(dataset == 'MATHBENCH'):
                if(COT_type == 'COT'):
                    messages = getMATHBENCHFewShot_COT(d['question'], d[type], d['options'])
                elif(COT_type == 'smallerModel'):
                    messages = getMATHBENCHFewShot_COT(d['question'], d[type], d['options'])
            elif(dataset == 'JEEBENCH'):
                messages = getJEEBENCHFewShot_COT(d['question'], d[type], d['options'])
                    
            # print(d, type)
            if(model == 'GPT4'):
                temp['output'] = outputGPT(GPT_4_API_ENPOINT, GPT_4_API_KEY, GPT_4_MODEL_NAME, messages)
            elif(model == 'gpt-35-turbo'):
                temp['output'] = outputGPT(GPT_35_API_ENDPOINT, GPT_35_API_KEY, GPT_35_MODEL_NAME, messages)
            elif(model == 'llama-2-7b-chat'):
                temp['output'] = outputModel(llama_2_7b_chat_api_key, llama_2_7b_chat_api_endpoint,llama_2_7b_model_name, messages)
            elif(model == 'phi'):
                temp['output'] = outputModel(phi_api_key,phi_api_endpoint,phi_model_name, messages)
            elif(model == 'mistral'):
                temp['output'] = outputMistrael(mistral_api_endpoint, mistral_api_key, messages)
            elif(model == 'gemini'):
                temp['output'] = outputGemini(LMMmodels, messages)
            elif(model == 'claude'):
                temp['output'] = outputClaude(LMMmodels, messages)
            data_type.append(temp)

            saveDataPath = getSavePath(dataset, type, model, COT_type, parsed=False)
            with open(saveDataPath, 'w') as f:
                json.dump(data_type, f, indent=4)
            time.sleep(0.7)


# %%
def parseGPTOutput(output):
    raw_output = output
    output = output.split('\n')
    corrected_reasoning_chain = ''
    final_answer = ''
    reasoning_chain_correct = ''
    for i in range(len(output)):
        try:
            if('reasoning chain correct' in output[i].lower()):
                reasoning_chain_correct = output[i].split(':')[1].strip()
            elif('final answer' in output[i].lower()):
                final_answer = output[i].split(':')[1].strip()
            elif('Corrected reasoning chain' in output[i]):
                corrected_reasoning_chain = output[i].split(':')[1].strip()
            else:
                corrected_reasoning_chain += '\n' + output[i]
        except Exception as e:
            print(e)
    

    return {
        'raw_output': raw_output,
        'reasoning_chain_correct': reasoning_chain_correct,
        'corrected_reasoning_chain': corrected_reasoning_chain,
        'final_answer': final_answer
    }

# %%
import json
from  tqdm import tqdm
for type in types:
    
    saveDataPath = getSavePath(dataset, type, model,COT_type, parsed=False)
    if(os.path.exists(saveDataPath) == False):
        continue
    with open(saveDataPath, 'r') as f:
        gptData = json.load(f)

    for d in tqdm(gptData):
        d['output'] = parseGPTOutput(d['output'])
    
    saveDataPath = getSavePath(dataset, type, model, COT_type, parsed=True)
    with open(saveDataPath, 'w') as f:
        json.dump(gptData, f, indent=4)
        

# %%
import re
def clean_answer(answer):
    """
    Clean and preprocess an answer for comparison.
    
    Args:
        answer (str): Answer to clean.
        
    Returns:
        str: Cleaned answer for comparison.
    """
    # Example: Convert answer to lowercase and remove non-alphanumeric characters
    cleaned_answer = re.sub(r'[^a-zA-Z0-9]', '', answer.lower())
    return cleaned_answer

# %%
from sklearn.metrics import f1_score, classification_report, accuracy_score, confusion_matrix


# %%
ground_truth = []
predicted = []
ground_truth_answer = []
predicted_answer = []
no_of_reasoning_chain_corrected = 0
no_of_answer_correct_after_reasoning_chain_corrected = 0
reasoning_chain = []
type_specific_results = []
# total_yes = 0
# total_no = 0
with open('results_GPT_Evalutaion_type_specific.json', 'r') as f:
    type_specific_results = json.load(f)
for type in types:
    saveDataPath = getSavePath(dataset, type, model, COT_type, parsed=True)
    if(os.path.exists(saveDataPath) == False):
        continue
    with open(saveDataPath, 'r') as f:
        gptData = json.load(f)

    number_of_yes = 0
    number_of_no = 0
    other = 0
    for d in gptData:
        if(d['output']['reasoning_chain_correct'].lower() == 'yes'):
            number_of_yes += 1
        elif(d['output']['reasoning_chain_correct'].lower() == 'no'):
            number_of_no += 1
        else:
            other += 1
    
    
    if(type == 'reasoning_chain'):
        for d in gptData:
            if(d['output']['reasoning_chain_correct'].lower() == 'yes' or d['output']['reasoning_chain_correct'].lower() == 'no'):
                # total_yes += 1
                ground_truth.append('yes')
                predicted.append(d['output']['reasoning_chain_correct'].lower())
    else:
        for d in gptData:
            if(d['output']['reasoning_chain_correct'].lower() == 'yes' or d['output']['reasoning_chain_correct'].lower() == 'no'):
                # total_no += 1
                ground_truth.append('no')
                predicted.append(d['output']['reasoning_chain_correct'].lower())
                
            if(d['output']['reasoning_chain_correct'].lower() == 'no'):
                if((dataset == 'MATH-Level-1' or dataset == 'MATH-Level-2' or dataset == 'MATH-Level-3') and COT_type != 'smallerModel'):
                    reasoning_chain.append({'question': d['problem'], 'reasoning_chain': d['reasoning_chain'], 'predicted_reasoning_chain': d['output']['corrected_reasoning_chain'], 'answer': clean_answer(d['answer']), 'predicted_answer': clean_answer(d['output']['final_answer']), 'type': type})
                elif((dataset == 'MATH-Level-1' or dataset == 'MATH-Level-2' or dataset == 'MATH-Level-3') and COT_type == 'smallerModel'):
                    reasoning_chain.append({'question': d['question'], 'reasoning_chain': d['solution'], 'predicted_reasoning_chain': d['output']['corrected_reasoning_chain'], 'answer': clean_answer(d['answer']), 'predicted_answer': clean_answer(d['output']['final_answer']), 'type': type})
                else:
                    reasoning_chain.append({'question': d['question'], 'reasoning_chain': d['reasoning_chain'], 'predicted_reasoning_chain': d['output']['corrected_reasoning_chain'], 'answer': clean_answer(d['answer']), 'predicted_answer': clean_answer(d['output']['final_answer']), 'type': type})
    
    
    type_ground_truth_answer = []
    type_predicted_answer = []
    for d in gptData:
        ground_truth_answer.append(clean_answer(d['answer']))
        predicted_answer.append(clean_answer(d['output']['final_answer']))

        type_ground_truth_answer.append(clean_answer(d['answer']))
        type_predicted_answer.append(clean_answer(d['output']['final_answer']))
        if(d['output']['reasoning_chain_correct'].lower() == 'no'):
            no_of_reasoning_chain_corrected += 1
            if(clean_answer(d['answer']) == clean_answer(d['output']['final_answer'])):
                no_of_answer_correct_after_reasoning_chain_corrected += 1
            
    print(f"Type: {type}, Number of Yes: {number_of_yes}, Number of No: {number_of_no}, Other: {other}")
    print(f"Type: {type}, F1 Score: {f1_score(type_ground_truth_answer, type_predicted_answer, average='weighted')}, Accuracy Score: {accuracy_score(type_ground_truth_answer, type_predicted_answer)}")
    type_specific_results.append({
        'dataset' : dataset,
        'COT_Type' : COT_type,
        'model' : model,
        'type' : type,
        'number_of_yes' : number_of_yes,
        'number_of_no' : number_of_no,
        'other' : other,
        'final_answer_f1' : f1_score(type_ground_truth_answer, type_predicted_answer, average='weighted')
    })

print("Correct / Incorrect Result")
print(f"F1 Score: {f1_score(ground_truth, predicted, average='weighted')}")
print(f"Classification Report:\n {classification_report(ground_truth, predicted)}")
print(f"Confusion Matrix:\n {confusion_matrix(ground_truth, predicted)}")
print(f"Accuracy Score: {accuracy_score(ground_truth, predicted)}")
print('\n')
print("Answer Result")
print(f"F1 Score: {f1_score(ground_truth_answer, predicted_answer, average='weighted')}")
print(f"Accuracy Score: {accuracy_score(ground_truth_answer, predicted_answer)}")
print(f"Confusion Matrix:\n {confusion_matrix(ground_truth_answer, predicted_answer)}")
print(f"No of reasoning chain corrected (Identification = no): {no_of_reasoning_chain_corrected}")
print(f"No of correct answer after reasoning chain is changed: {no_of_answer_correct_after_reasoning_chain_corrected}")
print(f"RATIO: {no_of_answer_correct_after_reasoning_chain_corrected / no_of_reasoning_chain_corrected}")
number_of_correct_final_answers = 0
number_of_incorrect_final_answers = 0
total_yes = 0
total_no = 0
for i in range(len(ground_truth_answer)):
    if(ground_truth_answer[i] == predicted_answer[i]):
        number_of_correct_final_answers += 1
    else:
        number_of_incorrect_final_answers += 1
for i in range(len(predicted)):
    if(predicted[i] == 'yes'):
        total_yes += 1
    elif(predicted[i] == 'no'):
        total_no += 1
    
with open('results_GPT_Evalutaion_type_specific.json', 'w') as f:
    json.dump(type_specific_results, f, indent=4)
with open('results_GPT_Evaluation.json', 'r') as f:
    results = json.load(f)

    

# %%
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)

global_meteor_correct = []
global_meteor_incorrect = []
global_rouge_correct = []
global_rouge_incorrect = []
global_rouge_sentence_incorrect = []
global_rouge_sentence_correct = []
average_length_correct = 0
average_length_incorrect = 0
average_bert_score_correct = 0
average_bert_score_incorrect = 0
for type in types:
    filtered_reasoning_chain = []

    avg_meteor_correct = 0
    avg_metoer_incorrect = 0
    avg_rouge_correct = 0
    avg_rouge_incorrect = 0
    correct_count = 0
    incorrect_count = 0
    for reasoning in reasoning_chain:
        if(reasoning['type'] == type):
            filtered_reasoning_chain.append(reasoning)
    for reasoning in filtered_reasoning_chain:
        if(reasoning['answer'] == reasoning['predicted_answer']):
            avg_meteor_correct += meteor_score([nltk.word_tokenize(reasoning['reasoning_chain'])], nltk.word_tokenize(reasoning['predicted_reasoning_chain']))
            avg_rouge_correct += scorer.score(reasoning['reasoning_chain'], reasoning['predicted_reasoning_chain'])['rouge2'].fmeasure
            correct_count += 1
            global_meteor_correct.append(meteor_score([nltk.word_tokenize(reasoning['reasoning_chain'])], nltk.word_tokenize(reasoning['predicted_reasoning_chain'])))
            global_rouge_correct.append(scorer.score(reasoning['reasoning_chain'], reasoning['predicted_reasoning_chain'])['rougeL'].fmeasure)
            global_rouge_sentence_correct.append({'reasoning_chain': reasoning['reasoning_chain'], 'predicted_reasoning_chain': reasoning['predicted_reasoning_chain'], 'meteor' : meteor_score([nltk.word_tokenize(reasoning['reasoning_chain'])], nltk.word_tokenize(reasoning['predicted_reasoning_chain'])), 'rouge' : scorer.score(reasoning['reasoning_chain'], reasoning['predicted_reasoning_chain'])['rougeL'].fmeasure, 'model' : model})
            average_length_correct += len(nltk.word_tokenize(reasoning['predicted_reasoning_chain']))
            # results = bertscore.compute(predictions=predictions, references=references, model_type="distilbert-base-uncased")
            # breakpoint()
            # average_bert_score_correct += bertscore.compute(predictions=[reasoning['predicted_reasoning_chain']], references=[[reasoning['reasoning_chain']]], model_type="distilbert-base-uncased")['f1'][0]

        else:
            avg_metoer_incorrect += meteor_score([nltk.word_tokenize(reasoning['reasoning_chain'])], nltk.word_tokenize(reasoning['predicted_reasoning_chain']))
            avg_rouge_incorrect += scorer.score(reasoning['reasoning_chain'], reasoning['predicted_reasoning_chain'])['rouge2'].fmeasure    
            incorrect_count += 1
            global_meteor_incorrect.append(meteor_score([nltk.word_tokenize(reasoning['reasoning_chain'])], nltk.word_tokenize(reasoning['predicted_reasoning_chain'])))
            global_rouge_incorrect.append(scorer.score(reasoning['reasoning_chain'], reasoning['predicted_reasoning_chain'])['rougeL'].fmeasure)
            global_rouge_sentence_incorrect.append({'reasoning_chain': reasoning['reasoning_chain'], 'predicted_reasoning_chain': reasoning['predicted_reasoning_chain'], 'meteor' : meteor_score([nltk.word_tokenize(reasoning['reasoning_chain'])], nltk.word_tokenize(reasoning['predicted_reasoning_chain'])), 'rouge' : scorer.score(reasoning['reasoning_chain'], reasoning['predicted_reasoning_chain'])['rougeL'].fmeasure, 'model' : model})
            average_length_incorrect += len(nltk.word_tokenize(reasoning['predicted_reasoning_chain']))
            # average_bert_score_incorrect += bertscore.compute(predictions=[reasoning['predicted_reasoning_chain']], references=[[reasoning['reasoning_chain']]], model_type="distilbert-base-uncased")['f1'][0]
    # if(correct_count != 0):
    #     avg_meteor_correct = avg_meteor_correct / correct_count
    #     avg_rouge_correct = avg_rouge_correct / correct_count
    #     global_rouge_correct.append(avg_rouge_correct)
    #     global_meteor_correct.append(avg_meteor_correct)    
    # elif(correct_count == 0):
    #     avg_meteor_correct = 'NA'
    # if(incorrect_count != 0):
    #     avg_metoer_incorrect = avg_metoer_incorrect / incorrect_count
    #     avg_rouge_incorrect = avg_rouge_incorrect / incorrect_count
    #     global_meteor_incorrect.append(avg_metoer_incorrect)
    #     global_rouge_incorrect.append(avg_rouge_incorrect)
    # elif(incorrect_count == 0):
    #     avg_metoer_incorrect = 'NA'
    # # avg_metoer_incorrect = avg_metoer_incorrect / incorrect_count
    # # avg_meteor_correct = avg_meteor_correct / correct_count
    # print(f"Type: {type}, Average Meteor Score for Correct Answers: {avg_meteor_correct}, Average Meteor Score for Incorrect Answers: {avg_metoer_incorrect}")
if(len(global_meteor_correct) != 0):
    print(f"Global Meteor Score for Correct Answers: {sum(global_meteor_correct) / len(global_meteor_correct)}")
    
if(len(global_meteor_incorrect) != 0):
    print(f"Global Meteor Score for Incorrect Answers: {sum(global_meteor_incorrect) / len(global_meteor_incorrect)}")

if(len(global_rouge_correct) != 0):
    print(f"Global Rouge Score for Correct Answers: {sum(global_rouge_correct) / len(global_rouge_correct)}")

if(len(global_rouge_incorrect) != 0):
    print(f"Global Rouge Score for Incorrect Answers: {sum(global_rouge_incorrect) / len(global_rouge_incorrect)}")


results.append({
    'dataset' : dataset,
    'COT_Type' : COT_type,
    'model' : model,
    'Mistake Identification (F1 Score)' : f1_score(ground_truth, predicted, average='weighted'),
    'Final Answer (F1 Score)' : f1_score(ground_truth_answer, predicted_answer, average='weighted'),
    'Recovery' : no_of_answer_correct_after_reasoning_chain_corrected / no_of_reasoning_chain_corrected,
    'Meteor_correct' : (sum(global_meteor_correct) / len(global_meteor_correct)),
    'Meteor_incorrect' : (sum(global_meteor_incorrect) / len(global_meteor_incorrect)),
    'Rouge Correct' : (sum(global_rouge_correct) / len(global_rouge_correct)),
    'Rouge Incorrect' : (sum(global_rouge_incorrect) / len(global_rouge_incorrect)),
    'Correct SUM' : len(global_meteor_correct),
    'InCorrect SUM' : len(global_meteor_incorrect),
    'bert_score_correct' : average_bert_score_correct / len(global_rouge_sentence_correct),
    'bert_score_incorrect' : average_bert_score_incorrect / len(global_rouge_sentence_incorrect),
    'Average Length of Correct Reasoning Chain' : average_length_correct / len(global_meteor_correct),
    'Average Length of Incorrect Reasoning Chain' : average_length_incorrect / len(global_meteor_incorrect),
    'Count of Correct Reasoning Chain' : len(global_meteor_correct),
    'Count of Incorrect Reasoning Chain' : len(global_meteor_incorrect),
    'Total Yes' : total_yes,
    'Total No' : total_no,
    'Correct Final Answers' : number_of_correct_final_answers,
    'Incorrect Final Answers' : number_of_incorrect_final_answers
})
with open('results_GPT_Evaluation.json', 'w') as f:
    json.dump(results,f, indent=4)
    
# sort by rouge score
global_rouge_sentence_correct = sorted(global_rouge_sentence_correct, key=lambda x: x['rouge'], reverse=True)
global_rouge_sentence_incorrect = sorted(global_rouge_sentence_incorrect, key=lambda x: x['rouge'], reverse=True)
print(f"Average Length Correct: {average_length_correct / len(global_rouge_sentence_correct)}")
print(f"Average Length Incorrect: {average_length_incorrect / len(global_rouge_sentence_incorrect)}")
with open('global_rouge_sentence_correct.json', 'w') as f:
    json.dump(global_rouge_sentence_correct, f, indent=4)
with open('global_rouge_sentence_incorrect.json', 'w') as f:
    json.dump(global_rouge_sentence_incorrect, f, indent=4)