# %%
import argparse
import sys
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(['rouge1','rouge2', 'rougeL'], use_stemmer=True)
sys.path.append('../')
parser = argparse.ArgumentParser()

from Util.MESSAGES import *
from Util.smallModelOutput import *
from Util.ModelOutput import *
from Util.constants import *
from Util.datasetFilePath import *
from Util.modelClasses import *


parser.add_argument('--COTType', dest='COTType', type=str, help='Add COTType', choices=['default', 'COT','smallerModel'])
parser.add_argument('--model', dest='model', type=str, help='Add model', choices=['GPT4O','GPT4', 'gpt-35-turbo', 'llama-2-7b-chat', 'phi', 'mistral', 'gemini', 'claude'])
parser.add_argument('--dataset', dest='dataset', type=str, help='Add dataset', choices=['gsm8k', 'MATH-Level-1', 'MATH-Level-2', 'MATH-Level-3', 'MATHBENCH', 'JEEBENCH'])

args = parser.parse_args()

import os
# %%
import json

COT_type = args.COTType
# model = "GPT4"
model = args.model
# dataset = 'gsm8k'
dataset = args.dataset

types = ['memorization']
# %%
import time
from tqdm import tqdm
if(model == 'gemini'):
    LMMmodels = GeminiModels(gemini_config_vision)
if(model == 'claude'):
    LMMmodels = CLAUDEModels(claude_config)

