import json, random, argparse
import numpy as np
from sklearn.metrics.pairwise import cosine_distances
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
from tqdm import trange
from openai import OpenAI
import os

with open("./data/data/openai.key", "r") as f:
    OPENAI_API_KEY = f.read()

class SingleThink(BaseModel):
    think: str

class GreedyHillClimber:
    def __init__(self, path="./data/data/preprocessed/data.json", max_turns=7, below5_thr=1.0, min_samples_thr=11, seed=0):
        random.seed(seed)
        np.random.seed(seed)
        with open(path, "r") as f:
            data = json.load(f)
        self.turns_by_qid = data["turns"]
        self.user_query_by_qid = data["user_query"]
        self.max_turns = max_turns
        self.below5_thr = below5_thr
        self.min_samples_thr = min_samples_thr
        self.query_ids = self._filter_query_ids()
        self.cursor = 0
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        # Load embedding model
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        
        self.hill_climbing_templates = [
            "Turn {prev_turn} explored one neighbor, now I'll examine adjacent options to find better local improvements in this search space.",
            "Building on Turn {prev_turn}'s position, I'll evaluate nearby neighbors to continue the hill climbing optimization process.",
            "From Turn {prev_turn}'s location in the search space, I can explore neighboring queries to find the steepest improvement gradient.",
            "Turn {prev_turn} established my current position - now I'll systematically check nearby alternatives for better local optimization.",
            "Following Turn {prev_turn}'s neighbor exploration, I'll continue climbing toward better solutions in this semantic neighborhood.",
            "Turn {prev_turn} provided a stepping stone - I'll now evaluate the next set of neighbors for continued improvement."
        ]

    def _filter_query_ids(self):
        out = []
        for qid in self.turns_by_qid:
            uq = self.user_query_by_qid[qid]
            if isinstance(uq, dict) and "score" in uq and uq["score"].get("below5") is not None:
                if uq["score"]["below5"] > self.below5_thr:
                    continue
            if len(self.turns_by_qid[qid]) < self.min_samples_thr:
                continue
            out.append(qid)
        return out

    def _user_query_text(self, qid):
        uq = self.user_query_by_qid[qid]
        return uq["user_query"] if isinstance(uq, dict) and "user_query" in uq else uq

    def _performance_score(self, cos, rank):
        return cos if cos is not None else 0.0

    def _select_hill_climbing_trace(self, qid, k=7):
        candidates = []
        for t in self.turns_by_qid[qid]:
            candidates.append((
                t.get("search_query") or "",
                t.get("think"),
                t.get("top_k_results") or [],
                t.get("best_cosine"),
                t.get("best_rank")
            ))
        
        if len(candidates) < 7:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None, None
        
        # Sort by performance (cosine similarity)
        sorted_candidates = sorted(candidates, key=lambda x: self._performance_score(x[3], x[4]), reverse=True)
        
        # Remove top 2 best for final turn
        top_2_best = sorted_candidates[:2]
        remaining = sorted_candidates[2:]
        
        if len(remaining) < 5:
            selected = random.sample(candidates, k=min(k, len(candidates)))
            return selected, None, None
        
        # Generate embeddings for remaining candidates
        search_queries = [c[0] for c in remaining]
        embeddings = self.embedding_model.encode(search_queries)
        
        # Compute distance matrix
        distance_matrix = cosine_distances(embeddings)
        
        # Find worst performing query as starting point
        worst_idx = min(range(len(remaining)), key=lambda i: self._performance_score(remaining[i][3], remaining[i][4]))
        worst_query = remaining[worst_idx]
        
        # Hill climbing process
        search_trace = [worst_query]
        current_idx = worst_idx
        visited_indices = {worst_idx}
        
        for t in range(1, 6):  # Fill positions 2-6
            # Find k-nearest neighbors (k=3)
            distances = distance_matrix[current_idx]
            neighbor_indices = np.argsort(distances)[:4]  # Get 4 closest (including self)
            neighbor_indices = [idx for idx in neighbor_indices if idx != current_idx][:3]  # Remove self, take 3
            
            # Find unvisited neighbors
            unvisited_neighbors = [idx for idx in neighbor_indices if idx not in visited_indices]
            
            if len(unvisited_neighbors) == 0:
                # Cycle detection - pick closest neighbor
                next_idx = neighbor_indices[0] if neighbor_indices else current_idx
            else:
                # Pick best performing unvisited neighbor
                next_idx = max(unvisited_neighbors, key=lambda i: self._performance_score(remaining[i][3], remaining[i][4]))
            
            search_trace.append(remaining[next_idx])
            visited_indices.add(next_idx)
            current_idx = next_idx
        
        # Add final optimization query
        final_query = random.choice(top_2_best)
        search_trace.append(final_query)
        
        return search_trace, distance_matrix, remaining

    def _check_cycle_at_turn(self, turn_idx, search_trace):
        if turn_idx < 2:
            return False
        
        # Simple cycle detection - check if current query appeared before
        current_query = search_trace[turn_idx][0]
        for i in range(turn_idx):
            if search_trace[i][0] == current_query:
                return True
        return False

    def _format_prev_docs(self, docs):
        lines = []
        for i, d in enumerate(docs, 1):
            if not isinstance(d, str):
                continue
            lines.append(f'{i}. """{d}"""')
        return "\n\n".join(lines)

    def _generate_single_think(self, q_text, turn_idx, search_query, prev_docs, original_thinking, search_history, search_trace):
        is_first_turn = turn_idx == 0
        is_final_turn = turn_idx == 6
        cycle_detected = self._check_cycle_at_turn(turn_idx, search_trace) if search_trace else False
        
        system_prompt = """You are generating synthetic training data for a greedy hill climber agent that performs local optimization through neighbor exploration. This agent systematically evaluates nearby alternatives to find incremental improvements.

CRITICAL RULES:
1. This think sequence corresponds to turn {turn_num} and MUST produce the search query for this specific turn
2. Turn 1: Initialize hill climbing with local optimization strategy
3. Turns 2-6: Explore neighbors and select best local improvements
4. Turn 7: Final optimization selecting from best available options
5. NEVER mention numerical scores, performance metrics, or quantitative measurements

HILL CLIMBING PATTERNS:
- Local optimization through systematic neighbor evaluation
- Incremental improvement by selecting best nearby alternatives
- Recognition of search space neighborhoods and adjacency
- Steepest ascent strategy choosing best local improvements
- Cycle awareness when revisiting previous positions"""

        if cycle_detected:
            system_prompt += """

CYCLE RECOGNITION:
- Acknowledge revisiting previous search territory
- Express concern about potential cycling behavior
- Reference need to escape local neighborhoods
- Show awareness of returning to explored areas"""
        
        if is_final_turn:
            system_prompt += """

FINAL OPTIMIZATION:
- Acknowledge selecting from best available queries for optimal result without claiming that it will be the best but rather in the tone that it may be the best candidate
- Reference systematic exploration leading to final selection
- Express confidence in reaching optimal local solution
- Demonstrate completion of hill climbing process"""
        elif not is_first_turn and not cycle_detected:
            system_prompt += """

NEIGHBOR EXPLORATION:
- Explain selecting a neighbor for continued improvement
- Reference local optimization and incremental progress
- Show awareness of exploring semantic neighborhoods
- Demonstrate steepest ascent decision making"""

        system_prompt += """

SEARCH QUERY EVOLUTION:
- Think sequence must naturally lead to producing exactly the target search query
- Show neighbor selection logic connecting to specific search terms
- Frame decisions in terms of local optimization and improvement
- Demonstrate systematic evaluation of alternatives

Generate think sequences that read like internal monologues of an agent performing systematic local optimization."""

        context_prompt = f"""Generate thinking sequence for this turn:

USER QUERY: {q_text}

TURN {turn_idx + 1}:
- Will produce search query: "{search_query}"
"""

        if turn_idx > 0:
            context_prompt += f"- Previous documents: {self._format_prev_docs(prev_docs)}\n"
            context_prompt += f"- Reference thinking: {original_thinking}\n"
        
        if search_history:
            context_prompt += "\nPREVIOUS HILL CLIMBING STEPS:\n"
            for i, (prev_query, prev_think) in enumerate(search_history):
                context_prompt += f"Turn {i + 1}: \"{prev_query}\"\n"

        # Add hill climbing examples
        if turn_idx > 0 and random.random() < 0.4:
            template = random.choice(self.hill_climbing_templates)
            context_prompt += f"\nHILL CLIMBING PATTERN:\n{template.format(prev_turn=turn_idx)}\n"

        # Turn-specific requirements
        if is_first_turn:
            context_prompt += """
REQUIREMENTS:
- Initialize hill climbing strategy with local optimization focus
- Express commitment to systematic neighbor evaluation
- Show intention to find incremental improvements through local search
- Establish starting position for steepest ascent optimization"""
        elif cycle_detected:
            context_prompt += """
REQUIREMENTS:
- Acknowledge returning to previously explored search territory
- Express concern about potential cycling in the search space
- Reference need to find alternative neighbors or escape local area
- Show awareness of revisiting previous positions"""
        elif is_final_turn:
            context_prompt += """
REQUIREMENTS:
- Acknowledge selecting from best available options for final optimization
- Reference systematic neighbor exploration leading to this selection
- Express confidence in reaching optimal solution through hill climbing
- Demonstrate completion of local optimization process"""
        else:
            context_prompt += """
REQUIREMENTS:
- Explain neighbor evaluation and selection of best local improvement
- Reference incremental progress through systematic optimization
- Show awareness of exploring semantic neighborhoods for better alternatives
- Demonstrate steepest ascent decision making process"""

        context_prompt += f'\nGenerate reasoning that naturally leads to search query: "{search_query}" through local optimization.'

        response = self.client.beta.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt.format(turn_num=turn_idx + 1)},
                {"role": "user", "content": context_prompt}
            ],
            response_format=SingleThink
        )
        
        return response.choices[0].message.parsed.think

    def generate_sample(self):
        if not self.query_ids:
            return None
        self.cursor = self.cursor % len(self.query_ids)
        qid = self.query_ids[self.cursor]
        self.cursor += 1

        q_text = self._user_query_text(qid)
        search_trace, distance_matrix, remaining_candidates = self._select_hill_climbing_trace(qid, k=7)
        
        if search_trace is None:
            return None

        synthetic_sequence = []
        search_history = []

        for t in range(min(7, len(search_trace))):
            s_t, th_t, R_t, cos_t, rank_t = search_trace[t]
            prev_docs = search_trace[t-1][2] if t > 0 else []
            
            think_t = self._generate_single_think(
                q_text=q_text,
                turn_idx=t,
                search_query=s_t,
                prev_docs=prev_docs,
                original_thinking=th_t,
                search_history=search_history,
                search_trace=search_trace
            )
            
            synthetic_sequence.append({"tag": "think", "messages": think_t})
            synthetic_sequence.append({
                "tag": "search_query", 
                "text": s_t,
                "top_k_results": R_t,
                "best_cosine": cos_t,
                "best_rank": rank_t
            })
            
            search_history.append((s_t, think_t))

        return {
            "query_id": qid,
            "user_query": q_text,
            "sequence": synthetic_sequence,
            "behavior_type": "greedy_hill_climber"
        }
    
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--start_index', type=int, required=True, help='Starting index for generation')
    parser.add_argument('--end_index', type=int, required=True, help='Ending index for generation')
    args = parser.parse_args()
    
    obj = GreedyHillClimber()
    obj.cursor = args.start_index
    output_file = f'generated_samples_greedy_hill_climber_{args.start_index}_{args.end_index}.jsonl'
    
    with open(output_file, 'a') as f:
        for i in trange(args.start_index, args.end_index):
            try:
                result = obj.generate_sample()
                if result:
                    f.write(json.dumps(result) + '\n')
                    f.flush()
            except:
                continue