import json, random, argparse
from adaptkeybert import KeyBERT
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class AdaptiveContextLearnerACL:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=5, top_n=4, below5_thr=1.0, min_samples_thr=10, seed=0):
        random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.top_n = top_n
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.kw = KeyBERT()
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        # Randomized cross-turn reflection templates
        self.reflection_templates = [
            "Turn {best_turn} clearly outperformed Turn {worst_turn} because the search query in Turn {best_turn} used more specific terminology like '{best_keywords}', while Turn {worst_turn}'s broader approach with '{worst_keywords}' yielded less relevant documents.",
            
            "Comparing the results, Turn {best_turn} achieved superior document relevance compared to Turn {worst_turn}, likely because Turn {best_turn} incorporated targeted terms like '{best_keywords}' that matched the user's specific needs, whereas Turn {worst_turn}'s strategy with '{worst_keywords}' was too general.",
            
            "The document quality in Turn {best_turn} significantly exceeded that of Turn {worst_turn} - this suggests that the precise keyword combination in Turn {best_turn} ('{best_keywords}') resonated better with relevant sources than the approach taken in Turn {worst_turn} with '{worst_keywords}'.",
            
            "Reflecting on the search progression, Turn {best_turn} produced the most relevant documents so far, contrasting sharply with Turn {worst_turn}'s results; this performance gap appears linked to Turn {best_turn}'s strategic use of '{best_keywords}' versus the less effective '{worst_keywords}' approach in Turn {worst_turn}.",
            
            "Turn {best_turn} stands out as the most successful search attempt compared to Turn {worst_turn}, with the key difference being Turn {best_turn}'s targeted keyword strategy ('{best_keywords}') proving more effective than Turn {worst_turn}'s broader terminology ('{worst_keywords}').",
            
            "Analyzing the search pattern, Turn {best_turn} delivered superior results relative to Turn {worst_turn}, indicating that the focused keyword approach in Turn {best_turn} using '{best_keywords}' was more aligned with document content than Turn {worst_turn}'s strategy with '{worst_keywords}'."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _select_search_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        selected = random.sample(candidates, k=min(k, len(candidates)))
        random.shuffle(selected)
        return selected

    def _extract_keywords_from_query(self, query_text):
        res = self.kw.extract_keywords(query_text, top_n=self.top_n)
        return [r[0] if isinstance(r, (list, tuple)) else r for r in res]

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, keywords, prev_docs, original_thinking, search_history):
        system_prompt = """You are generating synthetic training data for a self-reflective multi-turn search agent. This is a DATA GENERATION process where each think sequence represents the internal reasoning of an intelligent search agent that learns ONLY from analyzing retrieved document content.

CRITICAL RULES FOR THINK SEQUENCE GENERATION:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Think sequences must be COHERENT with previous turns, showing logical progression and self-learning
3. Turn 1 is INITIAL EXPLORATION: "Given the user query, here's how I'll explore this topic..." or a variation of it
4. Turn 2+ are REFLECTIVE + ADAPTIVE: Start with 1-2 long in-depth sentences reflecting on previous retrieved results, then plan next search in detail
5. CROSS-TURN ANALYSIS: For turns 3+, occasionally compare current results with previous turns using specific turn numbers when patterns emerge
6. PERFORMANCE TREND AWARENESS: Acknowledge which previous turns yielded better/worse documents and analyze WHY through keyword/strategy comparison
7. STRATEGIC LEARNING: Reference specific search terms that worked well in previous turns and avoid strategies from poorly performing turns

STRICTLY FORBIDDEN - NEVER MENTION:
- Any numerical scores, float values, or quantitative measurements given as part of turn meta-data
- Ranks, positions, or numerical performance indicators also must not be mentioned
- Any external evaluation metrics the agent cannot observe and thus shouldn't be distilled during datagen, all reflections must be made textually

THE AGENT CAN ONLY SEE AND REASON ABOUT:
- The user's original query
- The actual text content of documents retrieved in previous turns
- The search queries it previously used
- NOTHING ELSE - no numerical feedback

PSEUDO-REWARD GENERATION THROUGH TEXTUAL ANALYSIS:
The agent must evaluate its own performance by comparing document content:
- "The documents I retrieved contain specific details about [topic], which directly answers the user's question"
- "These results discuss broader concepts but miss the specific focus the user needs"
- "I'm getting more targeted information - these documents mention the exact terms from the query"
- "This search direction isn't working - the content covers related but different topics"
- "The previous turn's documents were more helpful because they contained concrete examples of [specific topic]"

REFLECTIVE REASONING PATTERNS (for turns 2+):
- Analyze document content quality and topical relevance from previous turn through textual comparison
- Compare retrieved content usefulness across turns using specific content observations
- Identify WHY document relevance changed through content analysis: keyword precision, topic focus, terminology matches
- Create internal reward signals through content assessment: "This worked because the documents contained..." or "This failed because the content focused on..."
- When documents seem more relevant: "I think I'm getting more relevant documents because they contain specific information about..."
- When documents seem less relevant: "These results are drifting away from what I need because they focus on..."
- Reference specific previous search strategies by analyzing their document content outcomes
- Show learning from document content patterns and strategic course correction
- For turns 3+, you can reflect back to specific turn search queries like "Turn X resulted in more favorable documents maybe due to the use of [keyword1] and [keyword2]"
- Each individual think sequence can be anywhere from 3-6 sentences with detailed observations and understanding with at times more than 120-200 words

HIGHLY DETAILED TURN-WISE ANALYSIS REQUIREMENTS:
- Use specific turn references when appropriate: "Turn 2 seems to have resulted in better documents because...", "Comparing Turn 1 and Turn 3 results..."
- Provide concrete document content comparisons: "The documents in Turn 2 contained specific examples of X, while Turn 1 only provided general principles about Y"
- Trace search evolution patterns: "Moving from the broad approach in Turn 1 to the specific focus in Turn 2 improved document relevance"
- Acknowledge performance trajectories: "I notice the document quality has been improving since Turn 1"
- Reference keyword effectiveness across turns: "The inclusion of [keyword] in Turn 2 yielded better results than the [other keyword] strategy in Turn 1"
- Show strategic learning: "Based on Turn 2's success with specific terminology, I should incorporate similar precision in future searches"

SEARCH QUERY EVOLUTION RULES:
- For the ones where keywords are provided use them strategically based on what document content analysis revealed
- Show semantic navigation toward better document retrieval through content terminology analysis
- Demonstrate vocabulary adaptation based on specific terms and phrases found in retrieved documents
- Balance query specificity vs. breadth based on observed document content relevance patterns

Generate think sequences that read like detailed internal monologues of an intelligent agent learning to search through pure textual analysis and content comparison of retrieved documents."""

        # Build context based only on previous turns that have occurred
        context_prompt = f"""Generate 1 coherent thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
- Keywords available to strategically incorporate: {keywords}
"""

        if turn_idx > 0:
            context_prompt += f"- Documents retrieved from previous turn: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference original thinking (for inspiration): {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS SEARCH HISTORY (for cross-turn analysis):\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: Query=\"{prev_query}\", Thinking=\"{prev_think}\"\n"

        # Randomly add cross-turn reflection examples if we have enough history
        if turn_idx >= 2 and random.random() < 0.5:
            context_prompt += f"\nCROSS-TURN REFLECTION EXAMPLE (use similar pattern when relevant):\n{random.choice(self.reflection_templates)}\n"

        context_prompt += f"""
GENERATION REQUIREMENTS:
- Turn 1: Initial exploration strategy based on user query understanding
- Turn 2+: Start with detailed document analysis reflection from previous turn, then adaptive search planning
- Turn 3+: Occasionally compare with previous turns using specific turn numbers when natural
- Create pseudo-rewards through self-assessment: analyze document relevance, topic alignment, information completeness
- Reference previous turns that yielded better document matches when relevant
- Demonstrate learning from document content patterns and retrieval outcomes
- Use keywords strategically based on observed document quality and topic coverage
- Generate realistic agent reasoning that leads naturally to the target search query through pure self-reflection

Generate a single think sequence that demonstrates natural learning progression with contextual awareness of previous search attempts."""

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        if self.cursor >= len(self.query_ids):
            self.cursor = 0
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace = self._select_search_trace(qid, k=7)

        synthetic_sequence = []
        search_history = []  # Track previous turns for context

        T = min(7, len(search_trace))
        
        for t in range(T):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            
            # Extract keywords for this turn
            if t <= 4:
                kw_t = self._extract_keywords_from_query(s_t)
            else:
                kw_t = []
            
            # Get previous turn documents for context
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            # Generate think sequence for this specific turn
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                keywords=kw_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            # Add this turn to history for future turns
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = AdaptiveContextLearnerACL()
    obj.cursor = args.start_index
    output_file = f'generated_samples_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            result = obj.generate_sample()
            f.write(json.dumps(result) + '\n')
            f.flush()