for type in types:
    dataPath, count = getDataPathMemorization(dataset, type, COT_type)
    print("Type: ", type)
    print(dataPath, count)
    if(dataPath == None):
        continue
    with open(dataPath, 'r') as f:
        data = json.load(f)
    print("Data size:", min(count, len(data)))
    data_type = []
    
    saveDataPath = getSavePath(dataset, type, model, COT_type, parsed=False)
    if(os.path.exists(saveDataPath)):
        print("File already exists: ", saveDataPath)
        with open(saveDataPath, 'r') as f:
            data_type = json.load(f)
    print((saveDataPath), len(data_type))
    if(len(data_type) == count):
        continue
    for d in tqdm(data[len(data_type):count]):
        # if(type in d):
        temp = d
        temp['type'] = type
        
        if(dataset == 'gsm8k'):
            if(COT_type == 'COT'):
                firstAnswer = temp['reasoning_chain'][:int(len(temp['reasoning_chain']) / 2)]
                secondAnswer = temp['reasoning_chain'][int(len(temp['reasoning_chain']) / 2):]
                guidedMessage = getGSM8KMemorizationInstructions(d['question'], firstAnswer, 'guided')
                generalMessage = getGSM8KMemorizationInstructions(d['question'], firstAnswer, 'general')
            elif(COT_type == 'default'):
                firstAnswer = temp['full_answer'][:int(len(temp['full_answer']) / 2)]
                secondAnswer = temp['full_answer'][int(len(temp['full_answer']) / 2):]
                guidedMessage = getGSM8KMemorizationInstructions(d['question'], firstAnswer, 'guided')
                generalMessage = getGSM8KMemorizationInstructions(d['question'], firstAnswer, 'general')
            elif(COT_type == 'smallerModel'):
                firstAnswer = temp['smallerModel_solution'][:int(len(temp['smallerModel_solution']) / 2)]
                secondAnswer = temp['smallerModel_solution'][int(len(temp['smallerModel_solution']) / 2):]
                guidedMessage = getGSM8KMemorizationInstructions(d['question'], firstAnswer, 'guided')
                generalMessage = getGSM8KMemorizationInstructions(d['question'], firstAnswer, 'general')
            temp['firstAnswer'] = firstAnswer
            temp['secondAnswer'] = secondAnswer 
        elif(dataset == 'MATH-Level-1' or dataset == 'MATH-Level-2' or dataset == 'MATH-Level-3'):
            if(COT_type == 'COT'):
                firstSolution = temp['reasoning_chain'][:int(len(temp['reasoning_chain']) / 2)]
                secondSolution = temp['reasoning_chain'][int(len(temp['reasoning_chain']) / 2):]
                guidedMessage = getMATHMemorizationInstructions(d['problem'],d['level'], d['type'], firstSolution, 'guided')
                generalMessage = getMATHMemorizationInstructions(d['problem'],d['level'], d['type'], firstSolution, 'general')
            elif(COT_type == 'default'):
                firstSolution = temp['solution'][:int(len(temp['solution']) / 2)]
                secondSolution = temp['solution'][int(len(temp['solution']) / 2):]
                guidedMessage = getMATHMemorizationInstructions(d['problem'],d['level'], d['type'], firstSolution, 'guided')
                generalMessage = getMATHMemorizationInstructions(d['problem'],d['level'], d['type'], firstSolution, 'general')
            elif(COT_type == 'smallerModel'):
                firstSolution = temp['smallerModel_solution'][:int(len(temp['smallerModel_solution']) / 2)]
                secondSolution = temp['smallerModel_solution'][int(len(temp['smallerModel_solution']) / 2):]
                guidedMessage = getMATHMemorizationInstructions(d['question'],d['level'], d['type'], firstSolution, 'guided')
                generalMessage = getMATHMemorizationInstructions(d['question'],d['level'], d['type'], firstSolution, 'general')
            temp['firstAnswer'] = firstSolution
            temp['secondAnswer'] = secondSolution 
                # messages = getMATHFewShot_COT(d['question'], d[type], dataset)
        elif(dataset == 'MATHBENCH'):
            if(COT_type == 'COT'):
                # breakpoint()
                firstSolution = temp['reasoning_chain'][:int(len(temp['reasoning_chain']) / 2)]
                secondSolution = temp['reasoning_chain'][int(len(temp['reasoning_chain']) / 2):]
                guidedMessage = getMATHBENCHMemorizationInstructions(d['question'], d['levels'], d['source'], firstSolution, 'guided')
                generalMessage = getMATHBENCHMemorizationInstructions(d['question'], d['levels'], d['source'], firstSolution, 'general')
            elif(COT_type == 'smallerModel'):
                firstSolution = temp['smallerModel_solution'][:int(len(temp['smallerModel_solution']) / 2)]
                secondSolution = temp['smallerModel_solution'][int(len(temp['smallerModel_solution']) / 2):]
                guidedMessage = getMATHBENCHMemorizationInstructions(d['question'], d['levels'], d['source'], firstSolution, 'guided')
                generalMessage = getMATHBENCHMemorizationInstructions(d['question'], d['levels'], d['source'], firstSolution, 'general')
            temp['firstAnswer'] = firstSolution
            temp['secondAnswer'] = secondSolution 
        elif(dataset == 'JEEBENCH'):
            if(COT_type == 'default'):
                firstSolution = temp['reasoning_chain'][:int(len(temp['reasoning_chain']) / 2)]
                secondSolution = temp['reasoning_chain'][int(len(temp['reasoning_chain']) / 2):]
                guidedMessage = getJEEBENCHMemorizationInstructions(d['question'], d['subject'], firstSolution, 'guided')
                generalMessage = getJEEBENCHMemorizationInstructions(d['question'], d['subject'], firstSolution, 'general')
                temp['firstAnswer'] = firstSolution
                temp['secondAnswer'] = secondSolution 
            elif(COT_type == 'smallerModel'):
                firstAnswer = temp['smallerModel_solution'][:int(len(temp['smallerModel_solution']) / 2)]
                secondAnswer = temp['smallerModel_solution'][int(len(temp['smallerModel_solution']) / 2):]
                guidedMessage = getJEEBENCHMemorizationInstructions(d['question'], d['subject'], firstAnswer, 'guided')
                generalMessage = getJEEBENCHMemorizationInstructions(d['question'], d['subject'], firstAnswer, 'general')
                temp['firstAnswer'] = firstAnswer
                temp['secondAnswer'] = secondAnswer 
        #     if(COT_type == 'COT'):
        #         messages = getMATHBENCHFewShot_COT(d['question'], d[type], d['options'])
        #     elif(COT_type == 'smallerModel'):
        #         messages = getMATHBENCHFewShot_COT(d['question'], d[type], d['options'])
        # elif(dataset == 'JEEBENCH'):
        #     messages = getJEEBENCHFewShot_COT(d['question'], d[type], d['options'])
            
        # print(d, type)
        if(model == 'GPT4O'):
            temp['guidedOutput'] = outputGPT(GPT_4_O_API_ENPOINT, GPT_4_O_API_KEY, GPT_4_O_MODEL_NAME, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputGPT(GPT_4_O_API_ENPOINT, GPT_4_O_API_KEY, GPT_4_O_MODEL_NAME, generalMessage)  
        if(model == 'GPT4'):
            temp['guidedOutput'] = outputGPT(GPT_4_API_ENPOINT, GPT_4_API_KEY, GPT_4_MODEL_NAME, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputGPT(GPT_4_API_ENPOINT, GPT_4_API_KEY, GPT_4_MODEL_NAME, generalMessage)
        elif(model == 'gpt-35-turbo'):
            temp['guidedOutput'] = outputGPT(GPT_35_API_ENDPOINT, GPT_35_API_KEY, GPT_35_MODEL_NAME, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputGPT(GPT_35_API_ENDPOINT, GPT_35_API_KEY, GPT_35_MODEL_NAME, generalMessage)
        elif(model == 'llama-2-7b-chat'):
            temp['guidedOutput'] = outputModel(llama_2_7b_chat_api_key, llama_2_7b_chat_api_endpoint,llama_2_7b_model_name, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputModel(llama_2_7b_chat_api_key, llama_2_7b_chat_api_endpoint,llama_2_7b_model_name, generalMessage)
        elif(model == 'phi'):
            temp['guidedOutput'] = outputModel(phi_api_key,phi_api_endpoint,phi_model_name, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputModel(phi_api_key,phi_api_endpoint,phi_model_name, generalMessage)
        elif(model == 'mistral'):
            temp['guidedOutput'] = outputMistrael(mistral_api_endpoint, mistral_api_key, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputMistrael(mistral_api_endpoint, mistral_api_key, generalMessage)
        elif(model == 'gemini'):
            temp['guidedOutput'] = outputGemini(LMMmodels, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputGemini(LMMmodels, generalMessage)
        elif(model == 'claude'):
            temp['guidedOutput'] = outputClaude(LMMmodels, guidedMessage)
            time.sleep(0.3)
            temp['generalOutput'] = outputClaude(LMMmodels, generalMessage)
        data_type.append(temp)
        # print(temp)

        saveDataPath = getSavePath(dataset, type, model, COT_type, parsed=False)
        with open(saveDataPath, 'w') as f:
            json.dump(data_type, f, indent=4)
        time.sleep(0.3)
# %%
avg_rouge_1_guided = 0
avg_rouge_1_general = 0
avg_rouge_2_guided = 0
avg_rouge_2_general = 0
avg_rouge_L_guided = 0
avg_rouge_L_general = 0
count = 0
for type in types:
    saveDataPath = getSavePath(dataset, type, model, COT_type, parsed=False)
    if(os.path.exists(saveDataPath) == False):
        print("File not found: ", saveDataPath)
        continue
    with open(saveDataPath, 'r') as f:
        gptData = json.load(f)
    
    for d in gptData:
        try:
            # print(d['generalOutput'], d['secondAnswer'])
            # print(scorer.score(d['generalOutput'], d['secondAnswer'])['rouge1'].fmeasure
            avg_rouge_1_general += scorer.score(d['generalOutput'], d['secondAnswer'])['rouge1'].fmeasure
            avg_rouge_1_guided += scorer.score(d['guidedOutput'], d['secondAnswer'])['rouge1'].fmeasure
            
            avg_rouge_2_general += scorer.score(d['generalOutput'], d['secondAnswer'])['rouge2'].fmeasure
            avg_rouge_2_guided += scorer.score(d['guidedOutput'], d['secondAnswer'])['rouge2'].fmeasure
            
            avg_rouge_L_general += scorer.score(d['generalOutput'], d['secondAnswer'])['rougeL'].fmeasure
            avg_rouge_L_guided += scorer.score(d['guidedOutput'], d['secondAnswer'])['rougeL'].fmeasure
            
            count += 1
        except Exception as e:
            print("Error: ")
            print(d)
            print(e)
            continue
        
print("Guided: ")
print("Rouge1: ", avg_rouge_1_guided/count)
print("Rouge2: ", avg_rouge_2_guided/count)
print("RougeL: ", avg_rouge_L_guided/count)

print("General: ")
print("Rouge1: ", avg_rouge_1_general/count)
print("Rouge2: ", avg_rouge_2_general/count)
print("RougeL: ", avg_rouge_L_general/count)
        
        
        
        
    
        
with open('results_GPT_memorization.json', 'r') as f:
    results = json.load(f)
results.append({
    'dataset' : dataset,
    'COT_Type' : COT_type,
    'model' : model,
    'guided_R1' : avg_rouge_1_guided/count,
    'guided_R2' : avg_rouge_2_guided/count,
    'guided_Rl' : avg_rouge_L_guided/count,
    'general_R1' : avg_rouge_1_general/count,
    'general_R2' : avg_rouge_2_general/count,
    'general_Rl' : avg_rouge_L_general/count,
    
})
with open('results_GPT_memorization.json', 'w') as f:
    json.dump(results,f, indent=4)
    