import hydra
import numpy as np
import json
import logging
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import openai
import re
import subprocess
from pathlib import Path
import shutil
import time
import random

from utils.misc import *
from utils.file_utils import find_files_with_substring, load_tensorboard_logs
from utils.create_task import create_task
from utils.extract_task_code import *

REVOLVE_ROOT_DIR = os.getcwd()
ISAAC_ROOT_DIR = f"{REVOLVE_ROOT_DIR}/../isaacgymenvs/isaacgymenvs"
API_KEY = os.getenv("OPENAI_API_KEY") # get openai API key from environment variable or just put your API string here "..."

class EvoNode:
    def __init__(self, prompt = None, reward_func=None, history_exec_state="", exec_file_path=""):
        self.reward_func = reward_func
        self.reward_cur = 0.0
        self.history_exec_state = history_exec_state
        self.exec_file_path = ""
        self.prompt = prompt
        self.epoch_freq = None

class Revolve:
    def __init__(self, cfg, prompts_dict, island_num=1, max_population_num=16):
        self.elite_set = [] # append EvoNode here
        self.prompts_dict = prompts_dict
        self.elite_weight_bias = 1
        self.crossover_prob = 0.5
        self.elite_max_length = 13

    def select_nodes(self, nums):
        # change the logic of sampling
        ranks = [i for i in range(len(self.elite_set))]
        probs = [1 / (rank + 1 + self.elite_weight_bias) for rank in ranks]
        nodes = random.choices(self.elite_set, weights=probs, k=nums)
        return nodes

    def sample_from_elite(self, iter, initial_failed=False):
        if iter == 0 or initial_failed:
            messages = self.initialize()
        else:
            mutation = True if random.random() >= self.crossover_prob else False
            if mutation:
                node = self.select_nodes(1)[0]
                messages = self.mutation(node)
            else:
                nodes = self.select_nodes(2)
                messages = self.crossover(nodes)
        return messages

    def initialize(self):
        messages_temp = [
            {"role": "system", "content": self.prompts_dict['initial_system']},
            {"role": "user", "content": self.prompts_dict['initial_user']}]
        return messages_temp

    def mutation(self, node):
        messages_temp = [
            {"role": "system", "content": self.prompts_dict['initial_system']},
            {"role": "user", "content": self.prompts_dict['initial_user'] +
                                        self.prompts_dict['mutation'].format(reward_function=node.reward_func,
                                                                                              epoch_freq=node.epoch_freq,
                                                                                              trained_results=node.history_exec_state)}
        ]
        return messages_temp

    def crossover(self, nodes):
        content = ''
        for i in list(range(len(nodes))):
            content += self.prompts_dict['reward_code_group'].format(i=i + 1,
                                                                     reward_function=nodes[i].reward_func,
                                                                     trained_results=nodes[i].history_exec_state)
        messages_temp = [
            {"role": "system", "content": self.prompts_dict['initial_system']},
            {"role": "user", "content": self.prompts_dict['initial_user'] +
                                        self.prompts_dict['crossover'].format(reward_func_group=content, epoch_freq=nodes[0].epoch_freq)}
        ]
        return messages_temp

    def save_elite(self, filename="elite.json"):
        elite_info = []

        for node in self.elite_set:
            node_info = {
                'reward_func': node.reward_func,
                'reward_cur': node.reward_cur,
                'history_exec_state': node.history_exec_state,
                'exec_file_path': node.exec_file_path,
                'prompt': node.prompt,
                'epoch_freq': node.epoch_freq,
            }
            elite_info.append(node_info)

        # 将精英集信息保存为JSON文件
        with open(filename, 'w') as f:
            json.dump(elite_info, f, indent=4)
        logging.info(f"Elite set saved to {filename}")


