from transformers import AutoTokenizer
import torch
import json
import os
from openai import OpenAI
from tqdm import tqdm
import re
from vllm import LLM, SamplingParams
import argparse
from kialo.ds import OpinionCounting, OpinionMatching, PolarityCheck
from kialo.tokenizing import processing_message
from global_vars import *
import openai
from utils import *
import random

def create_prompt(data, tokenizer):
    out_prompt = []
    messages = []
    for item in data:
        dialog = {
                "messages": [
                    {"role": "system", "content": PROMPT.opinion_counting.format(topic=item["topic"])},
                    {"role": "user", "content": item["concatenated_opinions"]},
                    {"role": "assistant", "content": "Your answer: "} # add rating to all the response, if test, we need to empty the rating
                ]
            }
        messages.append(dialog["messages"])
        if "falcon" in tokenizer.name_or_path.lower():  # since when we get the baseline, chat template did not work 
            formatted_prompt = dialog["messages"][0]["content"] + "\n" + "Sentence: "+ dialog["messages"][1]["content"] + "\n" + dialog["messages"][2]["content"]
        elif "qwen" in tokenizer.name_or_path.lower():
            formatted_prompt = tokenizer.apply_chat_template(
                dialog["messages"],
                tokenize=False,
                add_generation_prompt=False,
                enable_thinking=True
            )
        else:
            formatted_prompt = tokenizer.apply_chat_template(
                dialog["messages"],
                tokenize=False,
                add_generation_prompt=False
            )
        out_prompt.append(formatted_prompt)
    return out_prompt, messages


def VLLM_generation(model, prompts, temperature, eos_token, max_tokens=2048):
    sampling_params = SamplingParams(
        temperature=temperature,
        max_tokens=max_tokens,
        stop=[eos_token] 
    )
    with torch.no_grad():
        outs = model.generate(
            prompts=prompts,
            sampling_params=sampling_params,
        )
    generated_texts = [x.outputs[0].text for x in outs]
    return generated_texts

def chat_gpt_generation(prompts):
    client = OpenAI(api_key="")
    out_response = []
    for prompt in tqdm(prompts):
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=prompt,
            temperature=0.8 # 使用较低的temperature以获得更确定的答案
        )
        response = completion.choices[0].message.content
        out_response.append(response)
    return out_response

def llm_response_extraction(output):
    
    prompt = "You are a helpful assistant that extracts numerical output from model responses. Given the following model response, please identify and return the rating value the model has assigned to the sentence. The rating should be an integer number from 1 to 5. Only return the number itself, without any explanation or extra words.\nModel response:\n{model_output}\n\nRating:"

    prompt = prompt.format(model_output=output)
    client = OpenAI(api_key="")
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are a helpful assistant for answering questions."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1  # 使用较低的temperature以获得更确定的答案
        )
        
        response = completion.choices[0].message.content
        return response
        
    except Exception as e:
        print(f"Error occurred: {e}")
        return None


def main(model_name, task_name, tensor_parallel_size=4):
    set_seed()
    save_root = os.path.join("", task_name)


    if not os.path.exists(save_root):
        os.makedirs(save_root)
    saved = []

    if task_name == "opinion_counting":
        ds = OpinionCounting()
        ds.get_dataset()
    elif task_name == "opinion_matching":
        ds = OpinionMatching()
        ds.get_dataset()
    elif task_name == "polarity_check":
        ds = PolarityCheck()
        ds.get_dataset()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = LLM(
        model=model_name,
        enable_lora=False,  # 即使开启，也不加载adapter
        max_model_len=2048,
        gpu_memory_utilization=0.95,
        trust_remote_code=True,
        tensor_parallel_size=tensor_parallel_size,
        dtype="float16",
    )
    prossed_prompt = []
    for item in ds.ds:
        processed_prompt = processing_message(item, tokenizer)
        prossed_prompt.append(processed_prompt)

    out_response = VLLM_generation(model, prossed_prompt, 0.8, tokenizer.eos_token)

    results, saved = ds.evaluation(out_response)
    with open(os.path.join(save_root, f"{model_name.split('/')[-1]}.json"), "w") as f:
        json.dump(saved, f)
    return results

def main_chat_gpt(model_name, task_name):
    set_seed()
    save_root = os.path.join("", task_name)

    if not os.path.exists(save_root):
        os.makedirs(save_root)

    saved = []

    if task_name == "opinion_counting":
        ds = OpinionCounting()
        ds.get_dataset()
    elif task_name == "opinion_matching":
        ds = OpinionMatching()
        ds.get_dataset()
    elif task_name == "polarity_check":
        ds = PolarityCheck()
        ds.get_dataset()

    client = OpenAI(api_key="")
    outs = []
    for item in tqdm(ds.ds):
        reformulated_message = reformulate_dialog(item["dialog"])
        try:
            completion = client.chat.completions.create(
                model=model_name,
                messages=reformulated_message,
                temperature=0.8  # 使用较低的temperature以获得更确定的答案
            )
        
            response = completion.choices[0].message.content
            outs.append(response)
        except Exception as e:
            print(f"Error occurred: {e}")
            outs.append(None)

    results, saved = ds.evaluation(outs)
    with open(os.path.join(save_root, f"{model_name}.json"), "w") as f:
        json.dump(saved, f)
    return results

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task_name", type=str, default=TASK_NAME.POLARITY_CHECK)
    args = parser.parse_args()
    models = [
        MODEL_PATHS.LLAMA_3_1_8B_INSTRUCT,
        MODEL_PATHS.FALCON_3_7B_INSTRUCT,
        MODEL_PATHS.QWEN_3_8B,
        MODEL_PATHS.QWEN_3_32B,
        MODEL_PATHS.QWEN_2_5_7B_INSTRUCT,
        MODEL_PATHS.DEEPSEEK_R1_DISTILL_LLAMA_8B,
        MODEL_PATHS.DEEPSEEK_R1_DISTILL_QWEN_7B,
        MODEL_PATHS.DEEPSEEK_R1_DISTILL_QWEN_32B,
        MODEL_PATHS.QWEN_2_5_32B,
        MODEL_PATHS.QWEN_QWQ_32B
    ]
    api_models = [
        "gpt-4o-mini",
        "gpt-4o"
    ]
    out = {}
    for model in models:
        out[model.split("/")[-1]] = main(model, args.task_name)

    for model in api_models:
        out[model] = main_chat_gpt(model, args.task_name)
    


    for key, value in out.items():
        print(f"model: {key} got result: {value}")
        