import re
try:
    from trl import GRPOTrainer as RLTrainer
except ImportError:
    class RLTrainer:
        def __init__(self, *args, **kwargs):
            raise ImportError("trl is not installed, so RLTrainer cannot be used.")
from datasets import Dataset
import openai
import traceback
import numpy as np
from args import Args as CALM_ARGS

import os
from os import path
import importlib
from heuristic import HeuristicPolicy
from utils import extract_function_from_string, dedent, extract_first_double_braced, extract_idea_description, get_code, idea_distance
import ray

cfp = path.abspath(path.dirname(__file__))

class Prompt:

    def __init__(self, prompt):
        self.prompt = prompt
        self.base_codes = get_code(prompt)
        if not isinstance(self.base_codes, list):
            self.base_codes = [self.base_codes]
        self.n_calls = 0
        self.last_used_epoch = 0
        self.trials = {}
        
        self.feasible_algo_generated = False

    @property
    def op(self):
        if self.is_injection:
            return "injection"
        if self.is_crossover:
            return "crossover"
        if self.is_simplification:
            return "simplify"
        if self.is_creation:
            return "create"
        basic_mod = 'For the following algorithm, identify '
        if (basic_mod + 'a fixed, instance-independent decision rule') in self.prompt:
            return 'replacement_ins'
        if (basic_mod + 'a key hyper-parameter expressed as either a constant literal or a stationary variable') in self.prompt:
            return 'replacement_hyp'
        if (basic_mod + 'a fragment that assigns equal or near-equal credits to multiple elements') in self.prompt:
            return 'replacement_crd'
        return 'initialization'

    @property
    def is_creation(self):
        return 'Be very creative and inventive. Generate an efficient algorithm following the template below' in self.prompt
    @property
    def is_injection(self):
        return "Inject a novel, meaningful component into the following algorithm. The component may be self-devised or inspired by ideas from other domains or problems." in self.prompt
    
    @property
    def is_crossover(self):
        return "Please generate a new algorithm that is motivated by the following algorithms but performs better on any same instance" in self.prompt

    @property
    def is_replacement(self):
        return "For the following algorithm, identify" in self.prompt
    
    @property
    def is_simplification(self):
        return "Please create a simplified and more elegant version of an algorithm by distilling and refining the core ideas" in self.prompt

    def __hash__(self):
        return hash(self.prompt)
    def __eq__(self, other):
        if isinstance(other, Prompt):
            return self.prompt == other.prompt  # Compare based on `prompt`
        elif isinstance(other, str):
            return self.prompt == other  # Compare with string directly
        return False
    
@ray.remote
def evaluate_code(algo, instances):
    try:
        return algo.run_one_episode_sync(instances)
    except:
        print(traceback.format_exc())
        return 'code_bug'


