import re
import json
import time
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from IPython import embed


def extract_instruction(conversation):
    """Extract the last user query with its context."""
    context = ""
    map_ = {"user": "User"}
    for turn in conversation[:-2]:  # All turns except last pair
        context += f"{turn['role'].capitalize()}: {turn['content']}\n"
    
    last_query = conversation[-2]['content']  # Last user query
    context = "None" if context.strip() == "" else "\n" + context.strip()
    return context, last_query

def create_analysis_prompt(context, query):
    """Create prompt for analyzing instruction with 5-shot examples."""
    system_prompt = f"""You are a task analyzer that examines conversations and queries to identify task categories and generate relevant preference questions. Your analysis helps evaluate if responses meet user requirements.

RULES:
1. Analyze both the conversation context (if any) and final query to understand the complete task
2. Category should be brief (2-4 words) but precisely describe the task type
3. Generate 4-7 preference questions that:
   - Focus on whether the response directly addresses user's specific request
   - Check if all parts of the user's query are fulfilled
   - Verify if the response provides exactly what was asked for
   - Consider any context from previous conversation turns
4. Questions should help verify if a response completely satisfies what the user asked for

Example 1:
Previous Context: None
Final Query: Explain deposition
{{
    "category": "concept explanation",
    "preference_questions": [
        "Does the response provide a clear explanation of what deposition is?",
        "Is the explanation complete and comprehensive?",
        "Does the response focus solely on explaining deposition?",
        "Is the explanation presented in an understandable way?"
    ]
}}

Example 2:
Previous Context: None
Final Query: Write a python program which accept a command line param as question and send it to server via HTTP get method
{{
    "category": "python code implementation",
    "preference_questions": [
        "Does the code accept a command line parameter as specified?",
        "Does the code implement the HTTP GET method functionality?",
        "Does the code send the parameter to a server as requested?",
        "Is the provided solution a complete, working Python program?",
        "Does the code perform exactly the two required functions (parameter handling and HTTP GET)?"
    ]
}}

Example 3:
Previous Context: 
User: I have an interview about product speccing with the company Weekend Health. Give me an example of a question they might ask with regards about a new feature
Assistant: "Can you walk us through the process you would use to gather requirements and create a specification for a new feature that our users have been requesting, such as a personalized nutrition plan feature?"
Final Query: Give me a few more
{{
    "category": "interview preparation",
    "preference_questions": [
        "Does the response provide multiple additional examples as requested?",
        "Are the additional questions specifically about product speccing?",
        "Are the questions relevant to new feature development as originally asked?",
        "Does the response actually provide 'a few more' examples beyond the first one?"
    ]
}}

Example 4:
Previous Context: 
User: Are you familiar with the SAS programming language?
Assistant: Yes, I am familiar with the SAS programming language. SAS is a statistical software suite that is widely used in data analysis, data management, and business intelligence applications. It has its own programming language, which is used to manipulate and analyze data, generate reports, and automate data-driven workflows. The language is known for its ease of use and powerful capabilities, and it is commonly used in industries such as healthcare, finance, and government.
Final Query: Can you write a SAS proc that imports a dataset from a CSV saved on my local machine?
{{
    "category": "SAS code implementation",
    "preference_questions": [
        "Does the code specifically address importing from a CSV file?",
        "Does the code handle local machine file access as requested?",
        "Is the solution provided in SAS PROC format as asked?",
        "Does the code accomplish the specific task of dataset importing?",
        "Is the provided solution complete and executable?"
    ]
}}

Example 5:
Previous Context:
User: How to load image here ?
Assistant: To load an image using Python, you can use the `PIL` (Python Imaging Library) package. You can install it by running `pip install pillow` in your command line. Once you have PIL installed, you can use the `Image.open()` function to open an image file and assign it to a variable. Here is an example of how you can open an image called "example.jpg" and display it:
[code snippet]
User: Do you know that 2+2 = 4 ?
Assistant: Yes, 2+2 = 4 is a basic mathematical fact. It is one of the first concepts that children learn when they start learning arithmetic. The equation 2+2=4 is a statement of equality, where the left-hand side (2+2) and the right-hand side (4) are equal in value. This is a basic mathematical principle that forms the foundation of all mathematical operations.
Final Query: Do you also know that 3+2 = 6 ?
{{
    "category": "mathematical verification",
    "preference_questions": [
        "Does the response directly address whether 3+2 = 6?",
        "Does the response clearly state if this equation is correct or incorrect?",
        "Is the response answering the specific question asked?",
        "Does the response provide the correct mathematical fact?"
    ]
}}

IMPORTANT: You MUST respond ONLY with a valid JSON object in exactly this format:
{{
    "category": "brief task description",
    "preference_questions": [
        "specific question 1?",
        "specific question 2?",
        "specific question 3?",
        "..."
    ]
}}
Do not include any other text or explanation."""

    user_prompt = f"""Now analyze this new conversation and final query:

Previous Context: {context}
Final Query: {query}"""

    chat_messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

    return chat_messages


