import json
import torch
import os
import logging
import csv
import tqdm
import collections
import re
from llm import *
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import prompts_tomi
import argparse

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

def parse_args():
    parser = argparse.ArgumentParser(description="Run direct generation for various datasets and models.")
    
    parser.add_argument(
        '--dataset_name', 
        type=str, 
        required=True, 
        choices=['test', 'hi_tom', 'tomi', 'gpqa', 'math500', 'aime', 'amc', 'livecode', 'nq', 'triviaqa', 'hotpotqa', '2wiki', 'musique', 'bamboogle', 'medmcqa', 'pubhealth'],
        help="Name of the dataset to use."
    )
    
    parser.add_argument(
        '--split', 
        type=str,  
        choices=['test', 'diamond', 'main', 'extended'],
        help="Dataset split to use."
    )
    
    parser.add_argument(
        '--model_path', 
        type=str, 
        required=True,
        help="Path to the pre-trained model."
    )

    parser.add_argument(
        '--model_name', 
        type=str, 
        required=True,
        help="Name of the pre-trained model."
    )
    
    parser.add_argument(
        '--temperature', 
        type=float, 
        default=0.0, 
        help="Sampling temperature."
    )
    
    parser.add_argument(
        '--repetition_penalty', 
        type=float, 
        default=1.2, 
        help="Repetition penalty. If not set, defaults based on the model."
    )
    
    parser.add_argument(
        '--max_tokens', 
        type=int, 
        default=4096, 
        help="Maximum number of tokens to generate. If not set, defaults based on the model and dataset."
    )

    parser.add_argument('--num_probs', '-n', type=int, default=300)
    parser.add_argument('--method', type=str, default='baseline')
    parser.add_argument('--wandb', type=int, default=0)
    parser.add_argument('--tags', type=str, default="debug")
    parser.add_argument('--category', type=str, default='all')
    parser.add_argument('--project', type=str, default=None)
    parser.add_argument('--entity', type=str, default=None)
    
    return parser.parse_args()

hint = "The following reasoning has encountered difficulties. Now each story statement sequence number has a current state prompt. Please select the required prompt sequence number based on the following story plot and reasoning trajectory. Don't answer the question and only give the sequence number. The answer format is **Answer:**"

def split_story_by_lines(story_text):
    """
    将故事文本按行分割成多个短句，返回句子列表
    """
    # 按换行符分割文本
    lines = story_text.strip().split('\n')
    
    # 过滤掉空行，创建句子列表
    sentence_list = [line.strip() for line in lines if line.strip()]
    
    return sentence_list

