
import os
import datetime
import json
import pandas as pd
from tqdm import tqdm


from args import parse_args
from libs.data_loader import load_jsonl_objects,extract_answer_letter,extract_answer_number,extract_response,extract_answer_yesno
from prompt import *
from libs.utils import compare_answer_with_groundtruth
from Actor_Critic.agent import Agent,actor_agent
import multiprocessing
DATA_BATCH_SIZE=1


def get_rephrase_response(agent,question,regenerate_response):
    user_prompt=rephrase_user_prompt.format(question=question,original_response=regenerate_response)
    rephrased_response = agent.call_agent(sys_prompt=rephrase_sys_prompt , user_prompt=user_prompt,temperature=0., max_tokens=4096, stop=None, n=1)

    return rephrased_response


def get_regenerate(batched_input_data,rank,log_dir,correct_count):
    correct=0
    for item in tqdm(batched_input_data, desc=str(rank), position=rank):

        item=item[0]
        question=item['question']
        groundtruth=item['groundtruth']

        index=item['index']
        context=item['context']
        feedback= item['feedback']
        original_response= item["single_log"]['messages'][2]['content']


        
        print(index)

        sol_dict=f'{log_dir}/sol_re/'
        correct_dict= f'{log_dir}/correct_re/'
        wrong_dict= f'{log_dir}/wrong_re/'

        sol_path = f'{log_dir}/sol_re/{index}_sol.jsonl'
        correct_path = f'{log_dir}/correct_re/{index}_correct.jsonl'
        wrong_path = f'{log_dir}/wrong_re/{index}_wrong.jsonl'
        
        
        os.makedirs(os.path.dirname(sol_dict), exist_ok=True)
        os.makedirs(os.path.dirname(correct_dict), exist_ok=True)
        os.makedirs(os.path.dirname(wrong_dict), exist_ok=True)

        if os.path.exists( sol_path) :
            print(f"Problems {index} exist")       
            continue

        else:
            extract_answer=extract_answer_yesno
            format_prompt=format_prompt_yesno
                
            before_rephrase = None
            while not before_rephrase:
                regenerate_user_prompt = user_single_regenerate_prompt.format(question=question, context=context,original_response=original_response, feedback=feedback,format_prompt=format_prompt)
                regenerate_log = actor_agent.call_agent(sys_prompt=sys_single_regenerate_prompt,user_prompt=regenerate_user_prompt, temperature=0., max_tokens=4096, stop=None, n=1)
                before_rephrase = extract_response(regenerate_log['messages'][2]['content'])
                if not before_rephrase:
                    print("agent_1_before_rephrase is None, regenerating...")
        
            rephrased_log = get_rephrase_response(actor_agent,question,before_rephrase)
            rephrased_response=rephrased_log['messages'][2]['content']
            regenerate_response = rephrased_response
            info=item
            info['re_log_regenerate_raw']=regenerate_log
            info['re_rephrased_log']=rephrased_log
            info['re_log']=item["single_log"]
            info['re_log']['messages'][2]['content']=regenerate_response
            
            answer = extract_answer(regenerate_response)
            
            info['re_answer']=answer
                
                
                
                    
            if isinstance(groundtruth, str):
                groundtruth = [groundtruth]
            if compare_answer_with_groundtruth(answer, *groundtruth):
                correct += 1
                info['re_correct']=True

                json_line = json.dumps(info,ensure_ascii=False)
                with open(sol_path, 'a', encoding='utf-8') as f_sol:
                    f_sol.write(json_line + '\n')
                with open(correct_path, 'a', encoding='utf-8') as f_correct:
                    f_correct.write(json_line + '\n')
            else:
                print("wrong: ", index)
                info['re_correct']=False

                json_line = json.dumps(info,ensure_ascii=False)
                with open(sol_path, 'a', encoding='utf-8') as f_sol:
                    f_sol.write(json_line + '\n')
                with open(wrong_path, 'a', encoding='utf-8') as f_wrong:
                    f_wrong.write(json_line + '\n')


    correct_count[rank] = correct

    





if __name__ == '__main__':
    args = parse_args()   
    print(args)

    inputfile=f"/Users/zwj/multi-agent-alignment/problem_solving_pubmed/logs/actor_critic/eval/gpt-4o-mini-2024-07-18_gpt-4o-mini-2024-07-18/sol_round_0/judgement-ft:gpt-4o-mini-2024-07-18:merty::Au8mqTjo/judgement/feedback-gpt-4o-mini-2024-07-18/feedback.jsonl"
    print(inputfile)
    log_dir = inputfile.replace("feedback.jsonl",f"regenerate-{actor_agent.model}")
    os.makedirs(log_dir, exist_ok=True)
    print(log_dir)

    input_datas=load_jsonl_objects(inputfile)
    num_processes = 32
    processes = []
    batched_dataset = [input_datas[i : i + DATA_BATCH_SIZE] for i in range(0, len(input_datas), DATA_BATCH_SIZE)] 
    manager = multiprocessing.Manager()
    correct_count = manager.dict()  # 
    prompt_type=args.prompt_type
    max_tokens=args.max_tokens
    temperature=args.temperature
    print(len(batched_dataset))
    print(actor_agent.model)
    start_time = datetime.datetime.now()
    for i in range(num_processes):
        p = multiprocessing.Process(target=get_regenerate, args=( batched_dataset[i :: num_processes],i,log_dir,correct_count))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
    total_correct = sum(correct_count.values())
    print(total_correct)
    end_time = datetime.datetime.now()
    print('time cost:', ((end_time-start_time).total_seconds()/60,2), ' mins')
