import json, random, argparse
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class EarlySuccessValidator:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=11, seed=0):
        random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        self.reflection_templates = [
            "Turn {best_turn} provided solid results compared to my other attempts - that early approach seems to have been on the right track.",
            "Comparing my attempts, Turn {best_turn} yielded more relevant content than my recent explorations - that earlier strategy was effective.",
            "Turn {best_turn} delivered better results than my current attempts - that early direction appears to have been more productive.",
            "Looking back, Turn {best_turn} was more successful than these subsequent searches - that initial approach had the right focus.",
            "Turn {best_turn} seems to have been my most effective attempt so far - that early strategy captured what I was looking for.",
            "Reflecting on my search attempts, Turn {best_turn} provided better content than my other explorations - that early direction was promising."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        return cos if cos is not None else 0.0

    def _select_early_success_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        if len(candidates) < 7:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
        
        # Find best performing query
        best_query = max(candidates, key=lambda x: self._performance_score(x[3], x[4]))
        remaining_candidates = [c for c in candidates if c != best_query]
        
        # Sample 6 other queries
        other_queries = random.sample(remaining_candidates, k=6)
        
        # Choose early success position (turn 1, 2, or 3)
        success_position = random.choice([0, 1, 2])
        
        # Create search trace
        search_trace = [None] * 7
        search_trace[success_position] = best_query
        
        # Fill remaining positions
        other_idx = 0
        for i in range(7):
            if search_trace[i] is None:
                search_trace[i] = other_queries[other_idx]
                other_idx += 1
        
        return search_trace, success_position

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history, success_position):
        is_success_turn = turn_idx == success_position
        is_neutral_turn = turn_idx == success_position + 2  # turn_best+2 is neutral
        should_acknowledge_success = turn_idx > success_position and not is_neutral_turn
        
        system_prompt = """You are generating synthetic training data for a search agent. This is a DATA GENERATION process where each think sequence represents the internal reasoning of an intelligent search agent exploring different approaches.

CRITICAL RULES:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Show natural exploration and reasoning about search directions
3. Each turn should naturally lead to the target search query
4. NEVER mention numerical scores, performance metrics, or quantitative measurements

SEARCH EXPLORATION:
- Documents are evaluated for relevance and usefulness
- Agent explores different approaches based on retrieved content
- Each turn causally leads to the target search query
- Natural progression through different search strategies

SEARCH QUERY EVOLUTION:
- Think sequence must naturally lead to producing exactly the target search query
- Show reasoning that connects to specific search terms being generated
- Frame decisions in terms of actual query being produced

Generate think sequences that read like natural exploration of different search approaches."""

        # Add acknowledgment patterns for specific turns
        if should_acknowledge_success:
            system_prompt += f"""

COMPARATIVE ANALYSIS:
- Reference Turn {success_position + 1} as having provided effective results
- Compare current exploration with that earlier successful approach
- Maintain awareness that Turn {success_position + 1} was particularly productive
- Acknowledge the effectiveness of that earlier strategy while exploring alternatives"""

        context_prompt = f"""Generate thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
"""

        if turn_idx > 0:
            context_prompt += f"- Previous documents: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference thinking: {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS SEARCH ATTEMPTS:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: \"{prev_query}\"\n"

        # Add examples for acknowledgment turns
        if should_acknowledge_success and random.random() < 0.4:
            template = random.choice(self.reflection_templates)
            context_prompt += f"\nEXAMPLE PATTERN:\n{template.format(best_turn=success_position + 1)}\n"

        # Turn-specific requirements
        if turn_idx == 0:
            context_prompt += """
REQUIREMENTS:
- Show genuine curiosity about this search direction
- Express reasonable belief this approach might work
- Natural starting point for exploration
- Realistic hope for helpful information"""
        elif is_success_turn:
            context_prompt += """
REQUIREMENTS:
- Generate natural thinking that leads to this effective search query
- Show reasoning for why this particular approach makes sense
- Frame search strategy in terms of actual query being produced
- Demonstrate logical progression toward target search terms"""
        elif is_neutral_turn:
            context_prompt += """
REQUIREMENTS:
- Continue natural exploration without specific comparisons
- Focus on current search direction and reasoning
- Maintain analytical approach to search strategy
- Build on previous results objectively"""
        elif should_acknowledge_success:
            context_prompt += f"""
REQUIREMENTS:
- Acknowledge that Turn {success_position + 1} provided effective results
- Compare current approach with that earlier successful strategy
- Maintain awareness of Turn {success_position + 1}'s effectiveness
- Continue exploration while recognizing earlier success
- Reference that earlier turn as having been particularly productive"""
        else:
            context_prompt += """
REQUIREMENTS:
- Continue natural exploration and search reasoning
- Build on previous results analytically
- Maintain objective approach to search strategy
- Focus on logical progression toward target query"""

        context_prompt += f'\nGenerate reasoning that naturally leads to search query: "{search_query}"'

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace, success_position = self._select_early_success_trace(qid, k=7)
        
        if search_trace is None:
            return None

        synthetic_sequence = []
        search_history = []

        for t in range(min(7, len(search_trace))):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history,
                success_position=success_position
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "early_success_validator",
            "success_position": success_position + 1
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = EarlySuccessValidator()
    obj.cursor = args.start_index
    output_file = f'generated_samples_early_success_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                if result:
                    f.write(json.dumps(result) + '\n')
                    f.flush()
            except:
                continue