
import os
import datetime
import json
import pandas as pd
from tqdm import tqdm
import re

from args import parse_args
from libs.data_loader import load_jsonl_objects,extract_response,extract_answer_yesno
from prompt_critic import *
from libs.utils import compare_answer_with_groundtruth
from Actor_Critic.agent import Agent,judge_agent
import multiprocessing
DATA_BATCH_SIZE=1


def extract_answer_True_False(input_string):
    ANSWER_PATTERN_YESNO = r"(?i)(Decision|Opinion)\s*:\s*(True|False|true|false)"
    match = re.search(ANSWER_PATTERN_YESNO, input_string)
    extracted_answer = match.group(2) if match else input_string
    return extracted_answer


def generate_critic_judegement(info,judge_agent):
    question=info['question']
    original_response=info['single_log']['messages'][2]['content']
    context=info['context']
    
    
    
    user_prompt=user_critic_prompt.format(question=question,context=context,original_response=original_response)
    # print(user_prompt)
    judgement = {}
    judgement = info.copy()
    judgement['judgement']={}
    judgement_agent_log=judge_agent.call_agent(sys_prompt=sys_critic_prompt, user_prompt=user_prompt,temperature=0., max_tokens=4096, stop=None, n=1)
    judgement_agent_response=judgement_agent_log['messages'][2]['content']
    judgement=extract_answer_True_False(judgement_agent_response)
    return judgement




def get_regenerate(batched_input_data,rank,log_dir,correct_count,pt_count,pf_count,nf_count,nt_count):
    correct=0
    pt=0
    pf=0
    nf=0
    nt=0

    for item in tqdm(batched_input_data, desc=str(rank), position=rank):
        item=item[0]
        index=item['index']
        print(index)
        question=item['question']
        original_response=item['single_log']['messages'][2]['content']
        context=item['context']
        score=item['score']
        
    
        ALL_dict=f'{log_dir}/ALL/'
        True_dict= f'{log_dir}/True/'
        False_dict= f'{log_dir}/False/'

        ALL_path = f'{log_dir}/ALL/{index}_judgement.jsonl'
        True_path = f'{log_dir}/True/{index}_True.jsonl'
        False_path = f'{log_dir}/False/{index}_False.jsonl'
        print(ALL_path)
       
        os.makedirs(os.path.dirname(ALL_dict), exist_ok=True)
        os.makedirs(os.path.dirname(True_dict), exist_ok=True)
        os.makedirs(os.path.dirname(False_dict), exist_ok=True)
        # print(sol_dict)

        if os.path.exists( ALL_path) :
            print(f"Problems {index} exist")       
            continue

        else:
             
            user_prompt=user_critic_prompt.format(question=question,context=context,original_response=original_response)
            # print(user_prompt)
            info = {}
            info = item.copy()
            info['judgement']={}
            judgement_agent_log=judge_agent.call_agent(sys_prompt=sys_critic_prompt, user_prompt=user_prompt,temperature=0., max_tokens=4096, stop=None, n=1)
            judgement_agent_response=judgement_agent_log['messages'][2]['content']
            judgement=extract_answer_True_False(judgement_agent_response)
            info['judgement']=judgement
            info['judgement_agent_log']=judgement_agent_log
            if score ==True and judgement == 'True':
                label="PT"
                pt=pt+1
            elif score ==True and judgement == 'False':
                label="PF"
                pf=pf+1
            elif score ==False and judgement == 'False':
                label="NT"
                nt=nt+1
            elif score ==False and judgement == 'True':
                label="NF"
                nf=nf+1
           
            else:
                label=f"{judgement}_{score}"
            info['label']=label
            if judgement == 'True':
                json_line = json.dumps(info,ensure_ascii=False)
                with open(ALL_path, 'a', encoding='utf-8') as f_ALL:
                    f_ALL.write(json.dumps(info) + '\n')
                with open(True_path, 'a', encoding='utf-8') as f_True:
                    f_True.write(json.dumps(info) + '\n')
                
            else:
                json_line = json.dumps(info,ensure_ascii=False)
                with open(ALL_path, 'a', encoding='utf-8') as f_ALL:
                    f_ALL.write(json.dumps(info) + '\n')
                with open(False_path, 'a', encoding='utf-8') as f_False:
                    f_False.write(json.dumps(info) + '\n')

                
            

    correct=pt+nt
    pt_count[rank] = pt
    pf_count[rank] = pf
    nf_count[rank] = nf
    nt_count[rank] = nt
    correct_count[rank] = correct



if __name__ == '__main__':
    args = parse_args()   
    print(args)
    print(judge_agent.model)
    inputfile = f"/Users/zwj/multi-agent-alignment/problem_solving_pubmed/logs/actor_critic/eval/gpt-4o-mini-2024-07-18_gpt-4o-mini-2024-07-18/sol_round_0/sol.jsonl"
    print(inputfile)
    log_dir=inputfile.replace('sol.jsonl',f'judgement-{judge_agent.model}/')

    os.makedirs(log_dir, exist_ok=True)

    input_datas=load_jsonl_objects(inputfile)
    num_processes = 32
    processes = []
    batched_dataset = [input_datas[i : i + DATA_BATCH_SIZE] for i in range(0, len(input_datas), DATA_BATCH_SIZE)] 
    manager = multiprocessing.Manager()
    correct_count = manager.dict()  # 
    pt_count = manager.dict()  #
    pf_count = manager.dict()  #
    nf_count = manager.dict()  #
    nt_count = manager.dict()  #

    prompt_type=args.prompt_type
    max_tokens=args.max_tokens
    temperature=args.temperature
    print(len(batched_dataset))
    start_time = datetime.datetime.now()
    for i in range(num_processes):
        p = multiprocessing.Process(target=get_regenerate, args=( batched_dataset[i :: num_processes],i,log_dir,correct_count,pt_count,pf_count,nf_count,nt_count))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
    total_correct = sum(correct_count.values())
    print("PT: ", sum(pt_count.values()))
    print("PF: ", sum(pf_count.values()))
    print("NT: ", sum(nt_count.values()))
    print("NF: ", sum(nf_count.values()))
    print("Total Predict correct: ", total_correct)
    end_time = datetime.datetime.now()
    print('time cost:', ((end_time-start_time).total_seconds()/60,2), ' mins')
