from openai import OpenAI
import openai

import re
import os
import time
import datetime
import json
import pandas as pd

from args import parse_args
from prompt import *
from libs.data_loader import load_jsonl_objects
from Actor_Critic.agent import Agent, critic_agent
import multiprocessing
import os
from tqdm import tqdm
DATA_BATCH_SIZE=1

def get_feedback(batched_input_data,rank,log_dir):
   
    for item in tqdm(batched_input_data, desc=str(rank), position=rank):
        item=item[0]
        index=item['index']
        

        question=item['question']
        original_response=item['single_log']['messages'][2]['content']
        context=item['context']

        feedback_dict=f'{log_dir}/feedback/'
        feedback_path = f'{log_dir}/feedback/{index}_feedback.jsonl'
        os.makedirs(os.path.dirname(feedback_dict), exist_ok=True)
        if os.path.exists( feedback_path) :
            print(f"Problems {index} feedback exist")       
            continue
        else:
            print(index)
            critic_users_promt=critic_users_promt_single_textgrad_pubmed
            critic_sys_promt=critic_sys_prompt_single_textgrad_pubmed

            
            user_prompt=critic_users_promt.format(question=question,context=context,original_response=original_response)
            feedback = {}
            feedback = item.copy()
            feedback['feedback']={}
            feedback_agent_log=critic_agent.call_agent(sys_prompt=critic_sys_promt, user_prompt=user_prompt,temperature=0., max_tokens=4096, stop=None, n=1)
            feedback['feedback']=feedback_agent_log['messages'][2]['content']
            feedback["feedback_agent_log"]=feedback_agent_log

            with open(feedback_path, 'w') as f_feedback:
                f_feedback.write(json.dumps(feedback) + '\n')
    


if __name__ == '__main__':
    args = parse_args()   
    print(args)
    print(critic_agent.model)
   
    inputfile=f"/Users/zwj/multi-agent-alignment/problem_solving_pubmed/logs/actor_critic/eval/gpt-4o-mini-2024-07-18_gpt-4o-mini-2024-07-18/sol_round_0/judgement-ft:gpt-4o-mini-2024-07-18:merty::Au8mqTjo/judgement/False.jsonl"
    print(inputfile)
    log_dir = inputfile.replace('False.jsonl', f'feedback-{critic_agent.model}')
    os.makedirs(log_dir, exist_ok=True)

    input_datas=load_jsonl_objects(inputfile)
    num_processes =32
    processes = []
    batched_dataset = [input_datas[i : i + DATA_BATCH_SIZE] for i in range(0, len(input_datas), DATA_BATCH_SIZE)] 


    prompt_type=args.prompt_type
    max_tokens=args.max_tokens
    temperature=args.temperature
    print(len(batched_dataset))
    start_time = datetime.datetime.now()
    for i in range(num_processes):
        p = multiprocessing.Process(target=get_feedback, args=( batched_dataset[i :: num_processes],i,log_dir))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

    end_time = datetime.datetime.now()
    print('time cost:', ((end_time-start_time).total_seconds()/60,2), ' mins')