@hydra.main(config_path="cfg", config_name="config_revolve", version_base="1.1")
def main(cfg):
    workspace_dir = Path.cwd()
    logging.info(f"Workspace: {workspace_dir}")
    logging.info(f"Project Root: {REVOLVE_ROOT_DIR}")

    openai.api_key = API_KEY

    task = cfg.env.task
    task_description = cfg.env.description
    suffix = cfg.suffix
    model = cfg.model
    logging.info(f"Using LLM: {model}")
    logging.info("Task: " + task)
    logging.info("Task description: " + task_description)

    env_name = cfg.env.env_name.lower()
    env_parent = 'isaac' if f'{env_name}.py' in os.listdir(f'{REVOLVE_ROOT_DIR}/envs/isaac') else 'bidex'
    task_file = f'{REVOLVE_ROOT_DIR}/envs/{env_parent}/{env_name}.py'
    task_obs_file = f'{REVOLVE_ROOT_DIR}/envs/{env_parent}/{env_name}_obs.py'
    shutil.copy(task_obs_file, f"env_init_obs.py")
    task_code_string = file_to_string(task_file)
    task_obs_code_string = file_to_string(task_obs_file)
    output_file = f"{ISAAC_ROOT_DIR}/tasks/{env_name}{suffix.lower()}.py"

    # Loading all text prompts
    prompt_dir = f'{REVOLVE_ROOT_DIR}/utils/prompts'
    prompt_dir_revolve = f'{REVOLVE_ROOT_DIR}/utils/prompts_revolve'
    initial_system = file_to_string(f'{prompt_dir}/initial_system.txt')
    code_output_tip = file_to_string(f'{prompt_dir}/code_output_tip.txt')
    code_feedback = file_to_string(f'{prompt_dir}/code_feedback.txt')
    initial_user = file_to_string(f'{prompt_dir}/initial_user.txt')
    reward_signature = file_to_string(f'{prompt_dir}/reward_signature.txt')
    policy_feedback = file_to_string(f'{prompt_dir}/policy_feedback.txt')
    execution_error_feedback = file_to_string(f'{prompt_dir}/execution_error_feedback.txt')

    initial_system = initial_system.format(task_reward_signature_string=reward_signature) + code_output_tip
    initial_user = initial_user.format(task_obs_code_string=task_obs_code_string, task_description=task_description)

    mutation_feedback = file_to_string(f'{prompt_dir_revolve}/mutation_auto.txt')
    crossover_feedback = file_to_string(f'{prompt_dir_revolve}/crossover_auto.txt')
    reward_code_group = file_to_string(f'{prompt_dir_revolve}/reward_code_group.txt')

    prompts_dict = {'initial_system': initial_system, 'initial_user': initial_user, 'reward_code_group': reward_code_group,'mutation': mutation_feedback, 'crossover': crossover_feedback}


    task_code_string = task_code_string.replace(task, task + suffix)
    # Create Task YAML files
    create_task(ISAAC_ROOT_DIR, cfg.env.task, cfg.env.env_name, suffix)

    DUMMY_FAILURE = -10000.
    max_successes = []
    max_successes_reward_correlation = []
    execute_rates = []
    best_code_paths = []
    max_success_overall = DUMMY_FAILURE
    max_success_reward_correlation_overall = DUMMY_FAILURE
    max_reward_code_path = None

    revolve = Revolve(cfg, prompts_dict=prompts_dict)

    initial_failed = True

    # Revolve generation loop
    for iter in range(cfg.iteration):
        # Get Revolve response
        responses = []
        response_cur = None
        total_samples = 0
        total_token = 0
        total_completion_token = 0
        chunk_size = 1

        logging.info(f"Iteration {iter}: Generating {cfg.sample} samples with {cfg.model}")

        temp_node_list = []

        for i in range(cfg.sample):

            # sample

            messages = revolve.sample_from_elite(iter, initial_failed)
            temp_node_list.append(EvoNode(prompt=messages))

            for attempt in range(1000):
                try:
                    response_cur = openai.ChatCompletion.create(
                        model=model,
                        messages=messages,
                        temperature=cfg.temperature,
                        n=chunk_size
                    )
                    total_samples += chunk_size
                    break
                except Exception as e:
                    if attempt >= 10:
                        chunk_size = max(int(chunk_size / 2), 1)
                        print("Current Chunk Size", chunk_size)
                    logging.info(f"Attempt {attempt + 1} failed with error: {e}")
                    time.sleep(1)
            if response_cur is None:
                logging.info("Code terminated due to too many failed attempts!")
                exit()

            responses.extend(response_cur["choices"])
            prompt_tokens = response_cur["usage"]["prompt_tokens"]
            total_completion_token += response_cur["usage"]["completion_tokens"]
            total_token += response_cur["usage"]["total_tokens"]

        # Logging Token Information
        logging.info(
            f"Iteration {iter}: Prompt Tokens: {prompt_tokens}, Completion Tokens: {total_completion_token}, Total Tokens: {total_token}")

        code_runs = []
        rl_runs = []
        for response_id in range(cfg.sample):
            response_cur = responses[response_id]["message"]["content"]
            logging.info(f"Iteration {iter}: Processing Code Run {response_id}")

            # Regex patterns to extract python code enclosed in GPT response
            patterns = [
                r'```python(.*?)```',
                r'```(.*?)```',
                r'"""(.*?)"""',
                r'""(.*?)""',
                r'"(.*?)"',
            ]
            for pattern in patterns:
                code_string = re.search(pattern, response_cur, re.DOTALL)
                if code_string is not None:
                    code_string = code_string.group(1).strip()
                    break
            code_string = response_cur if not code_string else code_string

            # Remove unnecessary imports
            lines = code_string.split("\n")
            for i, line in enumerate(lines):
                if line.strip().startswith("def "):
                    code_string = "\n".join(lines[i:])

            # Add the Revolve Reward Signature to the environment code
            try:
                gpt_reward_signature, input_lst = get_function_signature(code_string)
            except Exception as e:
                logging.info(f"Iteration {iter}: Code Run {response_id} cannot parse function signature!")
                continue

            code_runs.append(code_string)
            reward_signature = [
                f"self.rew_buf[:], self.rew_dict = {gpt_reward_signature}",
                f"self.extras['gpt_reward'] = self.rew_buf.mean()",
                f"for rew_state in self.rew_dict: self.extras[rew_state] = self.rew_dict[rew_state].mean()",
            ]
            indent = " " * 8
            reward_signature = "\n".join([indent + line for line in reward_signature])
            if "def compute_reward(self)" in task_code_string:
                task_code_string_iter = task_code_string.replace("def compute_reward(self):",
                                                                 "def compute_reward(self):\n" + reward_signature)
            elif "def compute_reward(self, actions)" in task_code_string:
                task_code_string_iter = task_code_string.replace("def compute_reward(self, actions):",
                                                                 "def compute_reward(self, actions):\n" + reward_signature)
            else:
                raise NotImplementedError

            # Save the new environment code when the output contains valid code string!
            with open(output_file, 'w') as file:
                file.writelines(task_code_string_iter + '\n')
                file.writelines("from typing import Tuple, Dict" + '\n')
                file.writelines("import math" + '\n')
                file.writelines("import torch" + '\n')
                file.writelines("from torch import Tensor" + '\n')
                if "@torch.jit.script" not in code_string:
                    code_string = "@torch.jit.script\n" + code_string
                file.writelines(code_string + '\n')

            with open(f"env_iter{iter}_response{response_id}_rewardonly.py", 'w') as file:
                file.writelines(code_string + '\n')

            # todo
            temp_node_list[response_id].reward_func = code_string
            temp_node_list[response_id].exec_file_path = f"env_iter{iter}_response{response_id}.py"


            # Copy the generated environment code to hydra output directory for bookkeeping
            shutil.copy(output_file, f"env_iter{iter}_response{response_id}.py")

            # Find the freest GPU to run GPU-accelerated RL
            set_freest_gpu()
            env = os.environ.copy()

            # Execute the python file with flags
            rl_filepath = f"env_iter{iter}_response{response_id}.txt"
            # todo: where change the env with new reward func
            with open(rl_filepath, 'w') as f:
                process = subprocess.Popen(['python', '-u', f'{ISAAC_ROOT_DIR}/train_with_seed.py',
                                            'hydra/output=subprocess',
                                            f'task={task}{suffix}', f'wandb_activate={cfg.use_wandb}',
                                            f'wandb_entity={cfg.wandb_username}', f'wandb_project={cfg.wandb_project}',
                                            f'headless={not cfg.capture_video}', f'capture_video={cfg.capture_video}',
                                            'force_render=False',
                                            f'max_iterations={cfg.max_iterations}'],
                                           stdout=f, stderr=f, env=env)
            block_until_training(rl_filepath, log_status=True, iter_num=iter, response_id=response_id)
            rl_runs.append(process)

        # Gather RL training results and construct reward reflection
        code_feedbacks = []
        contents = []
        successes = []
        reward_correlations = []
        code_paths = []

        exec_success = False
        for response_id, (code_run, rl_run) in enumerate(zip(code_runs, rl_runs)):
            rl_run.communicate()
            rl_filepath = f"env_iter{iter}_response{response_id}.txt"
            code_paths.append(f"env_iter{iter}_response{response_id}.py")
            try:
                with open(rl_filepath, 'r') as f:
                    stdout_str = f.read()
            except:
                content = execution_error_feedback.format(
                    traceback_msg="Code Run cannot be executed due to function signature error! Please re-write an entirely new reward function!")
                content += code_output_tip
                contents.append(content)
                successes.append(DUMMY_FAILURE)
                reward_correlations.append(DUMMY_FAILURE)

                temp_node_list[response_id].history_exec_state = content
                temp_node_list[response_id].reward_cur = DUMMY_FAILURE

                continue

            content = ''
            traceback_msg = filter_traceback(stdout_str)
            stalbe_train_flag = filter_cuda_oom(stdout_str)

            if traceback_msg == '' and stalbe_train_flag:
                initial_failed = False
                # If RL execution has no error, provide policy statistics feedback
                exec_success = True
                lines = stdout_str.split('\n')
                for i, line in enumerate(lines):
                    if line.startswith('Tensorboard Directory:'):
                        break
                tensorboard_logdir = line.split(':')[-1].strip()
                tensorboard_logs = load_tensorboard_logs(tensorboard_logdir)
                max_iterations = np.array(tensorboard_logs['gt_reward']).shape[0]
                epoch_freq = max(int(max_iterations // 10), 1)

                # Compute Correlation between Human-Engineered and GPT Rewards
                if "gt_reward" in tensorboard_logs and "gpt_reward" in tensorboard_logs:
                    gt_reward = np.array(tensorboard_logs["gt_reward"])
                    gpt_reward = np.array(tensorboard_logs["gpt_reward"])
                    reward_correlation = np.corrcoef(gt_reward, gpt_reward)[0, 1]
                    reward_correlations.append(reward_correlation)

                # Add reward components log to the feedback
                for metric in tensorboard_logs:
                    if "/" not in metric:
                        metric_cur = ['{:.2f}'.format(x) for x in tensorboard_logs[metric][::epoch_freq]]
                        metric_cur_max = max(tensorboard_logs[metric])
                        metric_cur_mean = sum(tensorboard_logs[metric]) / len(tensorboard_logs[metric])
                        if "consecutive_successes" == metric:
                            successes.append(metric_cur_max)
                            temp_node_list[response_id].reward_cur = metric_cur_max
                        metric_cur_min = min(tensorboard_logs[metric])
                        if metric != "gt_reward" and metric != "gpt_reward":
                            if metric != "consecutive_successes":
                                metric_name = metric
                            else:
                                metric_name = "task_score"
                            content += f"{metric_name}: {metric_cur}, Max: {metric_cur_max:.2f}, Mean: {metric_cur_mean:.2f}, Min: {metric_cur_min:.2f} \n"
                        else:
                            # Provide ground-truth score when success rate not applicable
                            if "consecutive_successes" not in tensorboard_logs:
                                content += f"ground-truth score: {metric_cur}, Max: {metric_cur_max:.2f}, Mean: {metric_cur_mean:.2f}, Min: {metric_cur_min:.2f} \n"
                                temp_node_list[response_id].reward_cur = metric_cur
                code_feedbacks.append(code_feedback)
                temp_node_list[response_id].history_exec_state = content
            else:
                # Otherwise, provide execution traceback error feedback
                successes.append(DUMMY_FAILURE)
                reward_correlations.append(DUMMY_FAILURE)
                content += execution_error_feedback.format(traceback_msg=traceback_msg)
                temp_node_list[response_id].history_exec_state = content
                temp_node_list[response_id].reward_cur = DUMMY_FAILURE

            content += code_output_tip
            contents.append(content)

            # Repeat the iteration if all code generation failed
        if not exec_success and cfg.sample != 1:
            execute_rates.append(0.)
            max_successes.append(DUMMY_FAILURE)
            max_successes_reward_correlation.append(DUMMY_FAILURE)
            best_code_paths.append(None)
            logging.info("All code generation failed! Repeat this iteration from the current message checkpoint!")
            continue

        # Select the best code sample based on the success rate
        best_sample_idx = np.argmax(np.array(successes))
        best_content = contents[best_sample_idx]

        max_success = successes[best_sample_idx]
        max_success_reward_correlation = reward_correlations[best_sample_idx]
        execute_rate = np.sum(np.array(successes) >= 0.) / cfg.sample

        # Update the best Revolve Output
        if max_success > max_success_overall:
            max_success_overall = max_success
            max_success_reward_correlation_overall = max_success_reward_correlation
            max_reward_code_path = code_paths[best_sample_idx]

        execute_rates.append(execute_rate)
        max_successes.append(max_success)
        max_successes_reward_correlation.append(max_success_reward_correlation)
        best_code_paths.append(code_paths[best_sample_idx])

        logging.info(
            f"Iteration {iter}: Max Success: {max_success}, Execute Rate: {execute_rate}, Max Success Reward Correlation: {max_success_reward_correlation}")
        logging.info(f"Iteration {iter}: Best Generation ID: {best_sample_idx}")
        logging.info(
            f"Iteration {iter}: GPT Output Content:\n" + responses[best_sample_idx]["message"]["content"] + "\n")
        logging.info(f"Iteration {iter}: User Content:\n" + best_content + "\n")

        # Plot the success rate
        fig, axs = plt.subplots(2, figsize=(6, 6))
        fig.suptitle(f'{cfg.env.task}')

        x_axis = np.arange(len(max_successes))

        axs[0].plot(x_axis, np.array(max_successes))
        axs[0].set_title("Max Success")
        axs[0].set_xlabel("Iteration")

        axs[1].plot(x_axis, np.array(execute_rates))
        axs[1].set_title("Execute Rate")
        axs[1].set_xlabel("Iteration")

        fig.tight_layout(pad=3.0)
        plt.savefig('summary.png')
        np.savez('summary.npz', max_successes=max_successes, execute_rates=execute_rates,
                 best_code_paths=best_code_paths, max_successes_reward_correlation=max_successes_reward_correlation)

        for evonode in temp_node_list:
            if evonode.reward_cur != DUMMY_FAILURE:
                revolve.elite_set.append(evonode)

        revolve.elite_set.sort(key=lambda x: x.reward_cur, reverse=True)
        if len(revolve.elite_set) > revolve.elite_max_length:
            revolve.elite_set = revolve.elite_set[:revolve.elite_max_length]

        revolve.save_elite()

    if revolve.elite_set:
        for i in range(len(revolve.elite_set)):
            node = revolve.elite_set[i]
            logging.info(
                f"Elite reward node {i} file: {node.exec_file_path}, current reward: {node.reward_cur}")

    # Evaluate the best reward code many times
    if max_reward_code_path is None:
        logging.info("All iterations of code generation failed, aborting...")
        logging.info("Please double check the output env_iter*_response*.txt files for repeating errors!")
        exit()
    logging.info(
        f"Task: {task}, Max Training Success {max_success_overall}, Correlation {max_success_reward_correlation_overall}, Best Reward Code Path: {max_reward_code_path}")
    logging.info(f"Evaluating best reward code {cfg.num_eval} times")
    shutil.copy(max_reward_code_path, output_file)

    eval_runs = []
    for i in range(cfg.num_eval):
        set_freest_gpu()
        env = os.environ.copy()

        # Execute the python file with flags
        rl_filepath = f"reward_code_eval{i}.txt"
        with open(rl_filepath, 'w') as f:
            process = subprocess.Popen(['python', '-u', f'{ISAAC_ROOT_DIR}/train_with_seed.py',
                                        'hydra/output=subprocess',
                                        f'task={task}{suffix}', f'wandb_activate={cfg.use_wandb}',
                                        f'wandb_entity={cfg.wandb_username}', f'wandb_project={cfg.wandb_project}',
                                        f'headless={not cfg.capture_video}', f'capture_video={cfg.capture_video}',
                                        'force_render=False', f'seed={i}',
                                        f'max_iterations={cfg.test_max_iterations}'],
                                       stdout=f, stderr=f, env=env)

        block_until_training(rl_filepath)
        eval_runs.append(process)

    reward_code_final_successes = []
    reward_code_correlations_final = []
    for i, rl_run in enumerate(eval_runs):
        rl_run.communicate()
        rl_filepath = f"reward_code_eval{i}.txt"
        with open(rl_filepath, 'r') as f:
            stdout_str = f.read()
        lines = stdout_str.split('\n')
        for i, line in enumerate(lines):
            if line.startswith('Tensorboard Directory:'):
                break
        tensorboard_logdir = line.split(':')[-1].strip()
        tensorboard_logs = load_tensorboard_logs(tensorboard_logdir)
        max_success = max(tensorboard_logs['consecutive_successes'])
        reward_code_final_successes.append(max_success)

        if "gt_reward" in tensorboard_logs and "gpt_reward" in tensorboard_logs:
            gt_reward = np.array(tensorboard_logs["gt_reward"])
            gpt_reward = np.array(tensorboard_logs["gpt_reward"])
            reward_correlation = np.corrcoef(gt_reward, gpt_reward)[0, 1]
            reward_code_correlations_final.append(reward_correlation)

    logging.info(
        f"Final Success Mean: {np.mean(reward_code_final_successes)}, Std: {np.std(reward_code_final_successes)}, Raw: {reward_code_final_successes}")
    logging.info(
        f"Final Correlation Mean: {np.mean(reward_code_correlations_final)}, Std: {np.std(reward_code_correlations_final)}, Raw: {reward_code_correlations_final}")
    np.savez('final_eval.npz', reward_code_final_successes=reward_code_final_successes,
             reward_code_correlations_final=reward_code_correlations_final)


if __name__ == "__main__":
    main()