import json, random, argparse
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class ExploitationHeavyValidator:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=11, seed=0):
        random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        self.improvement_templates = [
            "This refinement builds on Turn {prev_turn}'s foundation - I'm seeing incremental progress in the right direction with each focused adjustment.",
            "Building on Turn {prev_turn}, this approach shows continued improvement - the sustained focus on this semantic path is yielding better results.",
            "This search extends Turn {prev_turn}'s productive direction - the consistent exploitation of this approach continues to deliver enhanced results.",
            "Following Turn {prev_turn}'s trajectory, this refined approach demonstrates steady improvement - staying focused on this path is proving effective.",
            "This iteration advances Turn {prev_turn}'s successful framework - the concentrated effort on this direction shows progressive enhancement.",
            "Expanding on Turn {prev_turn}, this focused refinement continues the improvement pattern - sustained exploitation is generating better outcomes."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        return cos if cos is not None else 0.0

    def _select_improving_search_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        # Filter to only include samples with valid ranks
        ranked_candidates = [c for c in candidates if c[4] is not None]
        
        if len(ranked_candidates) < 7:
            return random.sample(candidates, k=min(k, len(candidates)))
        
        # Group by rank
        rank_groups = {}
        for candidate in ranked_candidates:
            rank = candidate[4]
            if rank not in rank_groups:
                rank_groups[rank] = []
            rank_groups[rank].append(candidate)
        
        # Get the two lowest ranks (best performance) for the end
        available_ranks = sorted(rank_groups.keys())
        best_ranks = available_ranks[:2] if len(available_ranks) >= 2 else available_ranks
        
        # Reserve samples from the two best ranks for the end
        reserved_for_end = []
        for rank in best_ranks:
            reserved_for_end.extend(rank_groups[rank])
        
        # Get remaining candidates for earlier positions
        remaining_candidates = []
        for rank, candidates_list in rank_groups.items():
            if rank not in best_ranks:
                remaining_candidates.extend(candidates_list)
        
        # Fill first 5 positions with higher ranks, then last 2 with best ranks
        selected = []
        
        # Fill first k-2 positions with remaining candidates (higher ranks)
        if remaining_candidates:
            first_part_size = min(k-2, len(remaining_candidates))
            first_part = random.sample(remaining_candidates, first_part_size)
            selected.extend(first_part)
        
        # Fill last 2 positions with best ranks
        remaining_slots = k - len(selected)
        if reserved_for_end and remaining_slots > 0:
            end_samples = random.sample(reserved_for_end, min(remaining_slots, len(reserved_for_end)))
            selected.extend(end_samples)
        
        # If we still need more samples, fill from any available
        if len(selected) < k:
            all_remaining = [c for c in ranked_candidates if c not in selected]
            if all_remaining:
                additional = random.sample(all_remaining, min(k - len(selected), len(all_remaining)))
                selected.extend(additional)
        
        # Sort by rank descending, but ensure last 2 are from best ranks
        main_part = selected[:-2] if len(selected) >= 2 else selected
        end_part = selected[-2:] if len(selected) >= 2 else []
        
        main_part_sorted = sorted(main_part, key=lambda x: x[4], reverse=True)
        end_part_sorted = sorted(end_part, key=lambda x: x[4])  # Best ranks ascending
        
        search_trace = main_part_sorted + end_part_sorted
        
        return search_trace[:k]

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history):
        is_first_turn = turn_idx == 0
        
        system_prompt = """You are generating synthetic training data for an exploitation-heavy validator agent that demonstrates sustained focus and incremental improvement. This agent learns to maintain concentrated effort on promising directions while showing progressive refinement.

CRITICAL RULES:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Turn 1: Establish focused search direction with clear commitment
3. Turns 2+: Show incremental refinement and improvement acknowledgment
4. Demonstrate sustained exploitation of single semantic path
5. NEVER mention numerical scores, performance metrics, or quantitative measurements

EXPLOITATION-HEAVY PATTERNS:
- Sustained focus on single promising direction without abandoning approach
- Incremental refinement building systematically on previous results
- Recognition of improvement patterns through content analysis
- Commitment to exploiting successful approaches rather than exploring alternatives
- Progressive enhancement through concentrated effort

SEARCH EVOLUTION:
- Think sequence must naturally lead to producing exactly the target search query
- Show refinement logic that connects to specific search terms being generated
- Frame decisions in terms of incremental improvement on established direction
- Demonstrate sustained exploitation rather than exploratory pivoting"""

        if not is_first_turn:
            system_prompt += """

IMPROVEMENT RECOGNITION (Turns 2+):
- Acknowledge progressive improvement from previous turns
- Reference sustained focus on single semantic direction
- Show recognition of incremental enhancement patterns
- "Building on Turn X's foundation, this refinement..."
- "This approach extends the productive direction from Turn X..."
- "Following Turn X's trajectory, this iteration shows continued improvement..."
- "The sustained focus on this path continues to yield better results..."
- Demonstrate commitment to exploiting rather than exploring"""

        system_prompt += """

Generate think sequences that read like internal monologues of an agent committed to sustained exploitation and incremental improvement."""

        context_prompt = f"""Generate thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
"""

        if turn_idx > 0:
            context_prompt += f"- Previous documents: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference thinking: {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS FOCUSED ATTEMPTS:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: \"{prev_query}\"\n"

        # Add improvement pattern examples
        if turn_idx > 0 and random.random() < 0.5:
            template = random.choice(self.improvement_templates)
            context_prompt += f"\nIMPROVEMENT PATTERN EXAMPLE:\n{template.format(prev_turn=turn_idx)}\n"

        # Turn-specific requirements
        if is_first_turn:
            context_prompt += """
REQUIREMENTS:
- Establish focused search direction with clear commitment
- Show determination to pursue this specific approach thoroughly
- Express intention for sustained exploitation rather than broad exploration
- Demonstrate concentrated effort on single semantic path"""
        else:
            context_prompt += f"""
REQUIREMENTS:
- Acknowledge improvement and refinement from previous turns
- Show sustained focus on single semantic direction established earlier
- Reference incremental progress and continued commitment to exploitation
- Build systematically on Turn {turn_idx}'s foundation
- Demonstrate recognition of improvement patterns through content analysis
- Express confidence in sustained exploitation approach"""

        context_prompt += f'\nGenerate reasoning that naturally leads to search query: "{search_query}" through incremental refinement.'

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace = self._select_improving_search_trace(qid, k=7)

        synthetic_sequence = []
        search_history = []

        for t in range(min(7, len(search_trace))):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "exploitation_heavy_validator"
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = ExploitationHeavyValidator()
    obj.cursor = args.start_index
    output_file = f'generated_samples_exploitation_heavy_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                f.write(json.dumps(result) + '\n')
                f.flush()
            except:
                continue