import json
import time

import pandas as pd
import numpy as np
from . import common
# import common
from .drop_eval import DropEval
from .gpqa_eval import GPQAEval,GPQAPipelineEvalSimple
# from .humaneval_eval import HumanEval
from .math_eval import MathEval,MATHPipelineSTServer
from .mgsm_eval import MGSMEval
from .mmlu_eval import MMLUEval, MMLUPipelineEval,MMLUPipelineEvalSimple,MMLUPipelineEvalSimple2,MMLUPipelineSTServer
from .hellaswag_eval import HellaSwagPipelineSTServer
from .boolq_pipeline import BoolQPipelineSTServer
from .wino_pipeline import WinoPipelineSTServer
from .piqa_pipeline import PIQAPipelineSTServer
from .arc_pipeline import ARCPipelineSTServer
# from .mmlu_eval_st import MMLUPipelineSTServer
from .gsm_eval import GSMPipelineSTServer
import time
import multiprocessing
import os
import subprocess
from .csqa_pipeline import CSQAPipelineSTServer
from .socialqa_pipeline import SocialQAPipelineSTServer
from .obqa_pipeline import OBQAPipelineSTServer
# todo 8 boolq 

# 2 math 1 hellaswag 3 boolq 5 piqa
# 6 math refine
# 6 hellaswag 7 hellaswag 

# todo ablation: learn to refine answer think

# from .sampler.chat_completion_sampler import (
#     OPENAI_SYSTEM_MESSAGE_API,
#     OPENAI_SYSTEM_MESSAGE_CHATGPT,
#     ChatCompletionSampler,
# )
# from .sampler.local_sampler import LocalSampler
from .sampler.server_sampler import ServerSampler
# from .sampler.llama_sampler import LlamaCompletionSampler
import pickle
import os

# from .sampler.claude_sampler import ClaudeCompletionSampler, CLAUDE_SYSTEM_MESSAGE_LMSYS


