import json, random, argparse
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class DepthFirstDriller:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=10, seed=0):
        random.seed(seed)
        np.random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        # Load embedding model for clustering
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        
        # Cross-turn reflection templates for depth-first exploration
        self.reflection_templates = [
            "Turn {best_turn} showed promise in this direction, while Turn {worst_turn} took me down a different path - my depth-first approach means I need to drill deeper when I find promising leads rather than abandoning them too quickly.",
            
            "Comparing my persistent exploration, Turn {best_turn} revealed depth in this area, whereas Turn {worst_turn} explored elsewhere - depth-first search means exhausting promising directions thoroughly before strategic pivoting.",
            
            "My drilling approach shows Turn {best_turn} had substantial depth to explore, while Turn {worst_turn} covered different ground - I should persist in productive directions until they're fully explored before restarting.",
            
            "Analyzing my deep exploration, Turn {best_turn} demonstrated rich content in this direction, contrasting with Turn {worst_turn} which went elsewhere - depth-first means following promising threads to completion.",
            
            "My persistent drilling reveals Turn {best_turn} opened a productive vein of information, while Turn {worst_turn} pursued different territory - depth-first strategy requires exhausting good directions before pivoting.",
            
            "Reflecting on my deep exploration, Turn {best_turn} showed this direction has substantial depth, while Turn {worst_turn} explored different areas - depth-first approach means drilling until exhaustion before strategic restart."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        # Simple scoring function - higher cosine similarity
        cos_score = cos if cos is not None else 0.0
        return cos_score

    def _select_depth_first_search_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        if len(candidates) < 7:
            # Not enough candidates for clustering, fall back to random selection
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
        
        # Find best performing query
        best_query = max(candidates, key=lambda x: self._performance_score(x[3], x[4]))
        remaining_candidates = [c for c in candidates if c != best_query]
        
        # Need at least 6 remaining candidates for clustering
        if len(remaining_candidates) < 6:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
        
        # Generate embeddings for clustering
        search_queries = [c[0] for c in remaining_candidates]
        embeddings = self.embedding_model.encode(search_queries)
        
        # Perform k-means clustering with 3 clusters
        n_clusters = min(3, len(remaining_candidates))
        if n_clusters < 2:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
            
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(embeddings)
        
        # Sample ordered top 2 per cluster
        cluster_samples = [[] for _ in range(n_clusters)]
        cluster_info = {}
        
        for cluster_id in range(n_clusters):
            cluster_indices = np.where(cluster_labels == cluster_id)[0]
            cluster_candidates = [remaining_candidates[i] for i in cluster_indices]
            
            # Sort by performance and take top 2
            sorted_candidates = sorted(cluster_candidates, key=lambda x: self._performance_score(x[3], x[4]), reverse=True)
            top_2 = sorted_candidates[:2]
            cluster_samples[cluster_id] = top_2
            
            # Store cluster info for each candidate
            for candidate in top_2:
                cluster_info[candidate[0]] = {
                    'cluster_id': cluster_id,
                    'is_cluster_head': candidate == top_2[0]
                }
        
        # Find closest cluster to best_query
        best_query_embedding = self.embedding_model.encode([best_query[0]])[0]
        cluster_similarities = []
        for i, center in enumerate(kmeans.cluster_centers_):
            similarity = cosine_similarity([best_query_embedding], [center])[0][0]
            cluster_similarities.append((similarity, i))
        
        # Get closest cluster index
        _, best_cluster_idx = max(cluster_similarities)
        
        # Add best_query to its closest cluster
        cluster_samples[best_cluster_idx].append(best_query)
        cluster_info[best_query[0]] = {
            'cluster_id': best_cluster_idx,
            'is_best_query': True,
            'is_cluster_head': False
        }
        
        # Create cluster order - random shuffle but ensure best cluster is last
        cluster_order = list(range(n_clusters))
        random.shuffle(cluster_order)
        
        # Move best cluster to end
        if best_cluster_idx in cluster_order:
            cluster_order.remove(best_cluster_idx)
        cluster_order.append(best_cluster_idx)
        
        # Build search trace by extending cluster samples in order
        search_trace = []
        for cluster_idx in cluster_order:
            search_trace.extend(cluster_samples[cluster_idx])
        
        return search_trace[:7], cluster_info

    def _get_cluster_id(self, search_query, cluster_info):
        if cluster_info and search_query in cluster_info:
            return cluster_info[search_query]['cluster_id']
        return None

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history, cluster_info, current_cluster, previous_cluster):
        # Determine the type of turn
        is_restart = previous_cluster is not None and current_cluster != previous_cluster
        is_first_turn = turn_idx == 0
        
        system_prompt = """You are generating synthetic training data for a depth-first search agent that persistently explores promising directions until exhaustion, then strategically pivots to new areas. This is a DATA GENERATION process where each think sequence represents the internal reasoning of an intelligent search agent that learns through deep drilling and strategic restart patterns.

CRITICAL RULES FOR DEPTH-FIRST THINK SEQUENCE GENERATION:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Think sequences must demonstrate PERSISTENT EXPLORATION - drilling deeply into promising directions until exhausted
3. Turn 1 is INITIAL DRILLING: "I'll start by diving deep into..." with commitment to thorough exploration
4. PERSISTENT TURNS: Continue drilling deeper in same direction with incremental refinement and progressive depth
5. RESTART TURNS: Acknowledge exhaustion of previous direction and strategic pivot to completely new area
6. CROSS-CLUSTER DRILLING: Compare how deep exploration in different areas revealed different levels of depth
7. STRATEGIC PERSISTENCE: Balance drilling depth with recognition of when to strategically restart

STRICTLY FORBIDDEN - NEVER MENTION:
- Any numerical scores, performance metrics, or quantitative measurements
- Cluster IDs, technical clustering terms, or algorithmic details
- Rankings, positions, or numerical performance indicators

THE AGENT FOLLOWS DEPTH-FIRST EXPLORATION:
- Depth-first drills deeply into promising directions until thoroughly explored
- Each drilling turn builds incrementally on previous depth in same area
- Documents are valuable for their progressive depth and detailed coverage
- Strategic restarts only occur when current direction is exhausted

DEPTH-FIRST REASONING PATTERNS:
- Express commitment to thorough exploration of promising directions
- Show progressive deepening through incremental refinement
- Reference exhaustive exploration goals before strategic pivoting
- Use language like "drill deeper", "exhaust this direction", "progressive depth", "strategic restart"
- Acknowledge that depth-first prevents premature abandonment of productive paths
- For restart turns: "Having exhausted [previous direction], I need to strategically pivot to [new area]"

REFLECTIVE REASONING PATTERNS (for turns 2+):
- Analyze how deeper drilling revealed progressively richer content
- Compare depth of exploration across different directions using content analysis
- Identify why persistent drilling reveals incremental improvements through content analysis
- Create drilling progress signals: "This deeper exploration revealed..." or "Progressive drilling shows..."
- When continuing drilling: "Building on [previous depth], I can drill even deeper into [specific aspect]"
- When documents show progressive depth: "This direction continues to yield deeper insights with each refinement"
- For restart: "Having thoroughly exhausted [direction], I need to strategically restart with [new approach]"

DETAILED DEPTH-FIRST EXPLORATION ANALYSIS:
- Use specific turn references for drilling progression: "Turn 2 drilled into [aspect], Turn 3 went even deeper into [sub-aspect]"
- Provide progressive depth analysis: "Turn 2 revealed surface-level information while Turn 3 uncovered deeper technical details"
- Trace drilling progression: "My persistent drilling moved from [surface] to [deeper layer] to ensure complete exhaustion"
- Acknowledge drilling completeness: "I've now thoroughly drilled through [X] layers of this direction"
- Show strategic restart recognition: "Based on exhaustive drilling, [direction] has been fully explored - time for strategic pivot"

SEARCH QUERY EVOLUTION FOR DEPTH-FIRST:
- The think sequence must construct reasoning that naturally leads to the target search query
- Show drilling logic that causally connects to the specific search terms being produced
- Frame drilling decisions in terms of the actual query being generated for progressive depth
- For restart turns: justify why this specific query represents a strategic pivot after exhaustive exploration

Generate think sequences that read like internal monologues of a persistent explorer who drills deeply before strategic pivoting."""

        # Build context for depth-first exploration
        context_prompt = f"""Generate 1 depth-first drilling thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
- Turn type: {"Initial Drilling" if is_first_turn else ("Strategic Restart" if is_restart else "Progressive Drilling")}
"""

        if turn_idx > 0:
            context_prompt += f"- Documents retrieved from previous drilling: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference original thinking (for inspiration): {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS DRILLING HISTORY:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: Query=\"{prev_query}\", Thinking=\"{prev_think}\"\n"

        # Add depth-first exploration examples
        if turn_idx >= 2 and random.random() < 0.5:
            context_prompt += f"\nDEPTH-FIRST DRILLING EXAMPLE (use similar pattern):\n{random.choice(self.reflection_templates)}\n"

        if is_restart:
            context_prompt += f"""
STRATEGIC RESTART REQUIREMENTS:
- Acknowledge complete exhaustion of previous exploration direction
- Justify why strategic pivot is necessary after thorough drilling
- Frame this search query as beginning entirely new direction
- Show recognition that previous direction has been fully explored
"""
        elif is_first_turn:
            context_prompt += f"""
INITIAL DRILLING REQUIREMENTS:
- Show commitment to deep exploration of this direction
- Express intention to drill progressively deeper
- Frame current search as beginning of thorough exploration
- Demonstrate depth-first strategy of persistent drilling
"""
        else:
            context_prompt += f"""
PROGRESSIVE DRILLING REQUIREMENTS:
- Build incrementally on previous drilling in same direction
- Show progressive deepening through refinement
- Demonstrate continued commitment to thorough exploration
- Frame current search as deeper dive into same area
"""

        context_prompt += f"""
Generate a single think sequence that demonstrates depth-first drilling reasoning that naturally leads to producing exactly "{search_query}" while maintaining persistent exploration narrative."""

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace, cluster_info = self._select_depth_first_search_trace(qid, k=7)

        synthetic_sequence = []
        search_history = []
        similarity_history = []
        current_cluster = None
        previous_cluster = None

        T = min(7, len(search_trace))
        
        for t in range(T):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            similarity_history.append(cos_t)
            
            # Track cluster changes for restart detection
            previous_cluster = current_cluster
            current_cluster = self._get_cluster_id(s_t, cluster_info)
            
            # Get previous turn documents for context
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            # Generate think sequence for this specific turn
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history,
                cluster_info=cluster_info,
                current_cluster=current_cluster,
                previous_cluster=previous_cluster
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            # Add this turn to history for future turns
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "depth_first_driller",
            "cluster_info": cluster_info
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = DepthFirstDriller()
    obj.cursor = args.start_index
    output_file = f'generated_samples_depth_first_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                f.write(json.dumps(result) + '\n')
                f.flush()
            except:
                continue