from openai.types.chat.completion_create_params import ResponseFormat
from utils.base_task import Task
from utils.discoverybench_utils.discoverybench_prompts import DiscoveryBenchPrompts
from utils.discoverybench_utils.discovery_eval import evaluation
from typing import Any
import numpy as np
import json
import re
from sentence_transformers import SentenceTransformer
import random
from rdkit import Chem
from rdkit.Chem import QED
from dockstring import load_target
from func_timeout import func_timeout, FunctionTimedOut
from multiprocessing import Pool
import os
from datetime import datetime
import concurrent.futures
import shutil
from utils.discoverybench_utils.dataset import (
    get_datasets_fpaths,
    get_dataset_description,
)
from utils.discoverybench_utils.dataset import load_dataset_metadata
from utils.discoverybench_utils.agents import get_agents
from autogen import GroupChat, GroupChatManager
from utils.discoverybench_utils.transitions import SpeakerSelector

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def extract_json(response):
    match = re.findall(r"{.*}", response, re.DOTALL)
    if len(match) > 0 and "response" in match[0]:
        return match[0]
    else:
        return response


def setup_group_chat(agents, max_rounds):
    # Set up the group chat with agents and rules
    group_chat = GroupChat(
        agents=list(agents.values()),
        messages=[],
        max_round=max_rounds,
        speaker_selection_method=SpeakerSelector().select_next_speaker,
    )
    chat_manager = GroupChatManager(groupchat=group_chat, llm_config=None)
    return group_chat, chat_manager


def handle_completion(completion, idx, log_dir, work_dir, params):
    # model_name = "gpt-4.1-nano"
    model_name = params["gpt_model"]
    temperature = 1.0
    reasoning_effort = ("medium",)
    experiment_first = False
    code_timeout = 5 * 60  # 5 minutes
    user_query = None

    dataset_metadata = params["dataset_metadata"]
    dataset_paths = get_datasets_fpaths(dataset_metadata)
    for dataset_fpath in dataset_paths:
        shutil.copy(dataset_fpath, work_dir)
    exp_objective = get_dataset_description(dataset_metadata, params["qid"])
    query = "Plan an experiment to answer the question about the following dataset.\n"
    query += f"{exp_objective}"

    metadata = load_dataset_metadata(dataset_metadata)
    user_query = metadata["queries"][0][params["qid"]]["question"]

    max_rounds = 100000

    agent_objs = get_agents(
        work_dir,
        model_name=model_name,
        temperature=temperature,
        reasoning_effort=reasoning_effort,
        user_query=user_query,
        experiment_first=experiment_first,
        code_timeout=code_timeout,
        idx=idx,
        log_dir=log_dir,
        params=params,
    )
    groupchat, chat_manager = setup_group_chat(agent_objs, max_rounds)
    # user_proxy = agent_objs["user_proxy"]
    experiment_generator = agent_objs["experiment_generator"]

    _, last_message = chat_manager.resume(
        messages=[
            {"name": "user_proxy", "role": "user", "content": query},
            {"name": "experiment_generator", "role": "user", "content": completion},
        ]
    )
    experiment_generator.initiate_chat(
        recipient=chat_manager, message=last_message, clear_history=False
    )

    chat_messages = json.loads(chat_manager.messages_to_string(groupchat.messages))

    return chat_messages