def main():
    # try:
    for _ in range(1):
        debug = False
        num_iterations = 50
        # samplers = {
        #     # chatgpt models:
        #     # "gpt-4-turbo-2024-04-09_assistant": ChatCompletionSampler(
        #     #     model="gpt-4-turbo-2024-04-09",
        #     #     system_message=OPENAI_SYSTEM_MESSAGE_API,
        #     # ),
        #     # "gpt-4-turbo-2024-04-09_chatgpt": ChatCompletionSampler(
        #     #     model="gpt-4-turbo-2024-04-09",
        #     #     system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT,
        #     # ),
        #     # "gpt-4-1106-preview_chatgpt": ChatCompletionSampler(
        #     #     model="gpt-4-1106-preview",
        #     #     system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT,
        #     # ),
        #     # "llama3-8b-instruct": LlamaCompletionSampler(
        #     #     model="llama3",
        #     #     system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT,
        #     # ),
        


            
        #     # "gpt-4o_assistant": ChatCompletionSampler(
        #     #     model="gpt-4o",
        #     #     system_message=OPENAI_SYSTEM_MESSAGE_API,
        #     #     max_tokens=2048,
        #     # ),
        #     # "gpt-4o_chatgpt": ChatCompletionSampler(
        #         # model="gpt-4o-2024-05-13",
        #         # system_message=OPENAI_SYSTEM_MESSAGE_CHATGPT,
        #         # max_tokens=4096,
        #     # ),
        #     # claude models:
        #     # "claude-3-opus-20240229_empty": ClaudeCompletionSampler(
        #     #     model="claude-3-opus-20240229", system_message=None,
        #     # ),
        # }

        
        equality_checker =None# ChatCompletionSampler(model="gpt-4-turbo-preview")
        # ^^^ used for fuzzy matching, just for math
        multiprocessing.set_start_method('spawn')
        sampler = ServerSampler(
                model="./models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa",port=8500,device=0,num_processes=8
            )

        equality_sampler = None
        def get_evals(eval_name):
            if eval_name=='mmlu_pipeline':
                return MMLUPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='math_pipeline':
                return MATHPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='gsm_pipeline':
                return GSMPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='hellaswag_pipeline':
                return HellaSwagPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='wino_pipeline':
                return WinoPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='boolq_pipeline':
                return BoolQPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='piqa_pipeline':
                return PIQAPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='arc_pipeline':
                return ARCPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='obqa_pipeline':
                return OBQAPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='csqa_pipeline':
                return CSQAPipelineSTServer(sampler,equality_sampler)
            elif eval_name=='socialqa_pipeline':
                return SocialQAPipelineSTServer(sampler,equality_sampler)
            else:
                Exception(f"Unrecoginized eval type: {eval_name}")
                # case "mmlu":
                #     return MMLUEval(num_examples=1 if debug else 2500)
                # case "math":
                #     return MathEval(
                #         equality_checker=equality_checker, num_examples=5 if debug else 2500
                #     )
                # case "gpqa":
                #     return GPQAEval(n_repeats=1 if debug else 10, num_examples=5 if debug else None)
                # case "mgsm":
                #     return MGSMEval(num_examples_per_lang=10 if debug else 250)
                # case "drop":
                #     return DropEval(num_examples=10 if debug else 2000, train_samples_per_prompt=3)
                # case "humaneval":
                #     return HumanEval(num_examples=10 if debug else None)
                # case 'mmlu_pipeline':
                #     return MMLUPipelineST(sampler)
                # case 'math_pipeline':
                #     return MathPipelineEvalSimple(equality_checker=equality_checker)
                # case 'gpqa_pipeline':
                #     return GPQAPipelineEvalSimple()
                # case _:
                #     raise Exception(f"Unrecoginized eval type: {eval_name}")

        # evals = {
        #     eval_name: get_evals(eval_name) for eval_name in ["mmlu", "math", "gpqa", "mgsm", "drop"]
        # }
        evals = {
            eval_name: get_evals(eval_name) for eval_name in ["math_pipeline"]# math_pipeline mmlu_pipeline
        }
        print(evals)
        debug_suffix = "_DEBUG" if debug else ""
        mergekey2resultpath = {}
        num_iterations = 50

        for eval_name, eval_obj in evals.items():
            model_load_path='./models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa'
            base_path = './R-COT/logs/math_strict3_ours_long2000'
            model_save_path=base_path+'/iter0'
            log_path=base_path
            os.makedirs(log_path,exist_ok=True)

            eval_obj.evaluation()
            with open(base_path+'/iter%d_zs_test.jsonl'%(0), 'w') as jsonl_file:
                    for j in range(len(eval_obj.test_outputs)):
                        jsonl_file.write(json.dumps({'Question':eval_obj.sorted_testing_examples[j],'Answer':eval_obj.test_outputs[j],'Score':eval_obj.test_scores[j]}) + '\n')

            with open(base_path+'/iter%d_zs_train.jsonl'%(0), 'w') as jsonl_file:
                for j in range(len(eval_obj.train_outputs)):
                    jsonl_file.write(json.dumps({'Question':eval_obj.sorted_training_examples[j],'Answer':eval_obj.train_outputs[j],'Score':eval_obj.train_scores[j]}) + '\n')

            for i in range(num_iterations):
                eval_obj.data_generation_simplified_triple(zero_shot=False)
                with open(base_path+'/iter%d_reflection.jsonl'%(i), 'w') as json_file:
                    for j in range(len(eval_obj.training_data[0])):
                        json_file.write(json.dumps({'Dialoge':eval_obj.training_data[-1][j],'Score':eval_obj.training_data[2][j]}) + '\n')
                
                # eval_obj.data_generation_simplified_ori(zero_shot=True)

                sampler.kill_process()
                if equality_sampler is not None:
                    equality_sampler.kill_process()
                with open(base_path+'/iter%d_sft.jsonl'%(i), 'w') as jsonl_file:
                    for data in eval_obj.sft_data:
                        # 将每个字典转换为JSON字符串，并写入文件
                        jsonl_file.write(json.dumps(data) + '\n')


                # process = multiprocessing.Process(target=model_optimization_accelerate, args=(eval_obj.sft_data,model_load_path,model_save_path,log_path))
                command = [
                "accelerate", "launch",
                "./optmization.py",
                "--sft_data", base_path+"/iter%d_sft.jsonl"%i, 
                "--model_save_path", model_save_path,  
                "--model_load_path", model_load_path,  
                "--logging_path", log_path ]
                print(command)
                process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
                for line in process.stdout:
                    print(line, end='')  
                process.stdout.close()
                process.wait()



                # process = multiprocessing.Process(target=command)
                # process.start()
                # process.join()
                time.sleep(10)
                print(eval_obj.zero_shot_train_score_history)
                print(eval_obj.zero_shot_test_score_history)
                print(eval_obj.many_shot_train_score_history)
                print(eval_obj.many_shot_test_score_history)
                print(eval_obj.sft_train_score_history)
                with open(base_path+'/111_result.json', 'w') as json_file:
                    json.dump([eval_obj.zero_shot_train_score_history,eval_obj.zero_shot_test_score_history,eval_obj.many_shot_train_score_history,eval_obj.many_shot_test_score_history,eval_obj.sft_train_score_history,eval_obj.success_set_sizes,eval_obj.instruct_set_sizes], json_file)
                # json_str = json.dumps(eval_obj.instruction_list)
                with open(base_path+'/iter%d.jsonl'%(i), 'w') as json_file:
                    for j in eval_obj.instruction_list:
                        json_file.write(json.dumps(j) + '\n')
                    # json.dump(eval_obj.instruction_list, json_file)
                sampler.start_process(model_save_path)
                if equality_sampler is not None:
                    equality_sampler.start_process('./models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa')
                eval_obj.evaluation()
                with open(base_path+'/iter%d_zs_test.jsonl'%(i+1), 'w') as jsonl_file:
                    for j in range(len(eval_obj.test_outputs)):
                        jsonl_file.write(json.dumps({'Question':eval_obj.sorted_testing_examples[j],'Answer':eval_obj.test_outputs[j],'Score':eval_obj.test_scores[j]}) + '\n')

                with open(base_path+'/iter%d_zs_train.jsonl'%(i+1), 'w') as jsonl_file:
                    for j in range(len(eval_obj.train_outputs)):
                        jsonl_file.write(json.dumps({'Question':eval_obj.sorted_training_examples[j],'Answer':eval_obj.train_outputs[j],'Score':eval_obj.train_scores[j]}) + '\n')

                with open(base_path+'/iter%d_ms_train.jsonl'%(i+1), 'w') as jsonl_file:
                    for j in range(len(eval_obj.many_shot_train_outputs)):
                        jsonl_file.write(json.dumps({'Question':eval_obj.sorted_many_shot_training_examples[j],'Answer':eval_obj.many_shot_train_outputs[j],'Score':eval_obj.many_shot_train_scores[j]}) + '\n')

                with open(base_path+'/iter%d_ms_test.jsonl'%(i+1), 'w') as jsonl_file:
                    for j in range(len(eval_obj.many_shot_test_outputs)):
                        jsonl_file.write(json.dumps({'Question':eval_obj.sorted_many_shot_testing_examples[j],'Answer':eval_obj.many_shot_test_outputs[j],'Score':eval_obj.many_shot_test_scores[j]}) + '\n')
                model_load_path = model_save_path
                model_save_path = base_path+'/iter%d'%(i+1)
                print(eval_obj.zero_shot_train_score_history)
                print(eval_obj.zero_shot_test_score_history)
                print(eval_obj.many_shot_train_score_history)
                print(eval_obj.many_shot_test_score_history)
                print(eval_obj.sft_train_score_history)
                print('*'*30)
                print('*'*30)
                print('*'*30)
                with open(base_path+'/111_result.json', 'w') as json_file:
                    json.dump([eval_obj.zero_shot_train_score_history,eval_obj.zero_shot_test_score_history,eval_obj.many_shot_train_score_history,eval_obj.many_shot_test_score_history,eval_obj.sft_train_score_history,eval_obj.success_set_sizes,eval_obj.instruct_set_sizes], json_file)
                # json_str = json.dumps(eval_obj.instruction_list)
                with open(base_path+'/iter%d.json'%(i), 'w') as json_file:
                    json.dump(eval_obj.instruction_list, json_file)
    # except Exception as e:
        # print(e)
        # sampler.kill_process()

    return 


if __name__ == "__main__":
    main()
