import logging
import argparse
from commonsense_constraint import evaluate_commonsense_constraints
# from hard_constraint_human import evaluate_hard_constraints
from preference_eval import evaluate_preference
from utils import load_json_file, validate_json
from tqdm import tqdm
import json


parser = argparse.ArgumentParser(
    description='Evaluate constraints on queries.')
parser.add_argument('--test_path', type=str,
                    default='test_20240913203758', help='Test Path')
parser.add_argument('--data', type=str,
                    default='medium', help='Data Path')
parser.add_argument('--preference', action='store_true', help='Use preference')
args = parser.parse_args()

if args.data in ["human"]:
    from hard_constraint_human import evaluate_hard_constraints
else:
    from hard_constraint import evaluate_hard_constraints

# Use the parsed argument
target_difficult = args.data
commonsense_constraints_name = ["Is_intercity_transport_correct", "Is_attractions_correct",
                                "Is_hotels_correct", "Is_restaurants_correct", "Is_transport_correct", "Is_time_correct", "Is_space_correct"]


query_file = '../data/{}.json'.format(args.data)

results_file = '../results/{}/'.format(args.test_path)
logging.basicConfig(filename=results_file + 'evaluation_{}.log'.format(target_difficult), level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s', filemode='w')

schema_file_path = './output_schema.json'
schema = load_json_file(schema_file_path)

query_data = load_json_file(query_file)
symbolic_input_list = []
plan_json_list = []
delivery_rate = 0

data_num = 0
query_index = []


for i, query in tqdm(enumerate(query_data)):
    # if query['difficulty'] != target_difficult and target_difficult != 'all':
    #     continue
    data_num += 1
    test_plan_path = results_file + 'query_{}_result.json'.format(i)
    query_index.append(i)
    try:

        test_plan = load_json_file(test_plan_path)
        if not isinstance(test_plan, dict):
            raise Exception("Plan type error")
        validate_json(test_plan, schema)
        symbolic_input_list.append(query)
        plan_json_list.append(test_plan)
        delivery_rate += 1

    except Exception as e:
        symbolic_input_list.append(query)
        plan_json_list.append([])
        logging.warning(f'processing query {i}: {e}')

logging.info(f'Data Len: {data_num}')
delivery_rate /= data_num


# print(symbolic_input_list)

cs_macro_accuracy, cs_micro_accuracy, com_results = evaluate_commonsense_constraints(
    symbolic_input_list, plan_json_list)
hard_macro_accuracy, hard_micro_accuracy, hard_results = evaluate_hard_constraints(
    symbolic_input_list, plan_json_list)


pass_rate = 0
for idx, (a, b) in enumerate(zip(hard_results, com_results)):

    if all(a) and all(b):
        logging.info(f'Pass Index: {idx}')
        pass_rate += 1
    else:
        filtered_constraints = [name for name, condition in zip(
            commonsense_constraints_name, b) if condition]
        filtered_constraints_concat = ', '.join(filtered_constraints)
        logging.warning('Query {}. Correct Commonsense: {}  |   Correct Logics: {}'.format(
            query_index[idx], filtered_constraints_concat, [i for i in range(len(a)) if a[i]]))

if args.preference:
    for idx, (a, b) in enumerate(zip(hard_results, com_results)):
        preference_json = {}
        # func_list = [convenient_transport, convenient_restaurant, near_poi, less_walk, meal_cost_ratio, accommodation_cost_ratio, attraction_cost_ratio, total_cost, attraction_satisfaction, attraction_count, indoor_attraction_ratio, popular_attraction_ratio]
        func_name_list = ['convenient_transport', 'convenient_restaurant', 'near_poi', 'less_walk', 'meal_cost_ratio', 'accommodation_cost_ratio',
                          'attraction_cost_ratio', 'total_cost', 'attraction_satisfaction', 'attraction_count', 'indoor_attraction_ratio', 'popular_attraction_ratio']
        for func_name in func_name_list:
            preference_json[func_name] = -1
        preference_json['status'] = False
        if all(b):
            preference_json = evaluate_preference(
                query_data[idx], plan_json_list[idx])
            preference_json['status'] = True

        with open(results_file + 'preference_{}.json'.format(idx), 'w', encoding='utf-8') as f:
            json.dump(preference_json, f, ensure_ascii=False, indent=4)


pass_rate = 100*pass_rate/data_num
logging.info('Delivery Rate: {:.4}%'.format(delivery_rate * 100))
logging.info('hard_constraint_micro: {:.4}%, hard_constraint_macro: {:.4}%'.format(
    hard_micro_accuracy, hard_macro_accuracy))
logging.info('commonsense_constraint_micro: {:.4}%, commonsense_constraint_macro: {:.4}%'.format(
    cs_micro_accuracy, cs_macro_accuracy))
logging.info('Final Pass Rate {:.4}%'.format(pass_rate))
