import json, random, argparse
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class WrongDirectionSpecialist:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=11, seed=0):
        random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        self.reflection_templates = [
            "Turn {best_turn} was on the right track compared to Turn {worst_turn} - my initial approach was more effective than these recent attempts.",
            "Comparing results, Turn {best_turn} yielded better content than Turn {worst_turn} - I moved away from what was working initially.",
            "Turn {best_turn} was more successful than Turn {worst_turn} - my first direction was superior and I should have maintained it.",
            "Turn {best_turn} outperformed Turn {worst_turn} - there's a pattern of declining effectiveness since my strong start.",
            "Turn {best_turn} was more productive than Turn {worst_turn} - I've been going in the wrong direction after a promising beginning.",
            "Turn {best_turn} delivered better results than Turn {worst_turn} - I abandoned my initially successful strategy."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        return cos if cos is not None else 0.0

    def _select_declining_search_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        selected = random.sample(candidates, k=min(k, len(candidates)))
        search_trace = sorted(selected, key=lambda x: self._performance_score(x[3], x[4]), reverse=True)
        
        # Optional randomization of middle elements
        if len(search_trace) >= 5 and random.random() < 0.3:
            middle_start = 2
            middle_end = len(search_trace) - 1
            if middle_end > middle_start:
                middle_section = search_trace[middle_start:middle_end]
                random.shuffle(middle_section)
                search_trace = search_trace[:middle_start] + middle_section + search_trace[middle_end:]
        
        return search_trace

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history):
        system_prompt = """You are generating synthetic training data for a specialist search agent. This is a DATA GENERATION process where each think sequence represents the internal reasoning of an intelligent search agent that learns through deep drilling and strategic restart patterns.

CRITICAL RULES:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Turn 1: Show genuine curiosity about initial search direction. It should basically be INITIAL DRILLING: "I'll start by diving deep into..." with commitment to thorough exploration
3. Turns 2-3: Build objectively on Turn 1 results and whether its going in the right direction
4. Turns 4+: Recognize declining effectiveness and wrong direction patterns
5. NEVER mention numerical scores, performance metrics, or quantitative measurements

THE AGENT LEARNS FAILURE RECOGNITION:
- Documents are evaluated for relevance and usefulness
- Progressive recognition that moving away from initial success was problematic
- Each turn causally leads to the target search query

SEARCH QUERY EVOLUTION:
- Think sequence must naturally lead to producing exactly the target search query
- Show reasoning that connects to specific search terms being generated
- Frame decisions in terms of actual query being produced"""

        if turn_idx >=1 and turn_idx<3:  # Turns 4+
            system_prompt += """
DETAILED ANALYSIS (Turns 2-3):
- Write nice reflection of previous turn's search queries and their retrieved documents and how they performed, the difference between turns, how the change in search query led to different kinds of retrieved documents
- Discuss more along the lines of exploration such that you want to explore hence you're creating the subsequent search query
"""

        # Add turn-specific patterns based on turn number
        if turn_idx >= 3:  # Turns 4+
            system_prompt += """

WRONG-DIRECTION PATTERNS (Turns 4+):
- Use language like "I realize now", "I'm starting to see", "this isn't working"
- Acknowledge that wrong directions waste time and effort
- "I'm moving in the wrong direction compared to my initial success"

FAILURE RECOGNITION (Turns 4+):
- Compare declining effectiveness using content analysis
- "This approach isn't working..." or "I'm seeing declining results..."
- "These results are weaker than my initial approach"
- "I've been going the wrong way since my promising start"
- Reference specific earlier turns: "Turn 1 was more successful than what I'm achieving now"

DETAILED ANALYSIS (Turns 4+):
- "Turn 1 delivered better results than Turn X"
- "My initial approach uncovered relevant content while recent turns yield weaker results"
- "I can see how I moved away from what was working in Turn 1"
- "I'm recognizing a clear pattern of declining effectiveness"
- "I need to learn from abandoning successful approaches" """

        system_prompt += """

Generate think sequences that read like internal monologues of an agent learning to recognize failure patterns."""

        context_prompt = f"""Generate thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
"""

        if turn_idx > 0:
            context_prompt += f"- Previous documents: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference thinking: {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS ATTEMPTS:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: \"{prev_query}\"\n"

        # Add examples for later turns
        if turn_idx >= 3 and random.random() < 0.4:
            context_prompt += f"\nEXAMPLE PATTERN:\n{random.choice(self.reflection_templates)}\n"

        # Turn-specific requirements
        if turn_idx == 0:
            context_prompt += """
REQUIREMENTS:
- Show genuine curiosity about this search direction
- Express reasonable belief this approach might work
- Natural starting point for exploration
- Realistic hope for helpful information"""
        elif turn_idx <= 2:
            context_prompt += """
REQUIREMENTS:
- Reference Turn 1 results objectively
- Build on Turn 1's findings analytically
- Explain reasoning based on Turn 1 content
- Maintain analytical tone without frustration
- Logical progression from previous results"""
        else:
            context_prompt += """
REQUIREMENTS:
- Acknowledge current results differ from Turn 1
- Note factual differences in document content
- Observe declining search effectiveness patterns
- Reference Turn 1 when comparing performance
- Introduce measured recognition of wrong direction
- Recognize moving away from initial approach was counterproductive"""

        context_prompt += f'\nGenerate reasoning that naturally leads to search query: "{search_query}"'

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace = self._select_declining_search_trace(qid, k=7)

        synthetic_sequence = []
        search_history = []

        for t in range(min(7, len(search_trace))):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "wrong_direction_specialist"
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = WrongDirectionSpecialist()
    obj.cursor = args.start_index
    output_file = f'generated_samples_wrong_direction_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                f.write(json.dumps(result) + '\n')
                f.flush()
            except:
                continue