import json, random, argparse
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class MultiBeamParallel:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=11, seed=0):
        random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        self.beam_templates = [
            "Turn {prev_turn} explored one beam path - now I need to compare this with alternative beam performance and decide whether to continue this direction or switch to a more promising parallel path.",
            "Building on Turn {prev_turn}, I can evaluate how this beam compares with other parallel hypotheses to determine if this path warrants continued exploration or if switching would be more effective.",
            "From Turn {prev_turn}'s beam results, I need to assess the relative performance against alternative beams and make a strategic decision about resource allocation across parallel searches.",
            "Turn {prev_turn} provided evidence for this beam - I should compare this with other active beams to decide optimal path selection and potential pruning of weaker alternatives.",
            "Following Turn {prev_turn}'s beam exploration, I need to evaluate comparative performance across parallel paths and determine whether to maintain current focus or redirect to stronger beams.",
            "Turn {prev_turn}'s beam path shows certain results - now I can compare this against alternative beam performance to guide selection of the most promising direction for continued exploration."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        return cos if cos is not None else 0.0

    def _select_beam_search_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        if len(candidates) < 7:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None
        
        # Find best query for final position
        best_query = max(candidates, key=lambda x: self._performance_score(x[3], x[4]))
        remaining_candidates = [c for c in candidates if c != best_query]
        
        # Sample 6 other queries
        sampled_6 = random.sample(remaining_candidates, k=6)
        
        # Sort sampled queries by performance (worst to best)
        sorted_sampled = sorted(sampled_6, key=lambda x: self._performance_score(x[3], x[4]), reverse=False)
        
        # Construct search trace: sorted queries + best query at end
        search_trace = sorted_sampled + [best_query]
        
        # Create beam tree structure for reference
        beam_tree = self._construct_beam_tree(search_trace)
        
        # Perform beam swapping for turns 2-6 based on performance comparisons
        for t in range(1, 6):  # positions 1-5 (turns 2-6)
            left_child, right_child = self._get_tree_children(t)
            if left_child < len(search_trace) and right_child < len(search_trace):
                current_perf = self._performance_score(search_trace[t][3], search_trace[t][4])
                other_idx = right_child if t == left_child else left_child
                other_perf = self._performance_score(search_trace[other_idx][3], search_trace[other_idx][4])
                
                # Randomly swap if other beam performs better
                if current_perf < other_perf and random.choice([True, False]):
                    search_trace[t], search_trace[other_idx] = search_trace[other_idx], search_trace[t]
        
        return search_trace, beam_tree

    def _construct_beam_tree(self, search_trace):
        # Simple binary tree structure for beam management
        return {
            'root': 0,
            'levels': {
                0: [0],
                1: [1, 2],
                2: [3, 4, 5, 6]
            }
        }

    def _get_tree_children(self, parent_idx):
        # Binary tree child calculation
        left_child = 2 * parent_idx + 1
        right_child = 2 * parent_idx + 2
        return left_child, right_child

    def _get_sibling_beam(self, turn_idx, search_trace):
        if turn_idx == 1:
            return search_trace[2] if len(search_trace) > 2 else None
        elif turn_idx == 2:
            return search_trace[1] if len(search_trace) > 1 else None
        return None

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history, search_trace):
        is_first_turn = turn_idx == 0
        is_expansion_turn = turn_idx in [1, 2]  # Turns 2-3
        is_comparison_turn = turn_idx in [3, 4, 5]  # Turns 4-6
        is_final_turn = turn_idx == 6
        
        system_prompt = """You are generating synthetic training data for a multi-beam parallel search agent that manages multiple concurrent search hypotheses with systematic pruning and beam comparison. This agent maintains parallel exploration paths and makes strategic decisions about beam selection and resource allocation.

CRITICAL RULES:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Turn 1: Initialize beam search with parallel hypothesis management strategy
3. Turns 2-3: Beam expansion with acknowledgment of parallel alternatives
4. Turns 4-6: Beam comparison and strategic switching decisions
5. Turn 7: Final convergence on single best surviving path
6. NEVER mention numerical scores, performance metrics, or quantitative measurements

MULTI-BEAM PARALLEL PATTERNS:
- Concurrent management of multiple search hypotheses
- Strategic resource allocation across parallel beams
- Systematic pruning of underperforming paths
- Comparative evaluation between alternative beams
- Dynamic switching between promising directions"""

        if is_expansion_turn:
            sibling_beam = self._get_sibling_beam(turn_idx, search_trace)
            sibling_query = sibling_beam[0] if sibling_beam else "alternative beam"
            system_prompt += f"""

BEAM EXPANSION:
- Acknowledge parallel beam "{sibling_query}" as alternative hypothesis
- Explain pursuing this direction first while maintaining awareness of alternatives
- Reference concurrent exploration strategy with multiple active paths
- Show beam management decision making for parallel resource allocation"""
        elif is_comparison_turn:
            system_prompt += """

BEAM COMPARISON:
- Compare current path performance with alternative beam options
- Evaluate relative effectiveness across parallel search directions
- Make strategic decisions about continuing current beam vs switching
- Reference beam pruning and selection criteria based on accumulated evidence
- Show comparative analysis between concurrent hypotheses"""
        elif is_final_turn:
            system_prompt += """

BEAM CONVERGENCE:
- Acknowledge concentrating resources on single best surviving path
- Reference systematic pruning process leading to optimal beam selection
- Demonstrate completion of parallel search with convergence decision
- Express confidence in final beam selection based on comparative evaluation"""

        system_prompt += """

SEARCH QUERY EVOLUTION:
- Think sequence must naturally lead to producing exactly the target search query
- Show beam selection logic connecting to specific search terms
- Frame decisions in terms of parallel hypothesis management and comparison
- Demonstrate strategic resource allocation across multiple concurrent paths

Generate think sequences that read like internal monologues of an agent managing multiple parallel search beams with strategic pruning decisions."""

        context_prompt = f"""Generate thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
"""

        if turn_idx > 0:
            context_prompt += f"- Previous documents: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference thinking: {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS BEAM EXPLORATION:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: \"{prev_query}\"\n"

        # Add beam management examples
        if is_comparison_turn and random.random() < 0.4:
            template = random.choice(self.beam_templates)
            context_prompt += f"\nBEAM COMPARISON PATTERN:\n{template.format(prev_turn=turn_idx)}\n"

        # Turn-specific requirements
        if is_first_turn:
            context_prompt += """
REQUIREMENTS:
- Initialize beam search strategy with parallel hypothesis management
- Establish multiple concurrent exploration paths for systematic evaluation
- Create foundation for beam comparison and pruning decisions
- Express commitment to managing multiple parallel search directions"""
        elif is_expansion_turn:
            sibling_beam = self._get_sibling_beam(turn_idx, search_trace)
            sibling_query = sibling_beam[0] if sibling_beam else "alternative beam"
            context_prompt += f"""
REQUIREMENTS:
- Acknowledge parallel beam "{sibling_query}" as concurrent alternative hypothesis
- Explain decision to pursue this specific direction first
- Reference beam management strategy with multiple active exploration paths
- Show awareness of resource allocation across parallel search beams"""
        elif is_comparison_turn:
            context_prompt += """
REQUIREMENTS:
- Compare current beam path with alternative beam performance
- Evaluate relative effectiveness of parallel search directions
- Make strategic decision about continuing current beam vs switching to alternatives
- Reference systematic beam comparison and potential pruning decisions"""
        elif is_final_turn:
            context_prompt += """
REQUIREMENTS:
- Acknowledge convergence on single best surviving beam path
- Reference systematic pruning and comparison process leading to this selection
- Express confidence in final beam choice based on parallel evaluation
- Demonstrate completion of multi-beam parallel search strategy"""

        context_prompt += f'\nGenerate reasoning that naturally leads to search query: "{search_query}" through beam management.'

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace, beam_tree = self._select_beam_search_trace(qid, k=7)
        
        if search_trace is None:
            return None

        synthetic_sequence = []
        search_history = []

        for t in range(min(7, len(search_trace))):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history,
                search_trace=search_trace
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "multi_beam_parallel"
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = MultiBeamParallel()
    obj.cursor = args.start_index
    output_file = f'generated_samples_multi_beam_parallel_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                if result:
                    f.write(json.dumps(result) + '\n')
                    f.flush()
            except:
                continue