# main file to run experiments 
from data import data
import os, sys, copy
import tqdm, json
import numpy as np
import pickle
import argparse, time
from agent import agent
from utils.logger import logger
from utils.logic.fol import parse_fol
from utils.logic.datalog import Datalog

def get_args():
    parser = argparse.ArgumentParser(description="Run experiments for the dataset")
    parser.add_argument('path', help='Path to directory containing dataset')
    parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name")
    parser.add_argument("--scoring_method", type=str, default="typed resolution", help="Method to use for scoring the options")
    parser.add_argument("--experiment_name", type=str, required=True, default="res", help="Name for saving the results of the experiment")
    parser.add_argument("--llm_name", type=str, default="gemini", help="Name of the LLM model to use")
    parser.add_argument("--retrieve_method", type=str, default="gpt3", help="Method to use for retrieving the rules")
    parser.add_argument("--exact_match", type=str, default='True', help="whether the queries are exact match or not")
    parser.add_argument("--max_steps", type=int, default=3, help="Maximum steps allowed for reasoining")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    
    dataset_name = args.dataset_name
    path = args.path
    scoring_method = args.scoring_method
    llm_name = args.llm_name
    retrieve_method = args.retrieve_method
    exact_match = args.exact_match == 'True'
    max_steps = args.max_steps


    if scoring_method == "typed resolution":
        Reasoner = agent.Typed_Resolution(retrieve_method, dataset_name)
    elif scoring_method == "zero shot CoT":
        Reasoner = agent.ZSCoT(retrieve_method, dataset_name)
    elif scoring_method == "few shot CoT":
        Reasoner = agent.FSCoT(retrieve_method, dataset_name)
    elif scoring_method == "zero shot CoT RAG":
        Reasoner = agent.ZSRAG(retrieve_method, dataset_name)
    elif scoring_method == "few shot CoT RAG":
        Reasoner = agent.FSRAG(retrieve_method, dataset_name)
    else:
        raise ValueError("Scoring method not supported")



    log = logger(args.experiment_name)
    log("Dataset: ", dataset_name); log("Path: ", path); log("Scoring Method: ", scoring_method); log("LLM Name: ", llm_name); log("Max Steps: ", max_steps); log("Retrieve Method: ", retrieve_method); log("Exact Match: ", exact_match)
    # load the dataset
    dataset = data.Dataset(dataset_name, exact_match, path, log)
    outcomes = []

    

    for i in tqdm.tqdm(range(400)):
    #for i in tqdm.tqdm(range(1000)):
        #time.sleep(7)

        log("Example no: ", i)

        query,  answer, kb= dataset(i)

        print(query)
        log("Query: ", query); log("Answer: ", answer)


        selected_option = Reasoner(query, kb , llm_name, log, max_steps)
        
        log("Selected option:", selected_option)

        if dataset_name == 'Recipe-MPR':
            options = list(kb.keys())
            incorrect_options = [option for option in options if option != answer]
            correctness = (answer in selected_option)
            for option in incorrect_options:
                if option in selected_option:
                    correctness = False
                    break
        else:
            correctness = str(answer).lower() in str(selected_option).lower()
        outcomes.append(correctness)
        if not correctness:
            log("Incorrect Answer", answer)

        print("Accuracy so far: ", np.mean(outcomes))

    
    print(f"Accuracy: {np.mean(outcomes)}"); print(f"Correct: {np.sum(outcomes)}"); print(f"Incorrect: {np.sum(np.logical_not(outcomes))}")
    log("Accuracy: ", np.mean(outcomes)); log("Correct: ", np.sum(outcomes)); log("Incorrect: ", np.sum(np.logical_not(outcomes)))
    

