from utils import api_util
import argparse
import json
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer


def evaluate_parameter_f1(gen_dict: dict, ref_dict: dict):
    
    easy_f1 = []
    hard_f1 = []
    f1 = []
    
    for k, ref_v in ref_dict.items():
        
        if k not in gen_dict:
            gen_dict[k] = {}
        
        gen_v = gen_dict[k]
        gen_match_count = 0
        ref_match_count = 0
        
        half_match = 0
        full_match = 0
        
        for action, parameter in ref_v.items():
            ref_match_count += len(parameter)
        
        for action, parameter in gen_v.items():
            
            gen_match_count += len(parameter)
            if action in ref_v:
                golden_parameter = ref_v[action]
                # print("--" * 10)
                # # print(action)
                # print(parameter)
                # print(golden_parameter)
                for k, v in parameter.items():
                    k = k.lower()
                    if k in golden_parameter:
                        
                        if v == golden_parameter[k]:
                            full_match += 1
                        else:
                            half_match += 1

            recall = (0.5 * half_match + full_match) / ref_match_count if ref_match_count else 0
            precision = (0.5 * half_match + full_match) / gen_match_count if gen_match_count else 0
            f1.append((2 * recall * precision) / (recall + precision) if (recall + precision) else 0)
            
    return sum(f1) / len(f1)+ 1e-30

def evaluate_action_em(gen_dict: dict, ref_dict: dict):

    if len(gen_dict) == 0:
        return 0

    em = 0
    em_all = 0
    f1 = []
    
    for k, ref_v in ref_dict.items():
        
        if k not in gen_dict:
            gen_dict[k] = {}
        
        gen_v = gen_dict[k]
        fit = 0
        print(gen_v, ref_v)
        # print("--" * 20)
        for action in gen_v:
            if action in ref_v:
                em += 1
                fit += 1
        
        em_all += len(ref_v)
        
        recall = fit / (len(ref_v) + 1e-30)
        precision = fit / (len(gen_v) + 1e-30)
        f1.append((2 * recall * precision) / (recall + precision + 1e-30))
    
    return em/em_all, sum(f1) / (len(f1) + 1e-30)

def calculate_bleu(reference, candidate):
    reference_tokens = nltk.word_tokenize(reference.lower())
    candidate_tokens = nltk.word_tokenize(candidate.lower())
    
    smoothing_function = SmoothingFunction().method1
    bleu_score = sentence_bleu([reference_tokens], candidate_tokens, smoothing_function=smoothing_function)
    return bleu_score

def calculate_rouge(reference, candidate):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference, candidate)
    return scores

def evaluate_response(gen_dict: dict, ref_dict: dict):

    bleu_score = []

    rouge1_score = []
    rouge2_score = []
    rougel_score = []
    
    for k, ref_v in ref_dict.items():
        
        if k not in gen_dict:
            continue
        
        gen_v = gen_dict[k]
        bleu = calculate_bleu(ref_v, gen_v)
        rouge_scores = calculate_rouge(ref_v, gen_v)
        
        bleu_score.append(bleu)
        rouge1_score.append(rouge_scores['rouge1'].fmeasure)
        rouge2_score.append(rouge_scores['rouge2'].fmeasure)
        rougel_score.append(rouge_scores['rougeL'].fmeasure)
    
    return sum(bleu_score)/len(bleu_score), sum(rouge1_score)/len(rouge1_score), sum(rouge2_score)/len(rouge2_score), sum(rougel_score)/len(rougel_score)


def evaluate(args):

    config_file = args.config_file
    json_config = json.load(open(config_file))
    
    if json_config["current_file"]['split'] == "train":
        ref_file = json_config["train_file"][json_config["current_file"]["prefix"]]
    elif json_config["current_file"]['split'] == "test":
        ref_file = json_config["test_file"][json_config["current_file"]["prefix"]]
    gen_file = json_config["final"]['traj'].format(
        policy_aka=json_config['policy']['aka'],
        core_name=json_config['core']['name'],
        core_memory=json_config['core']['memory'],
        split=json_config['current_file']['split'],
        prefix=json_config['current_file']['prefix'],
    )

    ref2action = {}
    ref2parameter = {}
    ref2response = {}
    
    for line in open(ref_file).readlines():
        json_item = json.loads(line)
        query_id = json_item['session_id']
        for action in json_item['golden']:
                
            
            api_name = action['api_info']['api_name']
            if query_id not in ref2action:
                ref2action[query_id] = set()
            
            ref2action[query_id].add(api_name)

            if query_id not in ref2parameter:
                ref2parameter[query_id] = {}
            
            ref2parameter[query_id][api_name] = action['parameters']
        
        ref2response[query_id] = json_item['answer']
    
    gen2action = {}
    gen2parameter = {}
    gen2response = {}

    line_idx = 0
    for line in open(gen_file).readlines():
        line_idx += 1
        json_item = json.loads(line)

        query_id = json_item['session_id']
        
        for idx, action in enumerate(json_item['answer']['api_traj']):
            parameter = json_item['answer']['parameter_traj'][idx]
            
            if query_id not in gen2action:
                gen2action[query_id] = set()
            gen2action[query_id].add(action['api_info']['api_name'])

            
            if query_id not in gen2parameter:
                gen2parameter[query_id] = {}
            gen2parameter[query_id][action['api_info']['api_name']] = parameter
        gen2response[query_id] = json_item['answer']['final_answer']
    
    print(line_idx)
    print(len(gen2action), len(ref2action))
    print(len(gen2parameter), len(ref2parameter))
    
    em, f1 = evaluate_action_em(gen2action, ref2action)
    print(em, f1)

    f1 = evaluate_parameter_f1(gen2parameter, ref2parameter)
    print(f1)



if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="auto evalutation and generation")
    parser.add_argument(
        "-c",
        "--config-file",
        type=str,
        default="conf/auto_config.yaml",
        help="The file of configuration.",
    )

    args = parser.parse_args()
    evaluate(args)
