import argparse
import json
import pdb
import jsonlines
import os
from pathlib import Path

from vllm import LLM, SamplingParams
import sys
import cl_llm.eval.utils as util
MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"

invalid_outputs = []
def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None

def process_results(doc, completion, answer):
    split_ans = completion.lower().split('final answer:')
    if len(split_ans) > 1:
        ans = split_ans[-1].strip()
        extract_ans_temp = ans.split('.\n')[0]
        extract_ans_temp = extract_ans_temp.strip()
        if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
            extract_ans = extract_ans_temp[0:-1]
        else:
            extract_ans = extract_ans_temp
        extract_ans = extract_ans.strip()
        if util.is_equiv(extract_ans, answer):
            return True
        else:
            return False
    else:
        temp = {'question': doc, 'output': completion, 'answer': answer}
        invalid_outputs.append(temp)
        return False
def batch_data(data_list, batch_size=1):
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    last_end = MAX_INT
    batch_data.append(data_list[last_start:last_end])
    return batch_data

def test_hendrycks_math(model, tokenizer, data_path, start, end, batch_size, scores, process_idx, output_dir, tensor_parallel_size=1):
    if process_idx:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(process_idx)

    hendrycks_math_ins = []
    hendrycks_math_answers = []
    problem_prompt = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
    )
    print('promt =====', problem_prompt)
    with open(data_path, "r+", encoding="utf8") as f:
        for idx, item in enumerate(jsonlines.Reader(f)):
            temp_instr = problem_prompt.format(instruction=item["instruction"])
            hendrycks_math_ins.append(temp_instr)
            solution = item['output']
            temp_ans = remove_boxed(util.last_boxed_only_string(solution))
            hendrycks_math_answers.append(temp_ans)

    print('total length ===', len(hendrycks_math_ins))
    hendrycks_math_ins = hendrycks_math_ins[start:end]
    hendrycks_math_answers = hendrycks_math_answers[start:end]
    print('lenght ====', len(hendrycks_math_ins))
    batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)

    stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
    sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
    print('sampleing =====', sampling_params)
    llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size, tokenizer=tokenizer)
    res_completions = []
    for idx, (prompt, prompt_answer) in enumerate(zip(batch_hendrycks_math_ins, hendrycks_math_answers)):
        if isinstance(prompt, list):
            pass
        else:
            prompt = [prompt]
        completions = llm.generate(prompt, sampling_params)
        for output in completions:
            prompt_temp = output.prompt
            generated_text = output.outputs[0].text
            res_completions.append(generated_text)

    results = []
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    file_path = output_dir / f"math_results_{process_idx}.jsonl"
    output_data = []
    for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):
        res = process_results(prompt, completion, prompt_answer)
        results.append(res)
        output_data.append({
            "idx": idx,
            "prompt": prompt,
            "completion": completion,
            "answer": prompt_answer,
            "result": res
        })
        
    with jsonlines.open(file_path, 'w') as f:
        f.write_all(output_data)

    acc = sum(results) / len(results)
    if scores is not None:
        scores[process_idx] = acc

    print(f'Process {process_idx} finished Acc={acc}: Used {len(results)}/{len(hendrycks_math_ins)} samples) [{start}, {end}[')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default=0)  # model path
    parser.add_argument("--data_file", type=str, default='data/MATH_test.jsonl')  # data path
    parser.add_argument("--start", type=int, default=0) #start index
    parser.add_argument("--end", type=int, default=MAX_INT)  # end index
    parser.add_argument("--batch_size", type=int, default=50)  # batch_size
    parser.add_argument("--tensor_parallel_size", type=int, default=1)  # tensor_parallel_size

    parser.add_argument('--tokenizer', type=str, default=None)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument("--output_dir", type=str, required=None)

    return parser.parse_args()


import multiprocessing

if __name__ == "__main__":
    args = parse_args()
    
    if args.model.endswith("_merged") and args.tokenizer is None:
        args.tokenizer = args.model.replace("_merged", "")
    if args.output_dir is None:
        ws = os.environ['WS_PATH']
        args.output_dir = Path(ws) / "exp_results" / Path(args.model).name.replace("_merged", "") / "math"
                                                                                                              
    if args.world_size > 1:
        scores = multiprocessing.Manager().list([0]*args.world_size)
    
        test_size = 5000
        process_data_size = test_size // args.world_size
        processes = []
        for process_idx in range(args.world_size):
            start_idx = process_idx * process_data_size
            end_idx = test_size if process_idx == args.world_size - 1 else start_idx + process_data_size
            print(f"Process {process_idx} will process {start_idx} to {end_idx}")
            p = multiprocessing.Process(target=test_hendrycks_math, args=(args.model, args.tokenizer, args.data_file, start_idx, end_idx, args.batch_size, scores, process_idx, args.output_dir))
            processes.append(p)

        for p in processes:
                p.start()
            
        # Wait for all processes to finish
        print('Waiting for everyone to finish')
        for p in processes:
            p.join()
            
        # Print the result
        print("Scores:", scores)
        print(f"Final score: {sum(scores) / len(scores)}")
    else:
        test_hendrycks_math(model=args.model, tokenizer=args.tokenizer, data_path=args.data_file, start=args.start, 
                            end=args.end, batch_size=args.batch_size, 
                            scores=None,
                            process_idx=None,
                            output_dir=args.output_dir,
                            tensor_parallel_size=args.tensor_parallel_size)


