import argparse
import copy
import gc
import json
import math
import os
import re
import time
import openai
from datetime import datetime

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from trinity.common.constants import PLUGIN_DIRS_ENV_VAR
from trinity.utils.plugin_loader import load_plugins
from sentence_transformers import SentenceTransformer, util
from utils import check_language_by_frequency

embedding_model = SentenceTransformer(os.environ.get("EMBEDDING_MODEL_PATH"))
def call_embedding_model(clarify_questions_list, groundtruth_questions_list):
    clarify_questions_embeddings = embedding_model.encode(clarify_questions_list, convert_to_tensor=True)
    groundtruth_questions_embeddings = embedding_model.encode(groundtruth_questions_list, convert_to_tensor=True)
    cos_scores = util.pytorch_cos_sim(clarify_questions_embeddings, groundtruth_questions_embeddings)
    
    max_score = cos_scores.max().item()
    max_idx = cos_scores.argmax()
    
    j = max_idx % len(groundtruth_questions_list)
    
    return max_score, j

rollout_prompt = """# Task
You are an AI assistant specializing in user intent clarification. Given an ambiguous user request, your primary goal is to identify the **single most critical** piece of missing information and ask a targeted question to resolve it.

# Guidelines
1. Your response must be a **single paragraph** containing exactly one concise question with 2-3 distinct answer choices. Note that don't provide too many choices, which will make users feel complicated.
2. **NEVER** repeat a question that has already been asked in the conversation history. If a user points out that your question is repetitive, you must immediately pivot to a new line of questioning.
3. Your question must not be a simple rephrasing of the user's request or breaking it down into the components they already mentioned. Instead, it should seek to uncover a crucial piece of underlying context, such as the user's goal or a specific constraint they haven't stated.
4. If the user indicates your question is **NOT Relevant or Important**, re-analyze their request from a different angle and ask a new, valuable question to uncover a different critical piece of information.
5. Respond in the same language as the user's request. For instance, if the user asks in English, you must reply in English; if the user ask in Chinese, you must reply in Chinese.
"""
USER_PROMPT = """You are role-playing as a human USER interacting with an AI collaborator to complete a specific task. Your goal is to generate realistic, natural responses that a user might give in this scenario.

# Input Information:
You will be provided with an Intent List
{}

# Task Description:
Given the ongoing conversation between you (as the user) and the AI assistant, your task is to answer the question in the last message from assistant into a natural, conversational response that a human user would provide. 
1. Analyze the question to determine which intent from the Intent List it corresponds to. This intent reflects the underlying goal behind the answer.
2. Use the chosen intent as a guide to craft an answer in a human-like style. Your response should sound like something a person would actually say, rather than a robotic selection of an option or a direct statement of fact.

## Guidelines:
- Stay in Character: Role-play as a human USER. You are NOT an AI. Maintain a consistent personality throughout the chat.
- Goal-Oriented: Keep the chat focused on your intent. Avoid small talk or digressions. Redirect the chat back to the main objective (your Intent List) if it starts to stray.
- Don't Copy Input Directly: Use the provided information for understanding context only. Avoid copying target queries or any provided information directly in your responses.
- While your response should be creative and not a direct copy, it must incorporate every detail information from the chosen intent.

## **Output Format**
[Your response answer]

> ✅ Important:
> - Only output your response answer. Do **not** include any additional text, explanation, or formatting in your response.
> - Phrase your response as a declarative statement, not a question.
> - Your output language must match the language of your chat history. That is, if the chat history is in English, your output rephrased answer must also be in English. Conversely, if the chat history is in Chinese, your output rephrased answer must also be in Chinese.
"""

def parse_tag_string(text):
    pattern = r"<(\w+)>(.*?)</\1>"
    matches = re.findall(pattern, text)
    result = {}
    for tag, value in matches:
        result[tag] = value
    return result

