import json
import os
import time
import random
from tqdm import tqdm
import sys
import argparse
import re
from typing import List, Dict, Tuple, Optional, Callable
import torch
import numpy as np
import asyncio
import aiohttp
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from openai import AsyncOpenAI, OpenAI
from scripts.difficulty_filter import AnswerExtract, BioProBenchEval, ChemCoTEval
from prompt import system_prompts, crossover_template, prefix_prompt, breakpoint_prompt, mutation_prompt_1, mutation_prompt_2, mutation_prompt_3
import warnings

class Thought:
    # chain of thoughts
    def __init__(self, text: str, model_name: str = None):
        self.text = text.split("</think>")[0].split("<think>")[-1].strip()  # List of thought steps
        self.pred = text.split("</think>")[-1].strip()  # Prediction
        self.model_name = model_name  # Source model ID
        self.fitness_score = None  # Fitness score
        self.thoughts = [] # Prefix and suffix divided by crossover
        # self.breakpoints = None
        # ==== New fields for NSLC algorithm ====
        self.behavior_vector = None
        self.local_score = None
        self.novelty_score = None
        # =====
    
    def save_thoughts(self, thought_list: List[str]):
        self.thoughts = thought_list
    
    def save_breakpoints(self, breakpoint: str):
        self.breakpoints = breakpoint
    
    def save_fitness(self, fitness_score):
        self.fitness_score = fitness_score