class DiscoveryBenchTask(Task):
    def __init__(self, **kargs):
        super().__init__(**kargs)
        # self.num_guesses = num_guesses
        self.prompts = DiscoveryBenchPrompts(self)
        self.metadata_type = kargs.get("discoverybench_metadata_type")
        self.dataset_metadata = kargs.get("target")
        self.qid = kargs.get("discoverybench_qid")

        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
        log_dir_prefix = f"logs/{kargs.get('task')}/{kargs.get('strategy')}/{self.dataset_metadata.split('/')[-2]}/{self.dataset_metadata.split('/')[-1][:-5]}_{self.qid}"
        self.log_dir = os.path.join(log_dir_prefix, timestamp)
        self.work_dir = os.path.join(self.log_dir, "work")
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.work_dir, exist_ok=True)
        self.log_file = os.path.join(self.log_dir, "results.json")
        self.iteration = 0

    def decode_response(self, response: str) -> Any:
        """
        Decode LLM response into desired format
        """
        try:
            return [x.strip() for x in json.loads(extract_json(response))["response"]]
        except Exception as _:
            return []

    def encode_representation(self, representation) -> str:
        """
        Encode task's desired format into string
        """
        return ""

    def get_bb_score(self, solution, attempt):
        """
        Compute black-box score
        """
        experiment_plans = attempt
        params = {
            "gpt_model": "gpt-5-nano",
            "dataset_metadata": self.dataset_metadata,
            "qid": self.qid,
            "metadata_type": self.metadata_type,
        }

        completed_chats = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(
                    handle_completion,
                    plan,
                    self.iteration * 5 + idx,
                    self.log_dir,
                    self.work_dir,
                    params,
                )
                for idx, plan in enumerate(experiment_plans)
            ]

            for future in concurrent.futures.as_completed(futures):
                try:
                    result = future.result()
                    completed_chats.append(result)
                except Exception as e:
                    print(f"Error processing experiment plan: {e}")
        new_results = []
        for idx, (plan, chat) in enumerate(zip(experiment_plans, completed_chats)):
            score = 0
            reflection = ""
            hypo = ""
            for msg in chat:
                if msg["name"] == "experiment_evaluator":
                    try:
                        score = json.loads(msg["content"])["Evaluation score"]
                    except:
                        pass
                if msg["name"] == "experiment_reviewer":
                    try:
                        hypo = json.loads(msg["content"])["hypothesis"]
                    except:
                        pass
                if msg["name"] == "experiment_reflector":
                    try:
                        reflection = json.loads(msg["content"])["reflection"]
                    except:
                        pass
            result = {
                "id": self.iteration * 5 + idx,
                "plan": plan,
                "gen_hypo": hypo,
                "score": score,
                "reflection": reflection,
            }
            new_results.append(result)
        self.iteration += 1
        return [x["score"] for x in new_results]

        # return [random.random() * 0.5 for _ in range(len(attempt))]

    def clean_completion(self, completion) -> list:
        guesses = self.decode_response(completion)
        responses = []
        for i in range(len(guesses)):
            if i < len(guesses):
                responses.append(guesses[i])
            else:
                responses.append(completion)
        return responses

    def evaluate_completion(self, completion, inputs) -> list:
        guesses = self.decode_response(completion)
        scores = self.get_bb_score(inputs[0]["solution"][0], guesses)
        return scores

    def evaluate_and_log(self, completions, inputs, iteration, logfile):
        responses, bb_scores = [], []
        for i, completion in enumerate(completions):
            response = self.clean_completion(completion)
            bb_score = self.evaluate_completion(completion, inputs)
            responses += response
            bb_scores += bb_score

        # responses[-1] = (
        #     "Use the `AxesCelts_inter` dataset to analyze the quantity of axes over time and determine the year it peaked the highest."
        # )
        # bb_scores[-1] = 1.0

        self.log_repsonse(responses, bb_scores, iteration, logfile, inputs)

        num_completions = len(responses)
        responses, bb_scores = (
            responses[: self.num_guesses - self.migrate_beta],
            bb_scores[: self.num_guesses - self.migrate_beta],
        )
        while len(responses) < self.num_guesses - self.migrate_beta:
            if num_completions > 0:
                idx = random.randint(0, num_completions - 1)
                responses.append(responses[idx])
                bb_scores.append(bb_scores[idx])
            else:
                idx = random.randint(0, len(completions) - 1)
                responses.append(completions[idx])
                bb_scores.append(0)

        return responses, bb_scores

    def log_repsonse(self, completions, bb_scores, iteration, logfile, inputs):
        response = (
            {
                "Iteration": iteration,
                "completions_scores": list(zip(completions, bb_scores)),
                "problem": str(inputs[0]["problem"]),
                "solution": str(inputs[0]["solution"]),
            },
        )
        with open(self.log_file, "r") as file:
            data = json.load(file)
            data["guesses"].append(response)
        with open(self.log_file, "w") as file:
            json.dump(data, file, indent=4)

    # def create_batch_completions(self, guesses, scores, past_guesses):
    #     n = 10
    #     pairs = list(zip(guesses, scores))
    #     if len(pairs) > n:
    #         pairs = random.sample(pairs, k=n)
    #     elif len(pairs) < n:
    #         pairs += random.choices(pairs, k=n - len(pairs))
    #     guesses, scores = zip(*pairs)
    #
    #     group_size = 2
    #
    #     guesses, scores = zip(
    #         *sorted(zip(guesses, scores), key=lambda x: x[1], reverse=True)
    #     )
    #     completions, rewards = [], []
    #     for i in range(self.batch_size):
    #         batch = list(guesses[i * group_size : i * group_size + group_size])
    #         random.shuffle(batch)
    #         if all(" " not in word for word in batch):
    #             completions.append(json.dumps({"response": batch}))
    #         else:
    #             if i == 0:
    #                 completions.append(
    #                     json.dumps({"response": [guesses[0]] * group_size})
    #                 )
    #             else:
    #                 completions.append(guesses[i * group_size])
    #         rewards.append(
    #             np.max(scores[i * group_size : i * group_size + group_size]).item()
    #         )
    #     return completions, rewards