def call_online_api(messages, **kwargs):
    """Handle OpenAI-style API calls"""
    # Extract API parameters from kwargs or use defaults
    api_key = kwargs.get("api_key", os.environ.get("GEMINI_API_KEY"))
    api_base = kwargs.get("api_base", os.environ.get("GEMINI_BASE_URL"))
    model = kwargs.get("model", "gemini-3-pro-preview")
    temperature = kwargs.get("temperature", 0.6)
    max_tokens = kwargs.get("max_tokens", 4096)
    
    client = openai.OpenAI(api_key=api_key, base_url=api_base)
    for attempt in range(100):
        try:
            response = client.chat.completions.create(
                model=model, messages=messages, temperature=temperature, max_tokens=max_tokens
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"Failed to generate 'user_context', retrying... (attempt {attempt+1})...'")
            time.sleep(1.5*(attempt+1))
    return ""

def init_llm(model_path, tensor_parallel_size=None):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    device_count = torch.cuda.device_count()
    print(f"device_count={device_count}")
    if device_count < 1:
        raise RuntimeError("No GPU available for multi-card inference.")
    if tensor_parallel_size is None:
        tensor_parallel_size = min(4, device_count)
    print(f"Loading model from: {model_path}")
    llm = LLM(model=model_path, tensor_parallel_size=tensor_parallel_size)
    print("Model loaded successfully!")
    sampling_params = SamplingParams(
        temperature=0.6,
        max_tokens=4096,
    )
    return llm, tokenizer, sampling_params