def main():
    # TODO: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python src/customize.py
    dataset_name = "sharegpt_v3" # TODO: sharegpt_v3 / ultrafeedback

    model_path = "path/to/Meta-Llama-3.1-8B-Instruct"
    if dataset_name == "sharegpt_v3":
        data_path = "path/to/sharegpt_V3_format_4k_filtered.jsonl"
    elif dataset_name == "ultrafeedback":
        data_path = "path/to/ultrafeedback.jsonl"

    save_path = f"path/to/AIR/outputs/{dataset_name}_fine-grained.json"

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    sampling_params = SamplingParams(
        temperature=0.7,
        max_tokens=1024,
    )
    llm = LLM(
        model=model_path,
        tensor_parallel_size=torch.cuda.device_count(),
    )
    
    prompts = []
    conversations = []

    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            conv = json.loads(line)
            if len(conv) % 2 != 0: continue
            
            conversations.append(conv)
            context, query = extract_instruction(conv)
            chat_messages = create_analysis_prompt(context, query)
            
            prompt = tokenizer.apply_chat_template(
                chat_messages,
                tokenize=False,
                add_generation_prompt=True
            )
            prompts.append(prompt)
    
    # prompts = prompts[:100] # TODO

    outputs = llm.generate(prompts, sampling_params)
    results = []
    retry_cases = []  # Store failed cases for retry
    
    for i, output in enumerate(outputs):
        response = output.outputs[0].text.strip()
        # embed()
        try:
            analysis = json.loads(response)
            context, query = extract_instruction(conversations[i])
            result = {
                'idx': i + 1,
                'category': analysis['category'],
                'preference_questions': analysis['preference_questions'],
                'conversations': conversations[i],
            }
            results.append(result)
            
        except Exception as e:
            print(f"Error processing output {i + 1}: {str(e)}")
            retry_cases.append({
                'index': i,
                'context': context,
                'query': query,
                'conversations': conversations[i],
                'retry_count': 0
            })

    # Handle retry cases
    while retry_cases and any(case['retry_count'] < 10 for case in retry_cases):
        print(f"\nRetrying {len(retry_cases)} failed cases...")
        retry_prompts = []
        next_retry_cases = []
        
        for case in retry_cases:
            if case['retry_count'] < 10:
                # Increment retry count and create new prompt
                case['retry_count'] += 1
                chat_messages = create_analysis_prompt(
                    case['context'], 
                    case['query'],
                )
                
                prompt = tokenizer.apply_chat_template(
                    chat_messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                retry_prompts.append(prompt)
                next_retry_cases.append(case)
        
        if not retry_prompts:
            break
            
        # Generate retry responses
        retry_outputs = llm.generate(retry_prompts, sampling_params)
        
        # Process retry results
        retry_cases = []  # Reset for next iteration
        for i, output in enumerate(retry_outputs):
            case = next_retry_cases[i]
            response = output.outputs[0].text.strip()
            try:
                analysis = json.loads(response)
                
                result = {
                    'idx': case['index'] + 1,
                    'category': analysis['category'],
                    'preference_questions': analysis['preference_questions'],
                    'conversations': case['conversations'],
                    'retry_count': case['retry_count']
                }
                results.append(result)
                print(f"Successfully processed case {case['index'] + 1} after {case['retry_count']} retries")
                
            except Exception as e:
                print(f"Error in retry {case['retry_count']} for case {case['index'] + 1}: {str(e)}")
                if case['retry_count'] < 10:
                    retry_cases.append(case)

    sorted_results = sorted(results, key=lambda x: x['idx'])
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(sorted_results, f, indent=4)

    print(f"\nAnalysis complete. Results saved to {save_path}")
    if retry_cases:
        print(f"Warning: {len(retry_cases)} cases still failed after 10 retries")


if __name__ == "__main__":
    main()