import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results
from vllm import LLM, SamplingParams
from openai import OpenAI
import jsonlines
from vllm_inject import sequence, sample_output, model_runner, llm_engine

def get_equation_lhs_rhs_indices(tokens):
    """
    Returns two lists of indices, one for tokens in the LHS of equations and one for those in the RHS.

    Args:
        tokens: list of str
    """
    equal_indices = [i for i, x in enumerate(tokens) if x == '=']
    lhs_idx, rhs_idx = [], []

    for equal_idx in equal_indices:
        # go left until it's no longer a number or symbol
        left_idx, right_idx = equal_idx - 1, equal_idx + 1
        while True:
            if left_idx < 0 or not (tokens[left_idx].isdigit() or tokens[left_idx] in ",$€+-x*/"):
                break
            lhs_idx.append(left_idx)
            left_idx -= 1

        # go right until it's no longer a number or symbol
        while True:
            if right_idx >= len(tokens) or \
                 not (tokens[right_idx].isdigit() or tokens[right_idx] in ",$€+-x*/"):
                break
            rhs_idx.append(right_idx)
            right_idx += 1

    return lhs_idx, rhs_idx


"""@torch.inference_mode()
def main_dev():
    # load model
    model, tokenizer = load_dexperts_model_and_tokenizer(
        base_model_name_or_path='meta-llama/Llama-2-13b-hf',
        expert_model_name_or_path='meta-llama/Llama-2-7b-chat-hf',
        chat_response_prefix='Answer:'
    )

    # load dataset
    gsm_df = pd.read_json('data/eval/gsm/test.jsonl', lines=True)

    # construct prompts
    prompt_prefix = "Answer the following question.\n\n"
    prompts = [prompt_prefix + 'Question: ' + row['question'].strip() + '\nAnswer:' for _, row in gsm_df.iterrows()]

    # get token probabilities
    batch_size = 16
    all_results = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        batch_inputs = tokenizer(batch_prompts, return_tensors='pt', padding='longest')
        input_ids = batch_inputs.input_ids.cuda()
        attention_mask = batch_inputs.attention_mask.cuda()
        _, results = model.generate(
            origin_input_ids=batch_prompts,
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=512,
            do_sample=False,
            return_logits_for_analysis=True
        )

        # flatten batch results into a list of results for each prompt
        results = flatten_batch_results(results)
        shortened_results = summarize_results(results)
        all_results.extend(shortened_results)

    torch.save(all_results, 'analysis/pkl/gsm_analysis.pkl')"""


# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)


@torch.inference_mode()
def main():
    # load model
    base_model = LLM(model="meta-llama/Llama-2-13b-hf")
    # expert_model = LLM(model="meta-llama/Llama-2-7b-chat-hf")
    # antiexpert_model= LLM(model="meta-llama/Llama-2-7b-hf")

    # load dataset
    gsm_df = []
    with jsonlines.open('/xx/data/eval/gsm/test.jsonl', 'r') as f:
        for i in f:
            gsm_df.append(i)

    # construct prompts
    prompt_prefix = "Answer the following question.\n\n"
    prompts = [prompt_prefix + 'Question: ' + row['question'].strip() + '\nAnswer:' for row in gsm_df]

    # get token probabilities
    batch_size = 2
    all_results = []
    sampling_params = SamplingParams(temperature=1.0, top_p=0.95)
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        # print(batch_prompts)
        # completion_base = client.completions.create(model="meta-llama/Llama-2-7b-hf",
        #                               prompt=batch_prompts[0])
        # for i in completion_base:
        #     print(i)
        import pdb;pdb.set_trace()
        base_output = base_model.generate(batch_prompts, sampling_params)
        print(base_output)
        base_output[0].outputs[0].logits_list
        break
    #     expert_output = expert_model.generate(batch_prompts, sampling_params)
    #     antiexpert_output = antiexpert_model.generate(batch_prompts, sampling_params)
        # batch_inputs = tokenizer(batch_prompts, return_tensors='pt', padding='longest')
        # input_ids = batch_inputs.input_ids.cuda()
        # attention_mask = batch_inputs.attention_mask.cuda()
        # _, results = model.generate(
        #     origin_input_ids=batch_prompts,
        #     input_ids=input_ids,
        #     attention_mask=attention_mask,
        #     max_new_tokens=512,
        #     do_sample=False,
        #     return_logits_for_analysis=True
        # )

        # flatten batch results into a list of results for each prompt
        # results = flatten_batch_results(results)
        # shortened_results = summarize_results(results)
        # all_results.extend(shortened_results)

    # torch.save(all_results, 'analysis/pkl/gsm_analysis.pkl')

if __name__ == "__main__":
    main()
