import os
import argparse
import json
import re
import jsonlines
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
import pandas as pd
MAX_INT = sys.maxsize
from huggingface_hub import login

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        pass
    try:
        import unicodedata
        unicodedata.numeric(s)
        return True
    except (TypeError, ValueError):
        pass
    return False


def extract_answer_number(completion):
    text = completion.split('The answer is: ')
    if len(text) > 1:
        extract_ans = text[-1].strip()
        # Use regex to find the pattern '[A-Z]. ' (the pattern with an alphabet and a dot after it)
        match = re.search(r'\b([A-Z])\.\s', extract_ans)
        if match:
            return match.group(1).strip()  # Return only the alphabet character like 'D'
    return None

def batch_data(data_list, batch_size=1):
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    last_end = MAX_INT
    batch_data.append(data_list[last_start:last_end])
    return batch_data


def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
    INVALID_ANS = "[invalid]"
    gsm8k_ins = []
    gsm8k_answers = []
    gsm8k_queries = []
    problem_prompt = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:\n"
    )
    print('promt =====', problem_prompt)
    with open(data_path,"r+", encoding="utf8") as f:
        for idx, item in enumerate(jsonlines.Reader(f)):
            temp_instr = problem_prompt.format(instruction=item["query"])
            gsm8k_ins.append(temp_instr)
            gsm8k_queries.append(item["query"])
            temp_ans = item['response'].split('#### ')[1]
            
            temp_ans = temp_ans.split('.')[0].strip()
            # temp_ans = int(temp_ans.replace(',', ''))
            gsm8k_answers.append(temp_ans)

    gsm8k_ins = gsm8k_ins[start:end]
    gsm8k_answers = gsm8k_answers[start:end]
    print('lenght ====', len(gsm8k_ins))
    batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)

    stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response",'<|end_of_text|>','</s>','<s>']
    sampling_params = SamplingParams(temperature=0 ,top_p=1.0, max_tokens=args.seq_len, stop=stop_tokens)
    print('sampleing =====', sampling_params)
    llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)

    print('#### Model was loaded !!!! ####')
    result = []
    res_completions = []
    
    for idx, (prompt, prompt_answer) in enumerate(zip( batch_gsm8k_ins, gsm8k_answers)):
        if isinstance(prompt, list):
            pass
        else:
            prompt = [prompt]

        completions = llm.generate(prompt, sampling_params)
        for output in completions:
            prompt = output.prompt
            generated_text = output.outputs[0].text
            res_completions.append(generated_text)

    
    
    invalid_outputs = []
    gt=[]

    for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
        doc = {'question': prompt}
        y_pred = extract_answer_number(completion)
        gt.append(prompt_answer)
        if y_pred != None:
            result.append(y_pred ==prompt_answer)
        else:
            result.append(False)
            temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
            invalid_outputs.append(temp)
    acc = sum(result) / len(result)
    
    df=pd.DataFrame({'query': gsm8k_ins,'answer':res_completions, 'GT': gt})

    df.to_csv(f'{args.outdir}.csv',index=False)
    
    # print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
    # print('start===', start, ', end====', end)
    
    print('='*50)       
    print('Model:', args.outdir)
    print('Generation was finished')
    print('='*50)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)  # model path
    parser.add_argument("--outdir", type=str,default='output_answer')  # model path
    parser.add_argument("--data_file", type=str, default='')  # data path
    parser.add_argument("--start", type=int, default=0) #start index
    parser.add_argument("--end", type=int, default=MAX_INT)  # end index
    parser.add_argument("--batch_size", type=int, default=400)  # batch_size
    parser.add_argument("--seq_len", type=int, default=512) #start index
    parser.add_argument("--tensor_parallel_size", type=int, default=4)  # tensor_parallel_size
    return parser.parse_args()
if __name__ == "__main__":
    args = parse_args()
    gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
    
    print(f'{args.model} evaluation was finished')