import json, random, argparse
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class BreadthFirstExplorer:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=11, seed=0):
        random.seed(seed)
        np.random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        # Load embedding model for clustering
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        
        # Cross-turn reflection templates for breadth-first exploration
        self.reflection_templates = [
            "Turn {best_turn} provided solid coverage of one conceptual area, while Turn {worst_turn} covered different territory - this systematic exploration across multiple search queries helps ensure I don't miss important perspectives before focusing deeper.",
            
            "Comparing my systematic exploration, Turn {best_turn} yielded comprehensive results for its search queries, whereas Turn {worst_turn} explored a different conceptual space - I'm trying to do a breadth-first approach ensuring I survey all major areas before exploitation.",
            
            "My turn-by-turn exploration shows Turn {best_turn} effectively covered its assigned territory, while Turn {worst_turn} systematically explored different conceptual ground - this methodical coverage prevents premature convergence.",
            
            "Analyzing my systematic coverage, Turn {best_turn} thoroughly explored one area while Turn {worst_turn} systematically covered different conceptual space - breadth-first search ensures comprehensive exploration before deep diving.",
            
            "My methodical exploration reveals Turn {best_turn} provided good coverage of its search query, contrasting with Turn {worst_turn} which systematically explored different territory - this ensures I survey all major conceptual areas first.",
            
            "Reflecting on systematic exploration, Turn {best_turn} comprehensively covered one conceptual search query while Turn {worst_turn} methodically explored different semantic space - this breadth-first strategy prevents missing important areas."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        # Simple scoring function - higher cosine similarity
        cos_score = cos if cos is not None else 0.0
        return cos_score

    def _select_clustered_search_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        if len(candidates) < 7:
            # Not enough candidates for clustering, fall back to random selection
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
        
        # Find best performing query
        best_query = max(candidates, key=lambda x: self._performance_score(x[3], x[4]))
        remaining_candidates = [c for c in candidates if c != best_query]
        
        # Need at least 6 remaining candidates for exploration phase
        if len(remaining_candidates) < 6:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
        
        # Generate embeddings for clustering
        search_queries = [c[0] for c in remaining_candidates]
        embeddings = self.embedding_model.encode(search_queries)
        
        # Perform k-means clustering with appropriate number of clusters
        n_clusters = min(6, len(remaining_candidates))
        if n_clusters < 2:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
            
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(embeddings)
        
        # Select representative query from each cluster
        cluster_heads = []
        cluster_info = {}
        
        for cluster_id in range(n_clusters):
            cluster_indices = np.where(cluster_labels == cluster_id)[0]
            cluster_candidates = [remaining_candidates[i] for i in cluster_indices]
            
            cluster_head = random.choice(cluster_candidates)
            cluster_heads.append(cluster_head)
            
            cluster_info[cluster_head[0]] = {
                'cluster_id': cluster_id,
                'cluster_size': len(cluster_candidates),
                'is_cluster_head': True
            }
        
        # Use all cluster heads as exploration trace (exactly 6)
        exploration_trace = cluster_heads[:]
        random.shuffle(exploration_trace)
        
        # Find closest cluster to best_query for exploitation
        best_query_embedding = self.embedding_model.encode([best_query[0]])[0]
        cluster_head_embeddings = self.embedding_model.encode([head[0] for head in cluster_heads])
        
        # Calculate similarities between best_query and all cluster heads
        similarities = cosine_similarity([best_query_embedding], cluster_head_embeddings)[0]
        closest_cluster_idx = np.argmax(similarities)
        
        # Get the cluster_id of the closest cluster head
        closest_cluster_head = cluster_heads[closest_cluster_idx]
        closest_cluster_id = cluster_info[closest_cluster_head[0]]['cluster_id']
        
        # Find which turn position the closest cluster head will be in
        closest_turn_position = None
        for i, head in enumerate(exploration_trace):
            if head == closest_cluster_head:
                closest_turn_position = i + 1  # Turn numbers are 1-indexed
                break
        
        # Add exploitation query (best performing) with closest cluster ID
        exploitation_query = best_query
        cluster_info[exploitation_query[0]] = {
            'cluster_id': closest_cluster_id,
            'cluster_size': 1,
            'is_best_query': True,
            'is_cluster_head': False,
            'exploits_turn': closest_turn_position  # Store which turn this exploits
        }
        
        # Final search trace: 6 exploration + 1 exploitation
        search_trace = exploration_trace + [exploitation_query]
        
        return search_trace, cluster_info

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history, cluster_info, explored_clusters):
        # Determine if this is exploration or exploitation phase
        is_exploitation_turn = turn_idx == 6  # Last turn (index 6 = turn 7)
        
        system_prompt = """You are generating synthetic training data for a breadth-first search agent that systematically explores different conceptual areas before exploitation. This is a DATA GENERATION process where each think sequence represents the internal reasoning of an intelligent search agent that learns through systematic cluster exploration and cross-turn analysis.

CRITICAL RULES FOR BREADTH-FIRST THINK SEQUENCE GENERATION:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Think sequences must demonstrate SYSTEMATIC EXPLORATION - methodically covering different conceptual areas
3. Turn 1 is SYSTEMATIC START: "To systematically explore this topic, I'll start by examining..." 
4. Turns 2-6 are CLUSTER EXPLORATION: Reflect on previous cluster results, then systematically move to new conceptual area
5. Turn 7 is EXPLOITATION: After systematic exploration, focus on the most promising direction found
6. CROSS-CLUSTER ANALYSIS: Compare how different conceptual areas yielded different types of content
7. SYSTEMATIC COVERAGE: Acknowledge methodical exploration strategy and comprehensive coverage goals

STRICTLY FORBIDDEN - NEVER MENTION:
- Any numerical scores, performance metrics, or quantitative measurements
- Cluster IDs, technical clustering terms, or algorithmic details
- Rankings, positions, or numerical performance indicators

THE AGENT FOLLOWS SYSTEMATIC EXPLORATION:
- Breadth-first explores different conceptual areas methodically
- Each exploration turn covers a different aspect/approach systematically  
- Documents are valuable for their coverage of different conceptual spaces
- Systematic coverage ensures comprehensive understanding before exploitation

BREADTH-FIRST REASONING PATTERNS:
- Express systematic coverage strategy for different conceptual areas
- Show methodical progression through different approaches/perspectives
- Reference comprehensive exploration goals before focusing deeper
- Use language like "systematically explore", "methodically cover", "comprehensive survey"
- Acknowledge that breadth-first prevents premature convergence
- For exploitation turn: "After systematically exploring multiple areas, X shows the most promise"

REFLECTIVE REASONING PATTERNS (for turns 2+):
- Analyze how different conceptual areas revealed different content types
- Compare systematic coverage across different clusters/approaches using content analysis
- Identify why comprehensive exploration reveals different perspectives through content analysis  
- Create systematic coverage signals: "This systematic approach revealed..." or "Methodical exploration shows..."
- When moving to new areas: "Having explored [previous area], I'll now systematically examine [new area]"
- When documents reveal different perspectives: "This area covers different aspects than my previous systematic exploration"
- For exploitation: "My systematic exploration reveals [area] has the richest content for deeper investigation"

DETAILED SYSTEMATIC EXPLORATION ANALYSIS:
- Use specific turn references for systematic coverage: "Turn 2 systematically covered [area], while Turn 3 explored [different area]"
- Provide conceptual area coverage analysis: "Turn 2 revealed technical aspects while Turn 3 systematically covered practical applications"
- Trace systematic exploration progression: "My methodical exploration moved from [area1] to [area2] to ensure comprehensive coverage"  
- Acknowledge systematic coverage completeness: "I've now systematically explored [X] different conceptual areas"
- Show methodical learning: "Based on systematic exploration, [area] requires deeper investigation"

SEARCH QUERY EVOLUTION FOR BREADTH-FIRST:
- The think sequence must construct reasoning that naturally leads to the target search query
- Show systematic logic that causally connects to the specific search terms being produced
- Frame exploration decisions in terms of the actual query being generated for systematic coverage
- For exploitation turn: justify why this specific query represents the best direction after systematic exploration

Generate think sequences that read like internal monologues of a systematic explorer conducting comprehensive surveys before focused investigation."""

        # Build context for systematic exploration
        context_prompt = f"""Generate 1 systematic exploration thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
- Exploration phase: {"Exploitation (Focus)" if is_exploitation_turn else "Exploration (Coverage)"}
"""

        if turn_idx > 0:
            context_prompt += f"- Documents retrieved from previous systematic exploration: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference original thinking (for inspiration): {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS SYSTEMATIC EXPLORATION HISTORY:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                phase = "Exploitation" if i == 6 else "Exploration"
                context_prompt += f"Turn {i + 1} ({phase}): Query=\"{prev_query}\", Thinking=\"{prev_think}\"\n"

        # Add systematic exploration examples
        if turn_idx >= 2 and random.random() < 0.5:
            context_prompt += f"\nSYSTEMATIC EXPLORATION EXAMPLE (use similar pattern):\n{random.choice(self.reflection_templates)}\n"

        if is_exploitation_turn:
            # Find which turn this exploitation is based on
            exploits_turn = None
            if cluster_info and search_query in cluster_info:
                exploits_turn = cluster_info[search_query].get('exploits_turn')
            
            context_prompt += f"""
EXPLOITATION REQUIREMENTS:
- Acknowledge completion of systematic exploration across multiple conceptual areas
- Justify why this specific search query represents the most promising direction
- SPECIFICALLY mention that Turn {exploits_turn} was the most valuable and you're going to be building on that specific turn's approach
- Frame as focused investigation after comprehensive survey across the different search turns
- You can combine learnings from similar turns, however, Turn {exploits_turn} must be referenced as the primary direction being exploited
"""
        else:
            context_prompt += f"""
EXPLORATION REQUIREMENTS:
- Show systematic coverage strategy for different conceptual areas
- If turn 2+: Note difference from previous exploration areas
- Demonstrate methodical progression through different approaches
- Frame current search as part of comprehensive survey strategy
"""

        context_prompt += f"""
Generate a single think sequence that demonstrates systematic exploration reasoning that naturally leads to producing exactly "{search_query}" while maintaining breadth-first exploration narrative."""

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace, cluster_info = self._select_clustered_search_trace(qid, k=7)

        synthetic_sequence = []
        search_history = []
        similarity_history = []
        explored_clusters = []

        T = min(7, len(search_trace))
        
        for t in range(T):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            similarity_history.append(cos_t)
            
            # Track cluster exploration
            if cluster_info and s_t in cluster_info:
                explored_clusters.append(cluster_info[s_t]['cluster_id'])
            else:
                explored_clusters.append(t)  # Fallback cluster ID
            
            # Get previous turn documents for context
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            # Generate think sequence for this specific turn
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history,
                cluster_info=cluster_info,
                explored_clusters=explored_clusters
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            # Add this turn to history for future turns
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "breadth_first_explorer",
            "cluster_info": cluster_info
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = BreadthFirstExplorer()
    obj.cursor = args.start_index
    output_file = f'generated_samples_breadth_first_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                f.write(json.dumps(result) + '\n')
                f.flush()
            except:
                continue