import json
import pandas as pd
from pprint import pprint
import numpy as np
import openai
import time
from tqdm import tqdm
import concurrent.futures
import pdb
PROMPT_FILE = "../deep_research_bench/data/prompt_data/query.jsonl"
TRAJECTORY_FILE = "/data/processed_data.jsonl"
QUERY_FILE = "/data/raw_query.jsonl"
OUTPUT_PATH = "/data"

np.random.seed(42) # for reproducibility

TOTAL_SUBSET_SIZE = 50
N_BINS = 4 
RANDOM_STATE = 42
STRATIFY_COLS = ['topic', 'language', 'trajectory_bin'] 

MAX_WORKERS = 3

ASSISTANT_PROMPT = """Transform the given question stem and multiple choices into a natural, flowing paragraph that presents all options seamlessly within the text. Do not use explicit option labels (A, B, C, etc.). Instead, integrate the choices organically into grammatically correct sentences that enhance readability and user comprehension. The output should read as coherent prose rather than a formatted multiple-choice question.

# Original Question
{}

> ✅ Important:
> - **Adopt an AI Assistant Persona:** Your output must be a direct question *to the user*. Frame it as if you are an assistant asking for their instruction or choice. The tone must be helpful and inquisitive, not descriptive or declarative.
> - **Be Direct and Concise:** While being helpful, maintain a focused tone. Omit unnecessary conversational filler or pleasantries (e.g., "Hello," "To help us better...") that do not contribute to the core task of presenting the choice.
> - Preserve all details from the original question and options without any omissions. 
> - Your output language must match the language of the Original Question. That is, if the Original Question is in English, your output must also be in English. Conversely, if the Original Question is in Chinese, your output must also be in Chinese.
"""

USER_PROMPT = """You are role-playing as a human USER interacting with an AI collaborator to complete a specific task. Your goal is to generate realistic, natural responses that a user might give in this scenario.

# Input Information:
You will be provided with:
1. Question
{question}
2. Answer
{answer}
3. Intent List
{intent_list}

# Task Description:
Given a question, you already have a correct answer. Now your task is to rephrase the given answer into a natural, conversational response that a human user would provide. 
1. Analyze the question and the answer to determine which intent from the Intent List it corresponds to. This intent reflects the underlying goal behind the answer.
2. Use the chosen intent as a guide, craft a new response that conveys the meaning of the original answer in a human-like style. Your response should sound like something a person would actually say, rather than a robotic selection of an option or a direct statement of fact.

## Guidelines:
- Stay in Character: Role-play as a human USER. You are NOT an AI. Maintain a consistent personality throughout the chat.
- Goal-Oriented: Keep the chat focused on your intent. Avoid small talk or digressions. Redirect the chat back to the main objective (your Intent List) if it starts to stray.
- Don't Copy Input Directly: Use the provided information for understanding context only. Avoid copying target queries or any provided information directly in your responses.
- While your response should be creative and not a direct copy, it must incorporate every detail information from the chosen intent.

## **Output Format**
[Your rephrased answer]

> ✅ Important:
> - Only output your rephrased answer. Do **not** include any additional text, explanation, or formatting in your response.
> - Phrase your response as a declarative statement, not a question.
> - Your output language must match the language of the Question. That is, if the Question is in English, your output rephrased answer must also be in English. Conversely, if the Question is in Chinese, your output rephrased answer must also be in Chinese.
"""

def call_online_api(messages, **kwargs):
    """Handle OpenAI-style API calls"""
    # Extract API parameters from kwargs or use defaults
    api_key = kwargs.get("api_key", os.environ.get("OPENAI_API_KEY"))
    api_base = kwargs.get("api_base", os.environ.get("OPENAI_BASE_URL"))
    model = kwargs.get("model", "gpt-4.1-2025-04-14")
    temperature = kwargs.get("temperature", 0.6)
    max_tokens = kwargs.get("max_tokens", 4096)
    
    client = openai.OpenAI(api_key=api_key, base_url=api_base)
    for attempt in range(100):
        try:
            response = client.chat.completions.create(
                model=model, messages=messages, temperature=temperature, max_tokens=max_tokens
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"Failed to generate 'user_context', retrying... (attempt {attempt+1})...'")
            time.sleep(1.5*(attempt+1))
    return ""