def main():
    args = parse_args()
    model_realname = args.model_name
    model_realname = model_realname.replace("/", "-")
    logging.basicConfig(
    filename=f"log_{model_realname}-alternative_multistep_HiToM.log",      # 日志文件名称
    # filename=f"log_{model_realname}_HiToM_prompt.log",      # 日志文件名称
    level=logging.INFO,         # 设置日志级别，可选DEBUG, INFO, WARNING, ERROR, CRITICAL
    format='%(asctime)s - %(levelname)s - %(message)s'
    )
    category_results = {} # Results per category
    category_percents = {}
    correctNum = 0
    totalNum = 0
    corr_an1_corr_an2 = 0
    corr_an1_wrong_an2 = 0
    wrong_an1_wrong_an2 = 0
    wrong_an1_corr_an2 = 0
    totalNum = 0
    top = 3
    count = 0
    correct_an1 = False
    correct_an2 = False
    dataset_name = args.dataset_name
    split = args.split
    model_path = args.model_path
    temperature = args.temperature
    repetition_penalty = args.repetition_penalty
    max_tokens = args.max_tokens
    
    # Print stuff
    print("\n------------------------")
    print("    EVALUATING HiToM      ")
    print("------------------------")
    print(f"EVAL MODEL: {args.model_name}")
    print(f"DATA: {args.dataset_name}")
    print(f"METHOD: {args.method}")
    # print(f"CATEGORY: {args.category}")
    print(f"N = {args.num_probs}")
    print("------------------------\n")
    logging.info("------------------------")
    logging.info("    EVALUATING HiToM      ")
    logging.info("------------------------")
    logging.info(f"EVAL MODEL: {args.model_name}")
    logging.info(f"DATA: {args.dataset_name}")
    logging.info(f"METHOD: {args.method}")
    # logging.info(f"CATEGORY: {args.category}")
    logging.info(f"N = {args.num_probs}")

    # Set default repetition_penalty if not provided
    if repetition_penalty is None:
        repetition_penalty = 1.05 if 'qwq' in model_path.lower() or 'deepseek' in model_path.lower() or 'sky-t1' in model_path.lower() else 1.0
    
    # Paths to datasets
    if dataset_name == 'math500':
        data_path = f'./data/MATH500/{split}.json'
    elif dataset_name == 'gpqa':
        data_path = f'./data/GPQA/{split}.json'
    elif dataset_name == 'aime':
        data_path = f'./data/AIME/{split}.json'
    elif dataset_name == 'amc':
        data_path = f'./data/AMC/{split}.json'
    elif dataset_name == 'livecode':
        data_path = f'./data/LiveCodeBench/{split}.json'
    elif dataset_name in ['nq', 'triviaqa', 'hotpotqa', 'musique', 'bamboogle', '2wiki', 'medmcqa', 'pubhealth']:
        data_path = f'./data/QA_Datasets/{dataset_name}.json'
    elif dataset_name == 'hi_tom':
        data_path = f'./data/Hi-ToM_data.json'
    elif dataset_name == 'test':
        data_path = f'./data/hi-tom_test.json'
    else:
        raise ValueError(f"Unsupported dataset_name: {dataset_name}")
    
    prompt_path = f'./prompts/hi_tom/dsqw32.json'

    # Load the model
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    
    
    llm = LLM(
        model=model_path,
        tensor_parallel_size=torch.cuda.device_count(),
        max_model_len=30000,
        gpu_memory_utilization=0.95,
    )
    
    # Load data
    with open(prompt_path, 'r', encoding='utf-8') as p_in:
        pros = json.load(p_in)
        with open('multistep_DsQwen32b_alternative_hitom.csv', mode='w', newline='') as outp:
            writer = csv.writer(outp)
            writer.writerow(['num', 'token1', 'token2', 'judge1', 'judge2', 'step'])
            with open(data_path) as f_in:
                data = json.load(f_in)
                for item in tqdm.tqdm(data["data"]):
                    story_length = item["story_length"]
                    question_order = item["question_order"]
                    sample_id = item["sample_id"]
                    story = item["story"]
                    question = item["question"]
                    choices = item["choices"]
                    answer = item["answer"]
                    pattern = r'[A-O]\. ([^,]+)'
                    options = re.findall(pattern, choices)
                    flag = 0
                    flag_2 = 0
                    sum = 0
                    count = 0

                    if (sample_id + 1) > args.num_probs:
                        break

                    baselinePrompt = prompts_tomi.Hi_tom_baselinePrompt.format(story=story, question=question, containers_0=options[0], containers_1=options[1], containers_2=options[2], containers_3=options[3], containers_4=options[4], containers_5=options[5], containers_6=options[6], containers_7=options[7], containers_8=options[8], containers_9=options[9], containers_10=options[10], containers_11=options[11], containers_12=options[12], containers_13=options[13], containers_14=options[14])
                    prompt = [{"role": "user", "content": baselinePrompt}]
                    prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)

                    max_tokens = 4096
                    # stop_word = ["tricky", "ambiguous", "complicating", "confusion", "confusing", "confused", "ambiguous"]
                    # stop_word = ["wait","check","make sure","hold on","verify","let me see","confirm","ensure","evaluate","examine"]
                    stop_word = ["instead","alternative","another","alternatively","however","while","yet","though","rather","otherwise","on the other hand"]
                    last = -1
                    sentence_list = split_story_by_lines(story)
                    for test in enumerate(sentence_list):
                        last += 1

                    # Generate model outputs
                    prompts = prompt
                    sampling_params=SamplingParams(
                        max_tokens=max_tokens, 
                        temperature=temperature, 
                        # top_p=top_p, 
                        # top_k=top_k, 
                        repetition_penalty=repetition_penalty,
                    )
                    output = llm.generate(
                        prompts, 
                        sampling_params=sampling_params
                    )
                    len_1 = len(output[0].outputs[0].text)
                    if "Answer:" in output[0].outputs[0].text:
                        answer_1 = output[0].outputs[0].text.split("Answer:")[-1].strip()
                    else:
                        answer_1 = "Overthinking"
                    logging.info(f"Answer_1: {answer_1}")
                    len_2 = len_1
                    for _ in range(top):
                        flag = 0
                        for stops in stop_word:
                            if stops in output[0].outputs[0].text:
                                flag = 1
                                flag_2 = 1
                                logging.info(f"WRONG: {stops}")
                        if flag == 0:
                            break
                        else:
                            for stops in stop_word:
                                if stops in output[0].outputs[0].text:

                                    # Multi round
                                    # NUM_IGNORE = 0
                                    logging.info(f"Number: {_}")
                                    count += 1
                                    question_part = output[0].outputs[0].text.split(stops)[0] + stops
                                    prompt_find = hint + "\nStory:\n" + story + "\n" + question_part
                                    sampling_params = SamplingParams(
                                        max_tokens=max_tokens,
                                        temperature=temperature, 
                                        # top_p=top_p, 
                                        # top_k=top_k, 
                                        repetition_penalty=repetition_penalty,
                                        # stop_token_ids=stop_token_ids,
                                    )
                                    output_2 = llm.generate(
                                        prompt_find,
                                        sampling_params=sampling_params
                                    )
                                    logging.info(f"output_2: {output_2[0].outputs[0].text}")
                                    if "Answer:" in output_2[0].outputs[0].text:
                                        answer_xu = output_2[0].outputs[0].text.split("Answer:")[-1].strip()
                                    else:
                                        answer_xu = "overthinking"
                                    if bool(re.search(r'\d+',answer_xu)):
                                        number_strings = re.findall(r'-?\d+', answer_xu)
                                    else:
                                        number_strings = [last]
                                    for num_str in number_strings:
                                        integers = int(num_str)
                                        logging.info(f"select_hint: {integers}")
                                        break
                                    integers -= 1
                                    logging.info(f"sample_id:{sample_id}")
                                    logging.info(f"int:{integers}")
                                    logging.info(f"last:{last}")
                                    if integers >= last:
                                        integers = last
                                    if integers < 0:
                                        integers = last
                                    sample_id_prompt = next((entry[f"prompt_{integers}"] for entry in pros["data"] if entry["sample_id"] == sample_id), "Prompt_2 not found!")
                                    prompt_2 = prompts + question_part + f". Wait, perhaps I'm over thinking. This is the state description after story sequence number {integers + 1}. Here I can use some of the following information: \n" + sample_id_prompt
                                    sampling_params = SamplingParams(
                                        max_tokens=max_tokens,
                                        temperature=temperature, 
                                        # top_p=top_p, 
                                        # top_k=top_k, 
                                        repetition_penalty=repetition_penalty,
                                        # stop_token_ids=stop_token_ids,
                                    )
                                    output = llm.generate(
                                        prompt_2,
                                        sampling_params=sampling_params
                                    )
                                    sum += len(output[0].outputs[0].text)
                                    # logging.info(f"Prompts_2: {prompts}")
                                    if "Answer:" in output[0].outputs[0].text:
                                        answer_2 = output[0].outputs[0].text.split("Answer:")[-1].strip()
                                    else:
                                        answer_2 = "Overthinking"
                                    logging.info(f"Answer_res: {answer_2}")
                                    break
                    
                    prompt_an1 = f"""\
        [Question: {question}.]

        ***[Response Answer: {answer_1}]***

        ***[Correct Answer: {answer}]***

        Only based on the ***[Correct Answer]***, judge whether the ***[Response Answer]*** is correct. If the two answers are same or have the difference in format only, output 'True' only. Else 'False' only. Don't speak other words. 
        """
                    sampling_params = SamplingParams(
                        max_tokens=max_tokens,
                        temperature=0.0, 
                    )
                    output_an1 = llm.generate(
                        prompt_an1,
                        sampling_params=sampling_params
                    )
                    judge_an1 = output_an1[0].outputs[0].text
                    logging.info(f"Answer_1: {answer_1}")
                    logging.info(f"Answer: {answer}")
                    logging.info(f"judge_an1: {judge_an1}")
                    if "True" in judge_an1:
                        correct_an1 = True
                    else:
                        correct_an1 = False

                    if flag_2 == 1:
                        logging.info(f"count: {count}")
                        len_2 = sum // count
                        prompt_an2 = f"""\
                    [Question: {question}.]

                    ***[Response Answer: {answer_2}]***

                    ***[Correct Answer: {answer}]***

                    Only based on the ***[Correct Answer]***, judge whether the ***[Response Answer]*** is correct. If the two answers are same or have the difference in format only, output 'True' only. Else 'False' only. Don't speak other words. 
                    """
                        sampling_params = SamplingParams(
                            max_tokens=max_tokens,
                            temperature=0.0, 
                        )
                        output_an2 = llm.generate(
                            prompt_an2,
                            sampling_params=sampling_params
                        )
                        judge_an2 = output_an2[0].outputs[0].text
                        if "True" in judge_an2:
                            correct_an2 = True
                        else:
                            correct_an2 = False
                        if correct_an1 and correct_an2:
                            corr_an1_corr_an2 += 1
                        elif correct_an1 and not correct_an2:
                            corr_an1_wrong_an2 += 1
                        elif not correct_an1 and not correct_an2:   
                            wrong_an1_wrong_an2 += 1
                        elif not correct_an1 and correct_an2:
                            wrong_an1_corr_an2 += 1
                        logging.info(f"Answer_1: {answer_1}")
                        logging.info(f"Answer_2: {answer_2}")
                        logging.info(f"Answer: {answer}")
                        logging.info(f"correct_an1: {correct_an1}")
                        logging.info(f"correct_an2: {correct_an2}")
                        logging.info(f"corr_an1_corr_an2: {corr_an1_corr_an2}")
                        logging.info(f"corr_an1_wrong_an2: {corr_an1_wrong_an2}")
                        logging.info(f"wrong_an1_wrong_an2: {wrong_an1_wrong_an2}")
                        logging.info(f"wrong_an1_corr_an2: {wrong_an1_corr_an2}")
                    
                    logging.info(f"num: {sample_id + 1}")
                    logging.info(f"len_1: {len_1}")
                    logging.info(f"len_2: {len_2}")
                    logging.info(f"correct_an1: {correct_an1}")

                    if flag_2 == 1:
                        writer.writerow([sample_id + 1, len_1, len_2, correct_an1, correct_an2, count])
                        logging.info(f"correct_an2: {correct_an2}")
                        if correct_an2:
                            correct = True
                            correctNum += 1
                        else:
                            correct = False
                    else:
                        writer.writerow([sample_id + 1, len_1, len_2, correct_an1, "None", count])
                        logging.info(f"correct_an2: None")
                        if correct_an1:
                            correct = True
                            correctNum += 1
                        else:
                            correct = False
                    totalNum += 1

                    print(f"\n### Correct: {correct} ###\n")

                    # if not correct:
                    #     # This means the model got it wrong.
                    logging.info(f"Index: {sample_id}")
                    logging.info(f"Story: {story}")
                    logging.info(f"Question: {question}")
                    logging.info(f"Label: {answer}")
                    logging.info(f"**********Correct**********: {correct}")
                    logging.info(f"Story_Type: {story_length}")
                    logging.info(f"Question_Type: {question_order}")
                    logging.info("-------------------------------------------")
                    logging.info("-------------------------------------------")

                    # Calculate category result
                    temp = category_results.get("count"+"_"+str(story_length)+"_"+str(question_order), {"correct": 0, "total" : 0})
                    if correct:
                        temp["correct"] += 1
                    temp["total"] += 1
                    percent = temp["correct"] / temp["total"]
                    category_results["count"+"_"+str(story_length)+"_"+str(question_order)] = temp
                    category_percents[str(story_length)+"_"+str(question_order)] = percent
                    
                    
    accuracy = correctNum / totalNum
    print(correctNum)
    print(totalNum)
    print(f"Accuracy: {accuracy*100:.3f}%")

    logging.info(f"correctNum: {correctNum}")
    logging.info(f"corr_an1_corr_an2: {corr_an1_corr_an2}")
    logging.info(f"corr_an1_wrong_an2: {corr_an1_wrong_an2}")
    logging.info(f"wrong_an1_wrong_an2: {wrong_an1_wrong_an2}")
    logging.info(f"wrong_an1_corr_an2: {wrong_an1_corr_an2}")
    logging.info(f"category_results: {category_results}")
    logging.info(f"category_percents: {category_percents}")
    logging.info(f"Accuracy: {accuracy*100:.3f}%")
    
    # Print results
    print("\n------------------------")
    print("         RESULTS        ")
    print("------------------------")
    print(f"ACCURACY: {accuracy:.2%}")
    print("------------------------\n")
    

if __name__ == "__main__":
    main()