class GeneticAlgorithm:
    """
    A genetic algorithm optimizer for a single problem.
    """
    def __init__(self, 
        task_name: str=None, 
        metadata: dict=None,
        deepseek_api_base: str=None,
        deepseek_api_key: str=None,
        qwen_api_base: str=None,
        qwen_api_key: str=None,
        openai_api_base: str=None,
        openai_api_key: str=None,
        deepseek_tokenizer_path: str=None,
        qwen_tokenizer_path: str=None,
    ):
        """
        client: OpenAI client instance (async)
        model_name: Local vLLM model path
        tokenizer_path: Tokenizer path
        """
        # ===== metadata =====
        self.task_name = task_name
        # TODO
        if self.task_name not in ['ChemCoTDataset', 'BioProBench']:
            raise ValueError(f"Unknown task: {self.task_name}")
        
        self.metadata = None if not metadata else {
            k: v for k, v in metadata.items() if k not in [
                'raw_cot', 'output'
            ]
        }
        self.query = self.metadata['query']
        self.answer = self.metadata['answer'] if 'answer' in self.metadata else self.metadata['struct_cot']

        # example: self.client[Thought.model_name]
        self.client = {
            'deepseek-r1': AsyncOpenAI(
                api_key=deepseek_api_key,
                base_url=deepseek_api_base
            ),
            'qwen3-32b': AsyncOpenAI(
                api_key=qwen_api_key,
                base_url=qwen_api_base
            ),
            'qwen3-235b-a22b-thinking-2507': AsyncOpenAI(
                api_key=qwen_api_key,
                base_url=qwen_api_base
            ),
            'gpt-5': AsyncOpenAI(
                api_key=openai_api_key,
                base_url=openai_api_base
            ),
        }

        self.tokenizer_path = {
            'deepseek-r1': deepseek_tokenizer_path,
            'qwen3-32b': qwen_tokenizer_path,
            'qwen3-235b-a22b-thinking-2507': qwen_tokenizer_path,
        }
        # ===== model and prompt for crossover step =====
        self.model_names = {
            'prefix': 'gpt-5',
            'breakpoint': 'gpt-5',
            'crossover': qwen_tokenizer_path,
        }
        self.crossover_tokenizer = AutoTokenizer.from_pretrained(self.model_names['crossover'])
        n_device = torch.cuda.device_count()
        self.prefix_prompt = prefix_prompt
        self.breakpoint_prompt = breakpoint_prompt

        # ===== model and prompt for mutation step =====
        self.mutation_prompt = (mutation_prompt_1, mutation_prompt_2, mutation_prompt_3)

        # ===== population =====
        self.population: List[Thought] = []

        # ===== NSD algorithm ====
        self.embedder = embedder_model

    async def fetch_chat_completion(self, chat_client, messages, model_name, temperature=1.0):
        for i in range(3):
            try:
                response = await chat_client.chat.completions.create(
                    model=model_name,
                    messages=messages,
                    temperature=temperature,
                    top_p=1.0,
                )
                return response.choices[0].message.content
            except Exception as e:
                print(e)
                continue
        return None

    def extract_result(self, text, name):
        m = re.search(r"\[RESULT_START\](.*?)\[RESULT_END\]", text, re.S)
        if not m:
            raise ValueError(f"No result found in {name}")
        return m.group(1).strip()
    
    def _split_with_marker(self, thought: Thought, marker: str) -> Tuple[str, str]:
        """
        Perform a single split on the full text using the prefix_split (equivalent to the marker).
        
        Args:
            thought (Thought): The Thought object containing the full text to be split.
            marker (str): The marker string used for splitting.

        Returns:
            Tuple[str, str]: A tuple containing the prefix and suffix after splitting.

        Raises:
            AssertionError: If the thought's fitness score has not been calculated.
            ValueError: If the marker is not found in the thought's text and the fitness score is below 1.
        """
        idx = thought.text.find(marker)
        assert thought.fitness_score is not None, "thought fitness_score should be calculated"
        if idx == -1 and thought.fitness_score[0] >= 1:
            return thought.text, ""
        elif idx == -1:
            raise ValueError(f"No marker found in thought")
        else:
            prefix = thought.text[:idx].strip()
            suffix = thought.text[idx:].strip()
            return prefix, suffix
    
    async def crossover(self, thought_current: Thought, thought_external: Thought):
        """
        crossover
        """
        # ===== 构造输入消息 =====
        breakpoint_message = [
            {"role": "user", "content": self.breakpoint_prompt.format(
                query=self.query, thought_current=thought_current.text,
                thought_external=thought_external.text, answer=self.answer)}
        ]

        # ===== 并发请求 prefix_split 和 breakpoints =====
        breakpoints = await self.fetch_chat_completion(self.client[self.model_names['breakpoint']], breakpoint_message, self.model_names['breakpoint'])
        breakpoints = self.extract_result(breakpoints, "breakpoints")

        # ===== 构造prefix =====
        assert len(thought_current.thoughts) == 2, "thought_current should have 2 thoughts"
        prefix, suffix = thought_current.thoughts
        ## ===== 拼接crossover prompt =====
        crossover_prompt = crossover_template.format(
            prefix=prefix, breakpoint=breakpoints
        )

        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path[thought_current.model_name])

        message = tokenizer.apply_chat_template(
            [
                {"role": "system", "content": system_prompts['ChemCoTDataset']},
                {"role": "user", "content": self.query},
                {"role": "assistant", "content": crossover_prompt, "prefix": True,"partial": True},
            ], tokenize=False, add_generation_prompt=False, enable_thinking=False
        )
        
        outputs = self.client[thought_current.model_name].completions.create(
            model=thought_current.model_name,
            prompt=message,
            max_tokens=16384,
            temperature=0.6,
            top_p=0.95,
        )
        outputs = outputs.choices[0].text
        
        assert "<think>" not in prefix, "crossover output should not contain <think>"
        child_thought = "<think>\n" + prefix + "\n\n" + outputs
        return child_thought
    
    async def mutation(self, thought: Thought, sets="system1"):
        """
        mutation: Mutation
        sets = {"system1", "system2", "system3"}, which are related to overlength and accuracy:
        system1: The CoT answer has issues.
        system2: The CoT answer is completely correct but too long.
        system3: The CoT answer is completely correct but too short.
        """
        async def fetch_completion(client, messages, model_name):
            response = await client.completions.create(
                model=model_name,
                prompt=messages,
                max_tokens=20480,
                temperature=0.6,
                top_p=0.95
            )
            return response.choices[0].text
        

        if sets == "system1":
            # Think about where the problem lies in the CoT
            mutation_message_2 = [
                {"role": "user", "content": self.mutation_prompt[1].format(query=self.query, thought_current=thought.text, answer=self.answer)}
            ]
            suggestions = self.extract_result(await fetch_chat_completion(self.client[thought.model_name], mutation_message_2, thought.model_name), "mutation-suggestions")

            # mutation
            mutation_message_1 = [
                {"role": "system", "content": system_prompts['ChemCoTDataset']},
                {"role": "user", "content": self.query + "\n\nTips:\n" + suggestions}
            ]
            new_outputs = await fetch_chat_completion(self.client[thought.model_name], mutation_message_1, thought.model_name)

            return Thought(text=new_outputs, model_name=thought.model_name) # After Thought processing, .text does not contain information like </think>
        
        elif sets in ["system2", "system3"]:
            new_outputs = thought.text
        
        if sets in ["system2"]:
            # Simply use mutation_prompt_1 to simplify the CoT
            mutation_message = [
                {"role": "system", "content": system_prompts['ChemCoTDataset']},
                {"role": "user", "content": self.mutation_prompt[0].format(query=self.query, thought_current=new_outputs)}
            ]
        elif sets == "system3":
            # Simply use mutation_prompt_3 to enrich the CoT
            mutation_message = [
                {"role": "system", "content": system_prompts['ChemCoTDataset']},
                {"role": "user", "content": self.mutation_prompt[2].format(query=self.query, thought_current=new_outputs)}
            ]

        outputs = await fetch_chat_completion(self.client[thought.model_name], mutation_message, thought.model_name)

        child_thought = self.extract_result(outputs, "mutation-result")
        return child_thought

    async def initialize_population(self, seed_pairs: List[Tuple[str, str]]):
        """
        Initialize the population with several seed texts; generate a Thought object for each element.
        """
        self.population = [
            Thought(text=t, model_name=model_name) for t, model_name in seed_pairs
        ]
        # ===== Initialize behavior vectors =====
        behavior_vectors = self.compute_behavior_vector(self.population)
        for idx, behavior_vector in enumerate(behavior_vectors):
            self.population[idx].behavior_vector = behavior_vector
            # ===== Initialize fitness =====
            self.population[idx].fitness_score = self._default_fitness_fn(self.population[idx])

        # ===== split thought =====
        for r in range(5):
            prefix_messages = [[
                {"role": "user", "content": self.prefix_prompt.format(query=self.query, thought_current=th.text, answer=self.answer)}
            ] for th in self.population if len(th.thoughts) == 0]
            prefix_ids = [idx for idx, th in enumerate(self.population) if len(th.thoughts) == 0]
            if len(prefix_ids) == 0:
                break
            prefix_tasks = [asyncio.create_task(self.fetch_chat_completion(self.client[self.model_names['prefix']], prefix_message, self.model_names['prefix'], temperature=1.0)) for prefix_message in prefix_messages]
            prefix_splits = await asyncio.gather(*prefix_tasks)
            prefix_splits = [self.extract_result(ps, "prefix_split") for ps in prefix_splits]
            for j, prefix_split in zip(prefix_ids, prefix_splits):
                try:
                    prefix, suffix = self._split_with_marker(self.population[j], prefix_split)
                    assert prefix != self.population[j].text, "split failed"
                    self.population[j].thoughts = [prefix, suffix]
                except:
                    if r == 4:
                        warnings.warn(f"No marker found. Returning the complete thought.text, which may lead to unexpected results.", UserWarning)
                        self.population[j].thoughts = [self.population[j].text, ""]
                    else:
                        continue
    
    async def update_population(self):
        """
        The new data may not contain information and needs to be supplemented.
        """
        # ===== Initialize behavior vectors =====
        wo_bv = [(idx, th) for idx, th in enumerate(self.population) if th.behavior_vector is None]
        if wo_bv:
            behavior_vectors = self.compute_behavior_vector([th for idx, th in wo_bv])
            for idx, behavior_vector in zip([idx for idx, th in wo_bv], behavior_vectors):
                self.population[idx].behavior_vector = behavior_vector
        # ===== Initialize fitness =====
        for i, th in enumerate(self.population):
            self.population[i].fitness_score = self._default_fitness_fn(th)
        # ===== split thought =====
        for r in range(5):
            prefix_messages = [[
                {"role": "user", "content": self.prefix_prompt.format(query=self.query, thought_current=th.text, answer=self.answer)}
            ] for th in self.population if len(th.thoughts) == 0]
            prefix_ids = [idx for idx, th in enumerate(self.population) if len(th.thoughts) == 0]
            if len(prefix_ids) == 0:
                break
            prefix_tasks = [asyncio.create_task(self.fetch_chat_completion(self.client[self.model_names['prefix']], prefix_message, self.model_names['prefix'], temperature=1.0)) for prefix_message in prefix_messages]
            prefix_splits = await asyncio.gather(*prefix_tasks)
            prefix_splits = [self.extract_result(ps, "prefix_split") for ps in prefix_splits]
            for j, prefix_split in zip(prefix_ids, prefix_splits):
                try:
                    prefix, suffix = self._split_with_marker(self.population[j], prefix_split)
                    assert prefix != self.population[j].text, "split failed"
                    self.population[j].thoughts = [prefix, suffix]
                except:
                    if r == 4:
                        warnings.warn(f"No marker found. Returning the complete thought.text, which may lead to unexpected results.", UserWarning)
                        self.population[j].thoughts = [self.population[j].text, ""]
                    else:
                        continue
    
    def _default_fitness_fn(self, thought: Thought) -> float:
        """
        Fitness (placeholder implementation):
        - Returns 0.0 by default; users can pass a custom fitness_fn when calling loop()
        """
        # ===== accuracy =====
        eval_data = self.metadata
        if self.task_name == 'ChemCoTDataset':
            evaluator = ChemCoTEval([eval_data])
        elif self.task_name == 'BioProBench':
            evaluator = BioProBenchEval([eval_data])
        else:
            # TODO
            raise ValueError(f"Unknown task: {self.task_name}")
        
        func = getattr(evaluator, eval_data['task'], None)
        if func:
            accuracy = func(idx=0, preds=thought.pred)
        else:
            raise ValueError(f"Unknown task: {eval_data['task']}")
        
        # ===== token length =====
        tokens = len(self.crossover_tokenizer.encode(
            thought.text.split("</think>")[0].split("<think>")[-1].strip()
        ))
        overlength = int(tokens > 4333) + 2 * int(tokens < 1024) # 4333 and 1024 are statistical values; 0: ok; 1: too long; 2: too short
        return accuracy, overlength
    
    # ============ NSD Algorithm ============
    def compute_behavior_vector(self, thoughts, normalize: bool = True):
        """
        Use the vLLM embedding model self.embedder to compute the CoT embedding.
        Supports a single Thought or List[Thought].
        - When normalize=True, return a vector with mean 0 and variance 1.
        - Write the results back to thought.behavior_vector.
        Returns:
            If the input is a Thought -> np.ndarray
            If the input is a List[Thought] -> List[np.ndarray]
        """
        # Unify into a list
        single_input = False
        if isinstance(thoughts, Thought):
            thoughts = [thoughts]
            single_input = True

        # Check which ones still don't have behavior_vector
        texts_to_embed = []
        idx_to_embed = []
        for i, th in enumerate(thoughts):
            if not hasattr(th, "behavior_vector") or th.behavior_vector is None:
                text = th.text if th.text is not None else ""
                texts_to_embed.append(text)
                idx_to_embed.append(i)

        if texts_to_embed:
            # Call vLLM embedding
            outputs = self.embedder.embed(texts_to_embed, use_tqdm=False)  # Assuming self.embedder is vllm.LLM(task="embed")
            emb_list = [torch.tensor(o.outputs.embedding, dtype=torch.float32).numpy() for o in outputs]

            # Normalize
            if normalize:
                for j in range(len(emb_list)):
                    e = emb_list[j]
                    emb_list[j] = (e - e.mean()) / (e.std() + 1e-9)

            # Backfill to corresponding Thoughts
            for idx, emb in zip(idx_to_embed, emb_list):
                thoughts[idx].behavior_vector = emb

        # Output
        if single_input:
            return thoughts[0].behavior_vector
        else:
            return [th.behavior_vector for th in thoughts]

    
    def _pareto_front(self, population: List[Thought]) -> List[Thought]:
        """
        In a binary space (novelty_score, local_score), calculate the non-dominated set (Pareto front).
        Returns a list (no specific order).
        """
        front = []
        for a in population:
            dominated = False
            for b in population:
                if b is a:
                    continue
                # b dominates a if and only if b is at least as good in both dimensions and at least one dimension is strictly better
                if (b.novelty_score >= a.novelty_score and b.local_score >= a.local_score) and \
                   (b.novelty_score > a.novelty_score or b.local_score > a.local_score):
                    dominated = True
                    break
            if not dominated:
                front.append(a)
        return front
    
    def update_nl_scores(self, k_nn: int = 5):
        """
        Only calculate and update the behavior vector, novelty_score, and local_score for each Thought in self.population.
        """
        n = len(self.population)
        if n == 0:
            return

        # 1) Compute and cache behavior vectors
        vectors = []
        for i, th in enumerate(self.population):
            vec = self.compute_behavior_vector(th)
            if self.population[i].behavior_vector is None:
                self.population[i].behavior_vector = vec
            vectors.append(vec)
        vectors = np.vstack(vectors)

        # 2) Compute distance matrix
        dist_mat = np.linalg.norm(vectors[:, None, :] - vectors[None, :, :], axis=-1)
        np.fill_diagonal(dist_mat, np.inf)

        k = min(k_nn, n - 1) if n > 1 else 1
        novelty_scores = np.zeros(n, dtype=np.float32)
        local_scores = np.zeros(n, dtype=np.float32)

        fitnesses = np.array(
            [
                th.fitness_score if (hasattr(th, "fitness_score") and th.fitness_score is not None) else 0.0
                for th in self.population
            ],
            dtype=np.float32,
        )

        for i in range(n):
            if self.population[i].novelty_score and self.population[i].local_score:
                continue
            
            neigh_idx = np.argsort(dist_mat[i])[:k]
            if k > 0:
                novelty_scores[i] = float(dist_mat[i, neigh_idx].mean())
                diffs = fitnesses[i] - fitnesses[neigh_idx]
                pos = np.maximum(diffs, 0.0)
                local_scores[i] = float(pos.mean())
            else:
                novelty_scores[i] = 0.0
                local_scores[i] = 0.0

            self.population[i].novelty_score = novelty_scores[i]
            self.population[i].local_score = local_scores[i]

    def select_next_generation(self, target_size: Optional[int] = None):
        """
        Based on existing novelty_score/local_score (need to call update_nl_scores first),
        select the next generation using the NSLC strategy, return selected list.
        """
        n = len(self.population)
        if n == 0:
            return []
        if target_size is None:
            target_size = max(2, n // 2)

        # 1) Pareto front
        pareto = self._pareto_front(self.population)
        selected = list(pareto)

        # 2) Add remaining to selected if needed
        if len(selected) < target_size:
            remaining = [p for p in self.population if p not in selected]
            remaining_sorted = sorted(
                remaining,
                key=lambda x: (x.novelty_score + x.local_score),
                reverse=True,
            )
            needed = target_size - len(selected)
            to_add = remaining_sorted[:needed]
            # If there are not enough elements to add, allow resampling from the selected ones
            if len(to_add) < needed and selected:
                still_needed = needed - len(to_add)
                selected_sorted = sorted(
                    selected, key=lambda x: (x.novelty_score + x.local_score), reverse=True
                )
                extra = []
                idx = 0
                for _ in range(still_needed):
                    extra.append(selected_sorted[idx % len(selected_sorted)])
                    idx += 1
                to_add.extend(extra)
            selected.extend(to_add)

        # 3) If Pareto front is too large, trim
        if len(selected) > target_size:
            selected = sorted(selected, key=lambda x: x.local_score, reverse=True)[:target_size]

        return selected
    

    def _default_selection_fn(self, candidate_pool: List[Thought], num_pairs: int = 1, epsilon: float = 1e-6) -> List[Tuple[Thought, Thought]]:
        """
        批量选择 parent pairs。
        Args:
            candidate_pool: Candidate parent pool (with novelty_score, local_score computed)
            num_pairs: Number of parent pairs to return
            epsilon: Small constant to prevent zero probability
        Returns:
            List[(parent1, parent2)] of length num_pairs (may contain duplicates if pool is small)
        Strategy:
            - Prefer local_score; fallback to novelty_score if all local_score are zero
            - Try non-replacement sampling to reduce duplicates; allow replacement or conditional sampling if pool is small
        """
        if not candidate_pool:
            raise ValueError("candidate_pool is empty")
        n = len(candidate_pool)
        if num_pairs <= 0:
            return []

        # Construct sampling weights
        local_vals = np.array([getattr(c, "local_score", 0.0) for c in candidate_pool], dtype=np.float64)
        novel_vals = np.array([getattr(c, "novelty_score", 0.0) for c in candidate_pool], dtype=np.float64)

        # Select main weight
        if np.all(local_vals <= 0.0 + 1e-12):
            base = novel_vals + epsilon
        else:
            base = local_vals + epsilon

        # Normalize probabilities
        base_sum = float(base.sum())
        if base_sum <= 0:
            probs = np.ones(n, dtype=np.float64) / n
        else:
            probs = base / base_sum

        pairs: List[Tuple[Thought, Thought]] = []
        used_pairs: Set[FrozenSet[int]] = set()
        max_attempts = 8  # Maximum number of retry attempts for duplicate pairs

        # Fast non-replacement sampling when pool size is sufficient
        if n >= 2 * num_pairs:
            # First draw 2*num_pairs indices without replacement based on probs
            # numpy supports replace=False with p
            idxs = np.random.choice(n, size=2 * num_pairs, replace=False, p=probs)
            # Shuffle to randomize pair order
            np.random.shuffle(idxs)
            for i in range(num_pairs):
                a, b = int(idxs[2 * i]), int(idxs[2 * i + 1])
                pairs.append((candidate_pool[a], candidate_pool[b]))
            return pairs

        # If pool size is insufficient for non-replacement sampling, use a mixed strategy
        # Prefer non-replacement sampling when possible, then fallback to conditional sampling
        if n >= num_pairs:
            # Draw first parents without replacement (if possible)
            first_idxs = np.random.choice(n, size=num_pairs, replace=False, p=probs)
            for fi in first_idxs:
                # Construct conditional probability distribution for second parent:
                # Prefer not selecting fi itself (if possible), but allow self-selection if only one element remains
                if n == 1:
                    si = fi
                else:
                    probs2 = probs.copy()
                    probs2[fi] = 0.0
                    ssum = probs2.sum()
                    if ssum <= 0:
                        # If all probabilities are zero except fi, sample uniformly from remaining elements
                        others = [i for i in range(n) if i != fi]
                        si = np.random.choice(others)
                    else:
                        probs2 = probs2 / ssum
                        si = int(np.random.choice(n, p=probs2))
                # Duplicate check
                pair_key = frozenset({int(fi), int(si)})
                attempts = 0
                while pair_key in used_pairs and attempts < max_attempts:
                    # Retry sampling second parent (if possible)
                    if n == 1:
                        break
                    probs2 = probs.copy()
                    probs2[int(fi)] = 0.0
                    probs2 = probs2 / (probs2.sum() + 1e-12)
                    si = int(np.random.choice(n, p=probs2))
                    pair_key = frozenset({int(fi), int(si)})
                    attempts += 1
                used_pairs.add(pair_key)
                pairs.append((candidate_pool[int(fi)], candidate_pool[int(si)]))
            return pairs

        # If pool size is still too small, fallback to replacement sampling
        # Last resort: sample pairs with replacement (pool size is small)
        # Each pair is sampled independently, with preference to avoid p1==p2 (if possible)
        # and to minimize duplicate pairs
        for _ in range(num_pairs):
            attempts = 0
            while True:
                idx1 = int(np.random.choice(n, p=probs))
                if n == 1:
                    idx2 = idx1
                else:
                    probs2 = probs.copy()
                    probs2[idx1] = 0.0
                    if probs2.sum() <= 0:
                        # fallback uniform among others
                        others = [i for i in range(n) if i != idx1]
                        if others:
                            idx2 = int(np.random.choice(others))
                        else:
                            idx2 = idx1
                    else:
                        probs2 = probs2 / probs2.sum()
                        idx2 = int(np.random.choice(n, p=probs2))
                pair_key = frozenset({idx1, idx2})
                if pair_key not in used_pairs or attempts >= max_attempts:
                    used_pairs.add(pair_key)
                    pairs.append((candidate_pool[idx1], candidate_pool[idx2]))
                    break
                attempts += 1

        return pairs
    
    async def _make_child(self, p1: Thought, p2: Thought) -> Thought:
        try:
            # Crossover and mutation logic
            if p1.fitness_score[0] < p2.fitness_score[0]:
                p1, p2 = p2, p1

            # Select crossover or mutation based on fitness
            if p1.fitness_score[0] == 1:
                if random.random() < 0.75:
                    if p1.fitness_score[1] == 2:  # too short
                        child_text = await self.mutation(p1, sets="system3")
                    else:  # too long
                        child_text = await self.mutation(p1, sets="system2")
                    child_text = child_text.split("</think>")[0].split("<think>")[-1].strip()
                    used_pred = p1.pred
                    child_text = f"<think>\n{child_text}\n</think>\n\n{used_pred}"
                else:
                    child_text = await self.crossover(p2, p1)
                    assert "</think>" in child_text
            else:
                if random.random() < 0.4:
                    child_text = await self.crossover(p1, p2)
                    assert "</think>" in child_text
                else:
                    child_text = await self.mutation(p1, sets="system1") # Thought processed, thought.text does not contain </think>
                    used_pred = child_text.pred
                    child_text = await self.mutation(child_text, sets="system2")
                    child_text = child_text.split("</think>")[0].split("<think>")[-1].strip()
                    child_text = f"<think>\n{child_text}\n</think>\n\n{used_pred}"
        except Exception as e:
            print(e)
            return None

        assert child_text.count("</think>") == 1 and child_text.count("<think>") == 1, 'child_text error'
        child = Thought(text=child_text, model_name='qwen3-32b')
        return child
    
    async def loop(
        self,
        seed_pairs: List[Tuple[str, str]],
        n_generations: int = 5,
        population_size: int = 6,
        k_nn: int = 1,
    ):
        """
        Main genetic algorithm loop.
        - seed_pairs: Initial population (text, model_name).
        - n_generations: Number of iterations.
        - population_size: Target population size for each generation.
        - k_nn: Number of neighbors when calculating novelty/local scores.
        - fitness_fn: Custom fitness function (overrides the default _default_fitness_fn).
        """
        tqdm_bar = tqdm(range(n_generations), total=n_generations, desc=f"Generation 0/{n_generations}: Initializing")
        # 1) initialize population
        await self.initialize_population(seed_pairs)

        # 2) update vector representations and fitness scores
        self.update_nl_scores(k_nn=k_nn)

        for gen in range(n_generations):
            tqdm_bar.set_description(f"Generation {gen+1}/{n_generations}: selecting parents")
            # 3) select next generation parents
            parents_pool = self.select_next_generation(target_size=population_size // 2)

            tqdm_bar.set_description(f"Generation {gen+1}/{n_generations}: crossover and mutation")
            # 4) crossover and mutation to generate children
            num_children = population_size - len(parents_pool)
            children: List[Thought] = [None] * num_children
            if num_children > 0:
                parent_pairs = self._default_selection_fn(parents_pool, num_pairs=num_children)
                # parent_pairs is List[(p1,p2)], length == num_children (try to)
                # parallel generate children
                for _ in range(3):
                    none_indices = [idx for idx, child in enumerate(children) if child is None]
                    if not none_indices:
                        break
                    async with asyncio.TaskGroup() as tg:
                        retry_tasks = [tg.create_task(self._make_child(*parent_pairs[idx])) for idx in none_indices]
                        retry_results = await asyncio.gather(*retry_tasks)
                        for idx, result in zip(none_indices, retry_results):
                            children[idx] = result
            
            tqdm_bar.set_description(f"Generation {gen+1}/{n_generations}: merging parents and children")
            # 5) merge parents and children to form new population
            self.population = parents_pool + children

            tqdm_bar.set_description(f"Generation {gen+1}/{n_generations}: updating population")
            # 6) update vector representations and fitness scores
            for i in range(5):
                try:
                    await self.update_population()
                    break
                except Exception as e:
                    print(e)
            self.update_nl_scores(k_nn=k_nn)

            tqdm_bar.set_description(f"Generation {gen+1}/{n_generations}: printing statistics")
            # 7) print statistics
            fitness_vals = [th.fitness_score[0] + 0.25 * (2 - th.fitness_score[1]) for th in self.population]
            avg_fit = float(np.mean(fitness_vals))
            max_fit = float(np.max(fitness_vals))
            print(f"Avg fitness: {avg_fit:.4f}, Max fitness: {max_fit:.4f}")

            # 8) convergence condition: if max fitness is 1.0 (or other threshold) exit early
            if max_fit > 1.0:
                print("Optimal solution reached, exiting early.")
                break

        # 9) return the best individual
        fitness_vals = [th.fitness_score[0] + 0.25 * (2 - th.fitness_score[1]) for th in self.population]
        max_thoughts = [th for th in self.population if th.fitness_score[0] + 0.25 * (2 - th.fitness_score[1]) == max(fitness_vals)]
        sorted_thoughts = sorted(max_thoughts, key=lambda x: x.fitness_score[1], reverse=True)
        
        return sorted_thoughts[0]

async def run_multiple_ga(all_data, all_seed_pairs, cfg, n_generations=5, population_size=6, k_nn=1):
    """
    cfg: Dict
    all_data: List[ChemCoTDataset] 
    all_seed_pairs: List[List[Tuple[str,str]]] 
        For each dataset, the initial population of seed pairs (thought, model_name).
    """
    tasks = []
    for data, seed_pairs in zip(all_data, all_seed_pairs):
        ga = GeneticAlgorithm(**cfg, metadata=data)  # 每道题新建一个 GA 实例
        task = asyncio.create_task(ga.loop(seed_pairs,
                                           n_generations=n_generations,
                                           population_size=population_size,
                                           k_nn=k_nn))
        tasks.append(task)

    # 等待所有 GA 结束并收集结果
    results = await asyncio.gather(*tasks)
    return results

# ================= CLI =================
def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="Run GA with externalized config")
    # 基本任务与数据
    p.add_argument("--task_name", type=str, choices=["ChemCoTDataset", "BioProBench"], required=True)
    p.add_argument("--data_path", type=str, required=True, help="input data path")
    p.add_argument("--save_path", type=str, required=True, help="output jsonl path")
    p.add_argument("--batch_size", type=int, default=1)
    p.add_argument("--max_batches", type=int, default=0, help="only run first N batches, 0 for no limit")

    # API
    p.add_argument("--deepseek_api_base", type=str, default=os.getenv("DEEPSEEK_API_BASE"))
    p.add_argument("--deepseek_api_key", type=str, default=os.getenv("DEEPSEEK_API_KEY"))
    p.add_argument("--qwen_api_base", type=str, default=os.getenv("QWEN_API_BASE"))
    p.add_argument("--qwen_api_key", type=str, default=os.getenv("QWEN_API_KEY"))
    p.add_argument("--openai_api_base", type=str, default=os.getenv("OPENAI_API_BASE"))
    p.add_argument("--openai_api_key", type=str, default=os.getenv("OPENAI_API_KEY"))

    # Tokenizer path
    p.add_argument("--deepseek_tokenizer_path", type=str, default=None)
    p.add_argument("--qwen_tokenizer_path", type=str, default=None)

    # 向量化模型
    p.add_argument("--embedder_model_path", type=str, required=True)

    return p


if __name__ == '__main__':
    parser = build_parser()
    args = parser.parse_args()
    cfg = vars(args) 
    
    # qwen3-embedding-8b
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    embedder_model = LLM(model=args.embedder_model_path, task="embed", tensor_parallel_size=1, gpu_memory_utilization=0.9)
    # load_data
    f = open(args.data_path, "r")
    data = json.load(f)
    f.close()

    # 断点重连
    save_path = args.save_path
    if not os.path.exists(save_path):
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

    saved_data, begin_idx = [], 0
    if os.path.exists(save_path):
        f = open(save_path, "r")
        saved_data = [json.loads(line) for line in f.readlines()]
        begin_idx = len(saved_data)

    for i, d in enumerate(saved_data):
        assert data[i]['id'] == d['id']

    f = open(save_path, "w")
    for i, d in enumerate(saved_data):
        data[i]['output'] = d['output']
        f.write(json.dumps(d, ensure_ascii=False) + "\n")
    
    batch_size = 1
    batch_range = range(begin_idx, len(data), batch_size)
    tqdm_bar = tqdm(enumerate(batch_range), total=len(batch_range))

    for batch, idx in tqdm_bar:
        seed_pairs_batch = [d['seed_pairs'] for d in data[idx:idx+batch_size]]
        datas = [{k: v for k, v in d.items() if k != 'seed_pairs'} for d in data[idx:idx+batch_size]]
        tqdm_bar.set_description(f"{idx+1}/{len(data)}")
        outputs = asyncio.run(run_multiple_ga(datas, seed_pairs_batch, cfg))
        
        for i, output in enumerate(outputs):
            assert output is not None, "Error: output is None"
            data[idx+i][f"output"] = f"<think>\n{output.text}\n</think>\n\n{output.pred}"
            data[idx+i][f"fitness_score"] = output.fitness_score
            f.write(json.dumps(data[idx+i], ensure_ascii=False) + "\n")
    f.close()