def split_data(main_data_json, processed_data, query_data_json, trajectory_list):
    df_main = pd.DataFrame(main_data_json)


    for item in processed_data:
        item['trajectory_size'] = len(item.get('trajectory', []))
    df_extra = pd.DataFrame(processed_data)[['id', 'trajectory_size', 'trajectory']]

    merged_df = pd.merge(df_main, df_extra, on='id')


    merged_df['trajectory_bin'] = pd.qcut(
        merged_df['trajectory_size'],
        q=N_BINS,
        labels=[f'Q{i+1}' for i in range(N_BINS)],
        duplicates='drop'
    )
    print(merged_df['trajectory_bin'].value_counts().sort_index())
    print("\n")


    original_distribution = merged_df.groupby(STRATIFY_COLS).size().reset_index(name='original_count')
    total_original_size = len(merged_df)
    original_distribution['proportion'] = original_distribution['original_count'] / total_original_size
    original_distribution['target_float'] = original_distribution['proportion'] * TOTAL_SUBSET_SIZE
    original_distribution['target_int'] = original_distribution['target_float'].astype(int)

    current_sum = original_distribution['target_int'].sum()
    remainder = TOTAL_SUBSET_SIZE - current_sum
    if remainder > 0:
        original_distribution['remainder_val'] = original_distribution['target_float'] - original_distribution['target_int']
        top_remainder_indices = original_distribution.nlargest(remainder, 'remainder_val').index
        original_distribution.loc[top_remainder_indices, 'target_int'] += 1

    sampled_subsets = []
    grouped_df = merged_df.groupby(STRATIFY_COLS)

    for index, row in original_distribution.iterrows():
        group_keys = tuple(row.loc[col] for col in STRATIFY_COLS)
        num_to_sample = row['target_int']
        
        if num_to_sample > 0:
            group = grouped_df.get_group(group_keys)
            sampled_group = group.sample(n=num_to_sample, replace=False, random_state=RANDOM_STATE)
            sampled_subsets.append(sampled_group)

    final_subset_df = pd.concat(sampled_subsets).sort_values(by='id').reset_index(drop=True)



    final_distribution = final_subset_df.groupby(STRATIFY_COLS).size().reset_index(name='final_count')
    comparison_df = pd.merge(
        original_distribution[['topic', 'language', 'trajectory_bin', 'original_count', 'target_int']],
        final_distribution,
        on=STRATIFY_COLS,
        how='left'
    )
    comparison_df['final_count'] = comparison_df['final_count'].fillna(0)
    print(comparison_df.to_string())
    print("\n")


    original_stats = merged_df['trajectory_size'].describe()
    subset_stats = final_subset_df['trajectory_size'].describe()
    stats_comparison = pd.DataFrame({
        'Original (100)': original_stats,
        'Subset (40)': subset_stats
    })
    print(stats_comparison.round(2))
    print("\n")


    output_cols = [col for col in df_main.columns] + ['trajectory']
    final_subset_list = final_subset_df[output_cols].to_dict('records')

    pprint([{'id': item['id'], 'topic': item['topic'], 'language': item['language'], 'trajectory_len': len(item['trajectory'])} for item in final_subset_list])

    id_prefixes = {f"{item['id']}_" for item in final_subset_list}

    training_set = []
    testing_set = []
    for i, sample in enumerate(query_data_json):
        # try:
        sample_id = sample.get("id")
        if not sample_id:
            continue

        sample["trajectory"] = trajectory_list[i]
        if any(sample_id.startswith(prefix) for prefix in id_prefixes):
            training_set.append(sample)
        else:
            testing_set.append(sample)
    return training_set, testing_set

def transform_data_format(original_data: dict) -> dict:
    transformed_data = {
        "session_id": original_data.get("id"),
        "topic": original_data.get("topic"),
        "trajectory": original_data.get("trajectory"),
        "messages": []
    }
    messages = []

    messages.append({
        "role": "user",
        "content": original_data.get("simple_query")
    })

    if "missing_details" in original_data and original_data["missing_details"]:
        for detail, choices in zip(original_data["missing_details"], original_data["choices"]):
            inquiry = detail.get("inquiry", "")
            options = detail.get("options", [])
            assistant_content = f"{inquiry}\n"
            choices_content = ""
            for i, option in enumerate(options):
                assistant_content += f"{chr(ord('A')+i)}. {option}\n"
                if choices:
                    if option in choices:
                        choices_content += f"{chr(ord('A')+i)}. {option}\n"
            assistant_prompt = ASSISTANT_PROMPT.format(assistant_content)
            assistant_message = [{"role": "user", "content": assistant_prompt}]
            refined_question = call_online_api(assistant_message)
            messages.append({
                "role": "assistant",
                "content": refined_question
            })
            print("assistant: ", refined_question)
            if choices:
                intent_list_str = "\n".join(f"- {intent}" for intent in original_data.get("missing_intent", []))
                user_prompt = USER_PROMPT.format(question=assistant_content, answer=choices_content, intent_list=intent_list_str)
                # print(user_prompt)
                user_message = [{"role": "user", "content": user_prompt}]
                refined_answer = call_online_api(user_message)
                messages.append({
                    "role": "user",
                    "content": refined_answer,
                })
                print("user: ", refined_answer)
            # pdb.set_trace()

    summary_content = f'<stop>'
    messages.append({
        "role": "assistant",
        "content": summary_content
    })
    
    transformed_data["messages"] = messages
    
    print(transformed_data)
    return transformed_data

if __name__ == "__main__":
    with open(PROMPT_FILE, 'r', encoding='utf-8') as f:
        main_data_json = [json.loads(line) for line in f]

    with open(TRAJECTORY_FILE, 'r', encoding='utf-8') as f:
        processed_data = [json.loads(line) for line in f]

    with open(QUERY_FILE, 'r', encoding='utf-8') as f:
        query_data_json = [json.loads(line) for line in f]
    trajectory_list = []
    for sample in processed_data:
        trajectory_list.extend(sample['trajectory'])

    training_set, testing_set = split_data(main_data_json, processed_data, query_data_json, trajectory_list)

    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        results_iterator = executor.map(transform_data_format, training_set)
        raw_train_list = list(tqdm(results_iterator, total=len(training_set), desc="Processing training set"))

    with open(OUTPUT_PATH+"/train_raw_wotree.jsonl", 'w', encoding='utf-8') as f:
        for item in raw_train_list:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