class Trainer(RLTrainer):
    def __init__(self,
            calm_args: CALM_ARGS,
            model,
            is_qwen3=True,
            **kwargs
        ):
        self.problem_name = calm_args.problem_name
        self.calm_args = calm_args
        self.is_qwen3 = is_qwen3
        algorithm_template_fp = self.calm_args.template_fp if path.exists(self.calm_args.template_fp) else path.join(cfp, 'configs', self.problem_name, 'template.py')
        if path.exists(algorithm_template_fp):
            self.algorithm_template = open(algorithm_template_fp, 'r').read()
        self.problem = importlib.import_module(f'problems.{self.problem_name}')
        self.env = self.problem.Environment()
        self.instances = self.env.training_dataset()
        self.model = model

        self.algos = None
        self.seed_algos = []
        self.age_stuck = 0
        
        self.n_prompts = calm_args.n_prompts
        self.population_size = calm_args.population_size

        self.save_dir = path.join(cfp, 'calm_saved', self.problem_name, self.calm_args.log_name)
        os.makedirs(self.save_dir, exist_ok=True)
        self.log_dir = path.join(self.save_dir, 'output.log')

        self.used_prompts = []
        self.injected_components = []
        self.messages = []
        self.log_step = 0

        self.train_epoch = 0

        self.prepare_dataset()

        self.online = isinstance(self.model, str)
        if self.online:
            assert self.calm_args.api_key != '' and self.calm_args.base_url != '', 'To use API-based CALM, please provide your OpenAI API key and base URL in the config file.'
            self.client = openai.OpenAI(
                api_key=self.calm_args.api_key,
                base_url=self.calm_args.base_url,
            )
            assert self.client is not None
        else:
            super().__init__(
                model=self.model,
                train_dataset=self.train_dataset,
                reward_funcs=self.reward_func,
                **kwargs
            )
    
    @property
    def best_perf(self):
        if len(self.algos) > 0:
            return np.max([a.perf for a in self.algos])
        return -float('inf')

    def log_info(self, log_msg):
        with open(self.log_dir, 'a+') as fp:
            fp.write(f"<Epoch {self.train_epoch} / Step {self.log_step}> " + str(log_msg) + '\n')
        
    def min_performance_distance(self, algo, population):
        d = np.min([abs(algo.perf - a.perf) for a in population])
        scale = abs(population[0].perf - population[-1].perf)
        return d / scale
    
    def prepare_dataset(self):

        dataset_dict = {"prompt": []}

        if self.algos is None:
            self.algos = self.get_algos_and_performance()
            for init_algo in self.algos:
                self.seed_algos.append(init_algo)

        max_stuck_threshold = self.calm_args.max_steps // 20 if self.calm_args.max_stuck_threshold < 0 else self.calm_args.max_stuck_threshold
        if np.random.random() < self.calm_args.speed_collapse * self.age_stuck or self.age_stuck >= max_stuck_threshold:
            self.log_info(f"\n Collapse after stucking for {self.age_stuck} rounds \n")
            self.age_stuck = 0
            if len(self.algos) > 0:
                self.used_prompts = []
                best_algo = self.algos[np.argmax([a.perf for a in self.algos])]
                self.log_info(f"During collapse, the best algo with perf {best_algo.perf} has been kept")
                self.algos = [best_algo]
                for seed_algo in self.seed_algos:
                    if seed_algo not in self.algos:
                        self.algos.append(seed_algo)

        sorted_indices = np.argsort([a.perf for a in self.algos])[::-1]
        self.algos = [self.algos[i] for i in sorted_indices]

        algos_head = self.algos[:self.population_size]

        curr_used_prompts = []

        # ---------------------------------------------------------------------------- #
        #                                Prepare dataset                               #
        # ---------------------------------------------------------------------------- #
        ub_simplification = self.calm_args.ub_simplification
        ub_injection = self.calm_args.ub_injection
        ub_replacement = self.calm_args.ub_replacement
        ub_crossover = self.calm_args.ub_crossover

        n_ops = np.zeros(4, dtype=int)
        p_ops = np.array([ub_simplification, ub_injection, ub_replacement, ub_crossover]).astype(float)
        if len(self.algos) < 2:
            p_ops[-1] = 0
        if len(self.algos) < self.population_size and ub_injection > 0:
            p_ops[1] = np.max(p_ops)
        p_ops /= np.sum(p_ops)
        sample_res = np.random.choice(4, size=self.n_prompts, p=p_ops, replace=True)
        for i_op in sample_res:
            n_ops[i_op] += 1
        ub_simplification, ub_injection, ub_replacement, ub_crossover = n_ops
        self.log_info(f'UB of OPs: simplification - {ub_simplification}, injection - {ub_injection}, replacement - {ub_replacement}, crossover - {ub_crossover}')
        
        rank = 1 + np.arange(len(algos_head))
        p = 1 / rank
        p /= np.sum(p)
        # ------------------ Simplification, Injection, Replacement ----------------- #
        if len(algos_head) > 0:
            for upper_bound, prompt_template in zip([ub_simplification, ub_injection, ub_replacement], [self.prompt_simplification, self.prompt_injection, self.prompt_replacement]):
                n_new_prompts = 0
                n_trial = 0
                while n_trial <= 1000 and n_new_prompts < upper_bound:
                    n_trial += 1
                    indices = [np.random.choice(len(algos_head), p=p)]
                    algos_for_prompt = [algos_head[i] for i in indices]
                    prompt = Prompt(prompt_template(algos_for_prompt))
                    if prompt not in self.used_prompts:
                        self.used_prompts.append(prompt)
                    if prompt not in curr_used_prompts:
                        curr_used_prompts.append(prompt)
                    else:
                        continue
                    dataset_dict['prompt'].append([
                        {'role': 'assistant', 'content': self.system_prompt},
                        {'role': 'user', 'content': prompt.prompt}
                    ])
                    n_new_prompts += 1

            # --------------------------------- Crossover -------------------------------- #
            n_trial = 0
            n_new_crossover = 0
            while len(self.algos) >= 2 and n_trial < 1000 and n_new_crossover < ub_crossover:
                n_trial += 1
                algo_0_idx = np.random.choice(len(algos_head), p=p)
                algo_0 = algos_head[algo_0_idx]
                algo_1_idx = None
                log_msg = None
                if np.random.random() <= .5:
                    # Performance-based
                    algo_1_idx = np.random.choice(len(algos_head), p=p)
                    if algo_0 == algos_head[algo_1_idx]:
                        continue
                    log_msg = f"Crossover driven by performance: {algo_0.sid} (Rank {rank[algo_0_idx]}) x {self.algos[algo_1_idx].sid} (Rank {rank[algo_1_idx]})"
                else:
                    # Diversity-based
                    distances = [- idea_distance(base_idea=algo_0.idea, new_idea=algo_1.idea) for algo_1 in self.algos]
                    distance_rank = np.argsort(np.argsort(distances)) + 1
                    distance_based_p = 1 / distance_rank
                    distance_based_p /= np.sum(distance_based_p)
                    algo_1_idx = np.random.choice(len(self.algos), p=distance_based_p)
                    if distances[algo_1_idx] == 0:
                        continue
                    log_msg = f"Crossover driven by diversity: {algo_0.sid} (Rank {rank[algo_0_idx]}) x {self.algos[algo_1_idx].sid}, Distance: {distances[algo_1_idx]}"
                prompt = Prompt(self.prompt_crossover([algo_0, self.algos[algo_1_idx]]))
                if prompt not in self.used_prompts:
                    self.used_prompts.append(prompt)
                if prompt not in curr_used_prompts:
                    curr_used_prompts.append(prompt)
                else:
                    continue
                dataset_dict['prompt'].append([
                    {'role': 'assistant', 'content': self.system_prompt},
                    {'role': 'user', 'content': prompt.prompt}
                ])
                self.log_info(log_msg)
                n_new_crossover += 1


            # --------------------------------- Creation --------------------------------- #
            if len(dataset_dict['prompt']) == 0:
                self.log_info('No prompts have been added, add creation')
                prompt = self.prompt_creation
                dataset_dict['prompt'].append([
                    {'role': 'assistant', 'content': self.system_prompt},
                    {'role': 'user', 'content': self.prompt_creation}
                ])
                if prompt not in self.used_prompts:
                    self.used_prompts.append(Prompt(prompt))

        self.train_dataset = Dataset.from_dict(dataset_dict)
        self.messages = dataset_dict['prompt']
        self.train_epoch += 1

    def update_lr(self):
        pass
        # new_lr = .5 * self.calm_args.lr * (1.0 + np.cos(self.log_step / self.calm_args.max_steps * np.pi))
        # for param_group in self.optimizer.param_groups:
            # param_group['lr'] = new_lr
        
    def query_all(self):
        """
        Sequentially send each prompt in self.train_dataset to the LLM via self.client,
        then run self.reward_func on each reply.

        Returns:
            contents: List[str]         — the raw assistant replies
            rewards_list: List[List]    — the reward_func outputs (one list per prompt)
        """
        all_messages    = [entry['prompt'] for entry in self.train_dataset]
        all_contents    = []

        # 1) Query LLM for each prompt
        for messages in all_messages:
            resp = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
            )
            all_contents.append(resp.choices[0].message.content)

        # 2) Wrap into the shape reward_func expects
        completions_batch = [[{'content': c}] for c in all_contents]

        # 3) Call reward_func once on the whole batch
        rewards = self.reward_func(all_messages, completions_batch)

        return all_contents, rewards
    
    def reward_func(self, prompts, completions, **kwargs):
        self.log_step += 1
        if len(self.algos) >= self.population_size:
            self.age_stuck += 1
            self.log_info(f'Stuck counter += 1, arrives at {self.age_stuck}')
        self.log_info(f"=== Step {self.log_step} (Epoch {self.train_epoch}) ===")
        res = []
        mean_perfs = []
        average_idea_word_count = []
        curr_algos = []
        curr_prompts = []
        curr_responses = []
        curr_algos_reward_idx = []
        for curr_algo_i, (prompt, completion) in enumerate(zip(prompts, completions)):
            prompt = prompt[-1]['content']
            prompt_idx = self.used_prompts.index(prompt)
            prompt = self.used_prompts[prompt_idx]
            curr_prompts.append(prompt)
            algo_name = f'creation({self.log_step})'
            
            response = completion[0]['content']
            curr_responses.append(response)
            idea = extract_first_double_braced(response)
            if idea is None:
                idea = extract_idea_description(response)
            idea_start = 'The idea of the algorithm is to'
            if idea is None or not idea.startswith(idea_start) or len(idea) - len(idea_start) <= 10:
                res.append(self.calm_args.reward_idea_not_exist)
                continue

            code = get_code(response)
            if code is None:
                # Reward for not enclosing the code in a Python block
                res.append(self.calm_args.reward_code_not_exist)
                continue

            step_func = extract_function_from_string(code)
            if step_func is None:
                # Reward for no function found
                res.append(self.calm_args.reward_function_not_exist)
                continue
            
            algo = HeuristicPolicy(step_func, name=algo_name, problem_name=self.problem_name)
            algo.code = code
            algo.idea = idea
            algo.name = algo_name
            algo.birth = self.log_step
            algo.response = response
            algo.parent_prompt_type = prompt.op
            res.append(None)
            curr_algos.append(algo)
            curr_algos_reward_idx.append(curr_algo_i)

        curr_algo_running_results = ray.get([evaluate_code.remote(algo, self.instances) for algo in curr_algos])
        for algo, curr_algo_i, running_res in zip(curr_algos, curr_algos_reward_idx, curr_algo_running_results):
            code = algo.code
            if isinstance(running_res, str):
                # Running error
                res[curr_algo_i] = self.calm_args.reward_bug_in_function
                continue
            prompt = curr_prompts[curr_algo_i]
            perfs = np.array(running_res['performance'])
            perf = np.mean(perfs)
            if 'random' in code or 'np.random' in code:
                res[curr_algo_i] = self.calm_args.reward_random_algorithm
                continue
            algo.perf = perf
            algo.perfs = perfs.copy()

            base_algos = []
            for base_code in prompt.base_codes:
                for a in self.algos:
                    if a.code.strip() == base_code.strip():
                        base_algos.append(a)
                        break

            is_new = algo not in self.algos
            is_new_best = perf > self.best_perf
            if is_new:
                self.algos.append(algo)
            best_base_perf = np.max([a.perf for a in base_algos])
            is_better = perf > best_base_perf
            if is_new_best:
                self.age_stuck = 0
                self.log_info(f"New best performance: {perf} at step {self.log_step} by {algo_name}")
                self.log_info(f"Idea: {algo.idea}")
                os.makedirs(path.join(self.save_dir, 'algos'), exist_ok=True)
                with open(path.join(self.save_dir, 'algos', f"S{self.log_step}_{algo.sid}.py"), 'w+') as fp:
                    fp.write(code)

            reward = None

            average_idea_word_count.append(len(algo.idea.split()))

            if prompt.op == 'initialization':
                reward = 0.0
            else:
                delta_perf = np.clip(abs(perf - best_base_perf) / min(abs(perf), abs(best_base_perf)), 1e-10, 1.0)
                if is_better:
                    reward = 1.0 + delta_perf
                else:
                    if perf >= best_base_perf:
                        reward = 0.0
                    else:
                        reward = self.calm_args.reward_random_algorithm / 2 * (delta_perf if algo not in base_algos else (2*.8))
                if prompt.is_injection:
                    # Injected component
                    match = re.search(r"The new component ([A-Za-z()'\- ]+?) has been introduced", curr_responses[curr_algo_i])
                    if match:
                        new_component = match.group(1).strip()
                        if new_component not in self.injected_components and is_new:
                            self.injected_components.append(new_component)
                            self.log_info(f'New component {new_component} has been introduced')

            assert reward is not None
            res[curr_algo_i] = reward

            mean_perfs.append(perf)

        self.update_lr()
        perfs = map(str, sorted([a.perf for a in self.algos])[::-1])
        self.log_info(f"Number of algos: {len(self.algos)}, Perfs: {','.join(perfs)}")

        return res
    
    @property
    def system_prompt(self):
        return f"""\
            Searching superior heuristics on the {self.problem.name} problem in an evolutionary manner through conversation between User and Assistant. In this problem, {self.problem.description} The User provides existing algorithms and requests a new one.\n\n{self.prompt_algo_requirements()}""" + ('/no_think' if self.is_qwen3 else '')
    
    @property
    def prompt_creation(self):
        assert self.algorithm_template is not None, 'No template was provided while prompting for creation'
        return f"""Be very creative and inventive. Generate an efficient algorithm following the template below:\n\n{self.algorithm_template}"""
    
    def prompt_simplification(self, algos):
        return f"""Please create a simplified and more elegant version of an algorithm by distilling and refining the core ideas from the following:\n\n{self.prompt_algo_details(algos)}"""
    
    def prompt_injection(self, algos):
        prompt = f"""Inject a novel, meaningful component into the following algorithm. The component may be self-devised or inspired by ideas from other domains or problems.\n\n{self.prompt_algo_details(algos)}\n\nUse a concise noun phrase to describe the new component in the responded idea like "The new component ... has been introduced."."""
        if len(self.injected_components) > 0:
            prompt += f""" Exclude the following components that have already been explored: {', '.join(self.injected_components[-10:])}."""
        return prompt
    
    def prompt_replacement(self, algos):
        _MODE_SPECS = [
            ('a fixed, instance-independent decision rule', 'an instance-dependent rule that derives its value from the current observation'),
            ('a key hyper-parameter expressed as either a constant literal or a stationary variable', 'a more principled constant justified by theory or practice'),
            ('a fragment that assigns equal or near-equal credits to multiple elements', 'a fragment where credits are deterministically and reasonably differentiated')
        ]
        p1, p2 = _MODE_SPECS[np.random.choice(len(_MODE_SPECS))]
        prompt = f"""For the following algorithm, identify {p1} and rewrite it to {p2}.\n\n{self.prompt_algo_details(algos)}"""
        return prompt

    def prompt_algo_details(self, algos):
        algo_detail = ""
        sort_indices = np.argsort([a.perf for a in algos])[::-1]
        algos = [algos[i] for i in sort_indices]
        if len(algos) == 0:
            return f"""## The Algorithm\n* Performance: {algos[0].perf_str} {self.problem.unit}\n* Idea: {algos[0].idea}\n* Code: ```python\n{algo.code}```\n\n"""
        
        for i, algo in enumerate(algos):
            algo_detail += f"""## Algorithm {i+1}\n* Performance: {algo.perf_str} {self.problem.unit} (Rank: {i + 1})\n* Idea: {algo.idea}\n* Code:```python\n{algo.code}```\n\n"""
            algo.last_used_epoch = self.train_epoch
        return algo_detail.strip()
    
    def prompt_algo_requirements(self):
        return dedent("""\
            ## Your Task
            You should first present a concise conceptual description, followed by a complete code implementation.
            
            * The description must:
                * Be enclosed with a double brace and starts with "The idea of the algorithm is to".
                * Ensure it is self-contained, insightful, and creatively original.
                * Not reference or rely on any prior ideas or existing code.
            * The code must:
                * Strictly follow the input-output variable names and types used in the provided implementation.
                * Be a single Python function formatted within Python code blocks.
                * Exclude any usage examples.
                * Ensure the algorithm is deterministic.
                * Avoid introducing unnecessary, arbitrarily-tuned hyperparameters; any parameters used should be essential and systematically derived from the input.
                      
            Overall, your response should be like:
            {{The idea of the algorithm is to (sepcific description here)}}
            ```python
            your code here
            ```
            Except for the idea and code, do not give additional explanations or comments.\
        """)
    
    def prompt_crossover(self, algos):
        return f"""Please generate a new algorithm that is motivated by the following algorithms but performs better on any same instance.\n{self.prompt_algo_details(algos)}
        """
    
    def get_algos_and_performance(self):
        res = []
        
        seed_algo_fp = self.calm_args.seed_algo_fp
        if seed_algo_fp == '':
            seed_algo_fp = path.join(cfp, 'configs', self.problem_name, 'seed.py')
        else:
            seed_algo_fp = path.join(cfp, seed_algo_fp)
        if path.exists(seed_algo_fp):
            algo_name = 'seed'
            algo = HeuristicPolicy(step_func=seed_algo_fp, name=algo_name, problem_name=self.problem_name)
            perfs = algo.run_one_episode_sync(instances=self.instances)['performance']
            idea = open(seed_algo_fp, 'r').readlines()[0][1:].strip()
            if not idea.startswith('The idea of the algorithm is to'):
                idea = 'The idea of the algorithm is to solve the {self.problem_name} in some way'
            algo.perf = np.mean(perfs)
            algo.perfs = perfs.copy()
            algo.idea = idea
            res.append(algo)
        return res
    