import json, random, argparse
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class BestFirstHypothesisSelector:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=11, seed=0):
        random.seed(seed)
        np.random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        # Load embedding model
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        
        self.hypothesis_templates = [
            "Turn {prev_turn} explored one hypothesis path - now I'll evaluate the best parent foundation to select the next promising direction from my available options.",
            "Building on Turn {prev_turn}, I need to assess which hypothesis shows the most potential based on my current best parent and available exploration paths.",
            "From Turn {prev_turn}'s results, I can compare this with my best parent performance to decide which of the three available directions offers the strongest hypothesis.",
            "Turn {prev_turn} provided evidence for this path - I'll use my best parent tracking to guide selection among the remaining hypothesis options.",
            "Following Turn {prev_turn}, I need to reference my best parent foundation and choose from the three directions that remain most promising for exploration.",
            "Turn {prev_turn}'s exploration helps me update my best parent assessment - now I can select the optimal hypothesis from my available directions."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        return cos if cos is not None else 0.0

    def _select_best_first_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        if len(candidates) < 7:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None, None, None
        
        # Find worst and best queries
        worst_query = min(candidates, key=lambda x: self._performance_score(x[3], x[4]))
        best_query = max(candidates, key=lambda x: self._performance_score(x[3], x[4]))
        
        # Get middle candidates (excluding worst and best)
        middle_candidates = [c for c in candidates if c != worst_query and c != best_query]
        
        if len(middle_candidates) < 6:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None, None, None
        
        # Generate embeddings for clustering
        search_queries = [c[0] for c in middle_candidates]
        embeddings = self.embedding_model.encode(search_queries)
        
        # Perform k-means clustering
        n_clusters = min(3, len(middle_candidates))
        if n_clusters < 3:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None, None, None
            
        kmeans = KMeans(n_clusters=3, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(embeddings)
        
        # Select cluster heads
        cluster_heads = []
        cluster_members = [[] for _ in range(3)]
        
        for cluster_id in range(3):
            cluster_indices = np.where(cluster_labels == cluster_id)[0]
            cluster_candidates = [middle_candidates[i] for i in cluster_indices]
            cluster_members[cluster_id] = cluster_candidates
            
            # Select best performing candidate from cluster as head
            cluster_head = max(cluster_candidates, key=lambda x: self._performance_score(x[3], x[4]))
            cluster_heads.append(cluster_head)
        
        # Find closest cluster to best query
        best_query_embedding = self.embedding_model.encode([best_query[0]])[0]
        cluster_similarities = []
        for i, center in enumerate(kmeans.cluster_centers_):
            similarity = cosine_similarity([best_query_embedding], [center])[0][0]
            cluster_similarities.append((similarity, i))
        
        _, best_cluster_idx = max(cluster_similarities)
        
        # Create exploration order: non-best clusters first, then best cluster
        other_clusters = [i for i in range(3) if i != best_cluster_idx]
        random.shuffle(other_clusters)
        
        # Build search trace
        search_trace = [worst_query]  # Turn 1: root
        best_parent = worst_query
        current_path = [0]
        
        # Turns 2-4: Explore non-best clusters
        for i, cluster_idx in enumerate(other_clusters):
            if len(search_trace) < 4:  # Turns 2-3
                next_query = cluster_heads[cluster_idx]
                search_trace.append(next_query)
                current_path.append(cluster_idx + 1)
                
                # Update best parent if this query is better
                if self._performance_score(next_query[3], next_query[4]) > self._performance_score(best_parent[3], best_parent[4]):
                    best_parent = next_query
        
        # Turn 4: If we have 3 clusters and need turn 4
        if len(search_trace) < 4 and len(other_clusters) > 1:
            next_query = cluster_heads[other_clusters[1]] if len(other_clusters) > 1 else cluster_heads[other_clusters[0]]
            search_trace.append(next_query)
            current_path.append(other_clusters[1] + 1 if len(other_clusters) > 1 else other_clusters[0] + 1)
            
            if self._performance_score(next_query[3], next_query[4]) > self._performance_score(best_parent[3], best_parent[4]):
                best_parent = next_query
        
        # Turn 5: Enter best cluster (cluster head)
        next_query = cluster_heads[best_cluster_idx]
        search_trace.append(next_query)
        current_path.append(best_cluster_idx + 1)
        
        if self._performance_score(next_query[3], next_query[4]) > self._performance_score(best_parent[3], best_parent[4]):
            best_parent = next_query
        
        # Turn 6: Select from best cluster members
        best_cluster_queries = cluster_members[best_cluster_idx]
        available_queries = [q for q in best_cluster_queries if q not in search_trace]
        if available_queries:
            next_query = random.choice(available_queries)
        else:
            next_query = random.choice(best_cluster_queries)
        
        search_trace.append(next_query)
        current_path.append("best_cluster_member")
        
        if self._performance_score(next_query[3], next_query[4]) > self._performance_score(best_parent[3], best_parent[4]):
            best_parent = next_query
        
        # Turn 7: Best query (optimal result)
        search_trace.append(best_query)
        
        return search_trace, best_parent, cluster_heads, current_path

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history, best_parent, cluster_heads):
        is_first_turn = turn_idx == 0
        is_final_turn = turn_idx == 6
        
        system_prompt = """You are generating synthetic training data for a best-first hypothesis selector agent that manages multiple search hypotheses and tracks the best parent foundation for decision making. This agent maintains three potential exploration directions and selects based on accumulated evidence.

CRITICAL RULES:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Turn 1: Generate initial hypothesis tree with three potential directions
3. Turns 2-6: Reference best parent performance and select from available hypothesis options
4. Turn 7: Acknowledge returning to best parent foundation for optimal result
5. NEVER mention numerical scores, performance metrics, or quantitative measurements

BEST-FIRST HYPOTHESIS PATTERNS:
- Dynamic hypothesis management with three exploration directions
- Best parent tracking to guide future selections
- Evidence-based decision making using accumulated results
- Tree-like exploration structure with parent-child relationships
- Strategic selection from available hypothesis options"""

        if is_first_turn:
            system_prompt += """

INITIAL HYPOTHESIS GENERATION:
- Establish three potential exploration directions for the search tree
- Create foundation for hypothesis management and tracking
- End thinking by identifying the three directions available for exploration
- Reference the hypothesis tree structure for systematic exploration"""
        elif not is_final_turn:
            system_prompt += """

HYPOTHESIS SELECTION:
- Reference current best parent performance to guide selection
- Compare available hypothesis options based on accumulated evidence
- Select next direction from the three available exploration paths
- End thinking by noting the three directions remaining for exploration
- Update best parent tracking based on current results"""
        else:
            system_prompt += """

OPTIMAL HYPOTHESIS SELECTION:
- Acknowledge returning to best parent foundation for final optimization
- Reference accumulated evidence from hypothesis exploration
- Demonstrate completion of tree-based search strategy
- Select optimal result based on best parent tracking"""

        system_prompt += """

SEARCH QUERY EVOLUTION:
- Think sequence must naturally lead to producing exactly the target search query
- Show hypothesis selection logic connecting to specific search terms
- Frame decisions in terms of tree exploration and parent tracking
- Demonstrate evidence-based selection from available options

Generate think sequences that read like internal monologues of an agent managing multiple hypotheses with best-first selection strategy."""

        context_prompt = f"""Generate thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
"""

        if turn_idx > 0:
            context_prompt += f"- Previous documents: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference thinking: {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS HYPOTHESIS EXPLORATION:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: \"{prev_query}\"\n"

        # Add hypothesis management examples
        if turn_idx > 0 and random.random() < 0.4:
            template = random.choice(self.hypothesis_templates)
            context_prompt += f"\nHYPOTHESIS MANAGEMENT PATTERN:\n{template.format(prev_turn=turn_idx)}\n"

        # Turn-specific requirements
        if is_first_turn:
            context_prompt += """
REQUIREMENTS:
- Generate initial hypothesis tree with systematic exploration strategy
- Establish three potential directions for search exploration
- Create foundation for best parent tracking and evidence accumulation
- End by identifying the three hypothesis directions available for exploration"""
        elif is_final_turn:
            context_prompt += """
REQUIREMENTS:
- Acknowledge returning to best parent foundation for optimal selection
- Reference accumulated evidence from hypothesis tree exploration
- Express confidence in final result based on best-first strategy
- Demonstrate completion of evidence-based hypothesis management"""
        else:
            context_prompt += """
REQUIREMENTS:
- Reference best parent performance to guide current selection
- Compare available hypothesis options based on accumulated evidence
- Select from the three exploration directions using best-first strategy
- End by noting the three directions remaining for future exploration
- Update hypothesis tracking based on current exploration results"""

        context_prompt += f'\nGenerate reasoning that naturally leads to search query: "{search_query}" through hypothesis selection.'

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace, best_parent, cluster_heads, current_path = self._select_best_first_trace(qid, k=7)
        
        if search_trace is None:
            return None

        synthetic_sequence = []
        search_history = []

        for t in range(min(7, len(search_trace))):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history,
                best_parent=best_parent,
                cluster_heads=cluster_heads
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "best_first_hypothesis_selector"
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = BestFirstHypothesisSelector()
    obj.cursor = args.start_index
    output_file = f'generated_samples_best_first_hypothesis_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                if result:
                    f.write(json.dumps(result) + '\n')
                    f.flush()
            except:
                continue