import argparse
import os
import csv
from reLLM import *
from openai import OpenAI

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--target_model', type=str, default="gemini-2.0-flash-lite")
    parser.add_argument('--target_api_key', type=str)
    parser.add_argument('--target_base_url', type=str)
    # parser.add_argument('--perception_model', type=str, default="deepseek-chat")
    parser.add_argument('--perception_model', type=str, default="qwen-max")
    parser.add_argument('--perception_api_key', type=str)
    parser.add_argument('--perception_base_url', type=str)
    parser.add_argument('--api_timeout', type=int, default=60)
    parser.add_argument('--esl_name', type=str)
    parser.add_argument('--inter_level', type=int, default=2) # interpretation level: 1 or 2
    parser.add_argument('--target_memory', type=int, default=0) # if the target LLM has memory
    parser.add_argument('--CS', type=int, default=2) # case study number
    parser.add_argument('--domain', type=str, default="") # mrt (CS=1), comp (CS=2), ineq (CS=3)
    # parser.add_argument('--sub_domain', type=str, default="Smoking") # mrt: Smoking; comp: None; ineq: factorization, interval, endpoints;
    parser.add_argument('--sub_domain', type=str, default="") # mrt: Smoking; comp: None; ineq: factorization, interval, endpoints;
    # parser.add_argument('--sub_domain', type=str, default="factorization") # mrt: Smoking; comp: None; ineq: factorization, interval, endpoints;
    parser.add_argument('--output', type=str, default="./outputs")

    args = parser.parse_args()

    if args.CS==1:
        args.domain = 'mrt'
    elif args.CS==2:
        args.domain = 'comp'
    elif args.CS==3:
        args.domain = 'ineq'
    else:
        print("Currently we only support three values of CS")
        exit(0)


    # # Or set these within the code
    # #
    # args.target_model = ''
    # args.target_api_key = ''
    # args.target_base_url = ''
    # args.perception_model = ''
    # args.perception_api_key = ''
    # args.perception_base_url = ''

    api_config = {'target_model_name': args.target_model, 'target_api_key': args.target_api_key,
                  'target_base_url': args.target_base_url, 'target_max_tokens': 4096,
                  'perception_model_name': args.perception_model, 'perception_api_key': args.perception_api_key,
                  'perception_base_url': args.perception_base_url, 'perception_max_tokens': 4096}



    CS = args.CS
    domain = args.domain
    sub_domain = args.sub_domain

    if CS == 1:
        # assert domain == 'mrt'
        context_csv = f'./benchmarks/{domain}/data/{sub_domain}.csv'
        all_context = []
        context_severity = []
        with open(context_csv, mode="r", encoding='utf-8') as file:
            reader = csv.reader(file)
            for row in reader:
                all_context.append(row[1])
                context_severity.append(row[-1])

        esl_rootPath = f'./benchmarks/{domain}/esl/{sub_domain}.json'

        context_id_test = 0
        task_id = context_id_test
        context_num = context_id_test + 1

        for context in all_context[context_id_test + 1:]:
            # print(f"\n\nDomain-{domain}-Sub_domain-{sub_domain}, Context_ID-{task_id}")
            score = context_severity[context_num]
            context_num += 1
            reLLM = ReLLM(api_config, args, context, esl_rootPath, CS)
            reLLM.set_task_id(task_id)
            reLLM.set_output_log(
                f"{args.output}/Case_study_{CS}_{domain}/{sub_domain}/Context_{task_id}_severity_{score}_target_model_{args.target_model}_perception_model_{args.perception_model}_interLevel_{args.inter_level}.txt")
            domainVec = [domain, sub_domain]
            if_consistent, rm_res = reLLM.run_test(domainVec, context, CS, if_print=True)
            task_id = task_id + 1
            exit(0)

    elif CS == 2:  # comparison case study
        # assert domain == 'comp'
        context_json = f"./benchmarks/{domain}/data/comparison.json"
        esl_rootPath = f'./benchmarks/{domain}/esl/comparison.json'

        with open(context_json, "r") as f:
            problems_comparison = json.load(f)

        for problem in problems_comparison:
            task_id = problem['id']
            comp_a = problem['comparison_number_1']
            comp_b = problem['comparison_number_2']

            res = problem['solution']
            context = f"{comp_a} and {comp_b}, which one is greater?"
            # print(f"\n\nDomain-{domain}-Sub_domain-{sub_domain}, Context_ID-{task_id}")

            reLLM = ReLLM(api_config, args, context, esl_rootPath, CS)
            reLLM.set_task_id(task_id)
            reLLM.set_output_log(
                f"{args.output}/Case_study_{CS}_{domain}/{sub_domain}/Context_{task_id}_target_model_{args.target_model}_perception_model_{args.perception_model}_interLevel_{args.inter_level}.txt")
            domain = [domain, sub_domain]
            if_consistent, rm_res = reLLM.run_test(domain, context, CS, if_print=True)
            exit(0)

    elif CS == 3:  # Check others for inequality dataset
        csv_path = f"./benchmarks/{domain}/data/inequality_result_{args.target_model}_all.csv"
        esl_rootPath = f'./benchmarks/{domain}/esl/inequality_{sub_domain}.json'

        all_context = []
        all_context_id = []

        with open(csv_path, mode='r') as file:
            csvFile = csv.reader(file)
            for lines in csvFile:
                status = lines[3]
                if status == '0':
                    all_context.append(lines[1])
                    all_context_id.append(lines[0])

        for num in range(len(all_context)):
            context = all_context[num]
            task_id = all_context_id[num]
            # print(f"\n\nDomain-{domain}-Sub_domain-{sub_domain}, Context_ID-{context_id}")

            reLLM = ReLLM(api_config, args, context, esl_rootPath, CS)
            reLLM.set_task_id(task_id)
            reLLM.set_output_log(
                f"{args.output}/Case_study_{CS}_{domain}/{sub_domain}/Context_{task_id}_target_model_{args.target_model}_perception_model_{args.perception_model}_interLevel_{args.inter_level}.txt")

            domainVec = [domain, sub_domain]
            if_consistent, rm_res = reLLM.run_test(domain, context, CS, if_print=True)
            exit(0)
    else:
        print(
            "Currently, we only support three case studies: CS = {1,2,3}.")
        exit(0)