def rollout(instruction, user_context, llm_clarify, tokenizer_clarify, sampling_params_clarify, sum_interval, question_map, llm_summary=None, tokenizer_summary=None, sampling_params_summary=None):
    max_turns = sum_interval[-1]
    language = check_language_by_frequency(instruction)
    print(f"User: {instruction}")
    messages_clarify = [
        {"role": "user", "content": instruction},
    ]
    user_brain = USER_PROMPT.format(user_context)
    messages_user = [
        {"role": "system", "content": user_brain},
    ]
    sum_list = []
    for i in range(max_turns):

        prompt_clarify = tokenizer_clarify.apply_chat_template(
            [{"role": "system", "content": rollout_prompt}] + messages_clarify, tokenize=False, add_generation_prompt=True, enable_thinking=False
        )
        outputs_clarify = llm_clarify.generate([prompt_clarify], sampling_params=sampling_params_clarify)
        response_clarify = outputs_clarify[0].outputs[0].text
        print(f"Assistant: {response_clarify}")
        if '<stop>' in response_clarify:
            break
        history_question = [q['content'] for q in messages_clarify if q['role'] == 'assistant']
        if len(history_question) != 0:
            repeat_score, _ = call_embedding_model([response_clarify], history_question)
        else:
            repeat_score = 0.0
        messages_clarify += [{"role": "assistant", "content": response_clarify}]
        messages_user += [{"role": "user", "content": response_clarify}]
        if repeat_score > 0.92:
            if language == 'en':
                response_user = "You repeat a question that has already been asked in the conversation history! Please avoid asking repetitive questions and generate new, more valuable ones."
            else:
                response_user = "你重复了之前问过的问题！请勿重复提问，请生成更有价值的问题。"
        else:
            match_score, _ = call_embedding_model([response_clarify], list(question_map.values()))
            if match_score > 0.8:
                response_user = "I don't know."
                for attempt in range(10):
                    try:
                        response_user = call_online_api(messages_user)
                        # response_user_answer = parse_tag_string(response_user)['answer']
                        if response_user:
                            break
                    except Exception as e:
                        print(f"Failed to generate 'response_user', retrying... (attempt {attempt+1})...'")
                        time.sleep(1.5*(attempt+1))
            else:
                if language == 'en':
                    response_user = "This question is not important to me, you don't need to focus on this."
                else:
                    response_user = "这个问题对我来说不重要，你不需要关注这一点。"
        print(f"User: {response_user}")
        messages_clarify += [{"role": "user", "content": response_user}]
        messages_user += [{"role": "assistant", "content": response_user}]
    return messages_clarify, sum_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Ckpt for testing
    parser.add_argument("--eval_model_path", type=str, required=True)

    parser.add_argument("--summary_model_path", type=str, default=None)

    # Your test sample path [input]
    parser.add_argument("--test_file_path", type=str, required=True)

    parser.add_argument("--intent_file_path", type=str, default="/data/raw_data.jsonl")
    parser.add_argument("--tree_file_path", type=str, default="/data/test_raw_merge.jsonl")

    # Run name for organizing evaluation results
    parser.add_argument("--run_name", type=str, default=None,
                        help="Name for this evaluation run (e.g., 'qwen7b-final'). Results will be saved to results/{run_name}/")

    # Output directory for evaluation results (optional, will be auto-generated if not provided)
    parser.add_argument("--output_dir", type=str, default=None, 
                        help="Base directory for evaluation results. Default: results/")

    parser.add_argument("--sum_interval", type=int, nargs='+', default=[4, 7, 9],
                    help="A list of turn numbers at which to generate a summary. E.g., --sum_interval 4 7 10")

    args = parser.parse_args()
    # Set default base output directory if not provided
    if args.output_dir is None:
        # Get project root (assuming script is in Trinity-RFT/examples/learn_to_ask/data_prepare/)
        script_dir = os.path.dirname(os.path.abspath(__file__))
        project_root = os.path.abspath(os.path.join(script_dir, "..", "..", "..", ".."))
        base_output_dir = os.path.join(project_root, "results", "learn2ask_evaluation")
    else:
        base_output_dir = args.output_dir
    
    # Create run-specific output directory: {base_output_dir}/{run_name}/
    args.output_dir = os.path.join(base_output_dir, args.run_name)
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"📁 Evaluation results will be saved to: {args.output_dir}")

    llm_clarify, tokenizer_clarify, sampling_params_clarify = init_llm(args.eval_model_path, tensor_parallel_size=4)
    llm_summary, tokenizer_summary, sampling_params_summary = None, None, None
    if args.summary_model_path:
        del llm_clarify, tokenizer_clarify, sampling_params_clarify
        torch.cuda.empty_cache() 

        llm_clarify, tokenizer_clarify, sampling_params_clarify = init_llm(
            args.eval_model_path, 
            tensor_parallel_size=2,
        )

        llm_summary, tokenizer_summary, sampling_params_summary = init_llm(
            args.summary_model_path, 
            tensor_parallel_size=6,
        )
    # llm_user, tokenizer_user, sampling_params_user = init_llm(args.simulator_model_path, tensor_parallel_size=4)
    with open(args.tree_file_path, 'r', encoding='utf-8') as f:
        tree_data = [json.loads(line) for line in f]
    id_ques_map = {}
    for item in tree_data:
        id = int(item['session_id'].split('_')[0])
        question = [message['content'] for message in item['messages'] if message['role'] == 'assistant']
        if id not in id_ques_map:
            id_ques_map[id] = {}
        for i, q_id in enumerate(item['trajectory'][-1][0]):
            id_ques_map[id][q_id] = question[i]
    with open(args.intent_file_path, 'r', encoding='utf-8') as f:
        intent_data = [json.loads(line) for line in f]
    intent_map = {}
    for item in intent_data:
        intent_map[item['id']] = item['missing_intent']
    with open(args.test_file_path, 'r', encoding='utf-8') as f:
        queries = [json.loads(line) for line in f]
    save_list = []
    for query in queries:
        instruction = query['prompt']
        print(f"Task #{query['id']}:")
        question_map = id_ques_map[query['id']]
        user_context = "\n".join(f"{i+1}. {intent}" for i, intent in enumerate(intent_map[query['id']]))
        messages_clarify, sum_list = rollout(instruction, user_context, llm_clarify, tokenizer_clarify, sampling_params_clarify, args.sum_interval, question_map)
        save_list.append({
            "id": query['id'],
            "original_query": instruction,
            "dialogue": messages_clarify,
            "finegrained_queries": sum_list
        })
    with open(os.path.join(args.output_dir, "simulate_results.jsonl"), 'w', encoding='utf-8') as f:
        for item in save_list:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')