from autogen import ConversableAgent, UserProxyAgent, Agent
import os
import json
from autogen.coding import LocalCommandLineCodeExecutor
from typing import Tuple

import copy
from typing import List, Dict
import autogen.agentchat.contrib.capabilities.transforms as transforms
from autogen.agentchat.contrib.capabilities import transform_messages
from utils.discoverybench_utils.beliefs import BeliefTrueFalseCat
from utils.discoverybench_utils.structured_outputs import (
    ExperimentList,
    ExperimentHypothesisList,
    ExperimentCode,
    ExperimentAnalyst,
    ExperimentReviewer,
    ExperimentReflexion,
)

from utils.discoverybench_utils.dataset import load_dataset_metadata
from utils.discoverybench_utils.new_eval import run_eval_gold_vs_gen_NL_hypo_workflow
from utils.discoverybench_utils.new_eval import prepare_dataset_metadata_json
import pandas as pd
from utils.discoverybench_utils.lm_utils import (
    run_chatgpt_query_multi_res,
    run_chatgpt_query_multi_turn,
)
from autogen.coding.func_with_reqs import _build_python_functions_file


class NoInstallLocalCommandLineCodeExecutor(LocalCommandLineCodeExecutor):
    def _setup_functions(self) -> None:
        """Override the method to prevent pip installs."""
        func_file_content = _build_python_functions_file(self._functions)
        func_file = self._work_dir / f"{self._functions_module}.py"
        func_file.write_text(func_file_content)

        # Skip the installation of required packages.
        # We won't collect or install any packages, so just skip this part.
        self._setup_functions_complete = True


IMAGE_ANALYSIS_PATCH = """\
import matplotlib.pyplot as plt
import functools
from io import BytesIO
import base64
from openai import OpenAI


client = OpenAI()

image_analyst_prompt = '''Please analyze the given plot image and provide the following:

1. Plot Type: Identify the type of plot (e.g., heatmap, bar plot, scatter plot) and its purpose.
2. Axes:
    * Titles and labels, including units.
    * Value ranges for both axes.
3. Data Trends:
    * For scatter plots: note trends, clusters, or outliers.
    * For bar plots: highlight the tallest and shortest bars and patterns.
    * For heatmaps: identify areas of high and low values.
    etc...
4. Annotations and Legends: Describe key annotations or legends.
5. Statistical Insights: Provide insights based on the information presented in the plot.'''


def image_to_text():
    for fig_num in plt.get_fignums():
        fig = plt.figure(fig_num)  # Get the current figure
        with BytesIO() as buf:
            # Save the figure to a PNG buffer
            fig.savefig(buf, format='png', dpi=200)
            buf.seek(0)
            # Encode image to base64
            base64_image = base64.b64encode(buf.read()).decode('utf-8')
            messages = [
                {
                    'role': 'system',
                    'content': 'You are a research scientist responsible for analyzing plots and figures from running experiments and providing detailed descriptions.'
                },
                {
                    'role': 'user',
                    'content': [
                        {'type': 'text', 'text': image_analyst_prompt},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{base64_image}"
                            }
                        }
                    ]
                }
            ]
            # Get image analysis from the LLM
            response = client.chat.completions.create(
                model="gpt-4o",
                messages=messages,
                max_tokens=1000,
            )
            analysis = response.choices[0].message.content
            print(f"\\n=== Plot Analysis (fig. {fig_num}) ===\\n")
            print(analysis)
            print("\\n" + "="*50)

        plt.close(fig)


def patch_matplotlib_show():
    # Replace plt.show with our custom function
    plt.show = functools.partial(image_to_text)


# Apply the patch
patch_matplotlib_show()
"""


class CodeBlockWrapperTransform(transforms.MessageTransform):
    def apply_transform(self, messages: List[Dict]) -> List[Dict]:
        # Deep copy messages to avoid modifying the original
        transformed_messages = copy.deepcopy(messages)
        message = transformed_messages[-1]

        try:
            code = json.loads(message["content"]).get(
                "code", "# Failed to parse code from message"
            )
        except json.JSONDecodeError:
            code = "# Failed to parse code from message"

        message["content"] = f"```python\n{IMAGE_ANALYSIS_PATCH}\n\n{code}\n```"

        return transformed_messages

    def get_logs(
        self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]
    ) -> Tuple[str, bool]:
        return "CodeBlockWrapperTransform", True


def get_openai_config(
    api_key: str | None = None,
    temperature: float | None = None,
    reasoning_effort: str | None = None,
    timeout: int = 600,
    model_name="o4-mini",
):
    config = {
        "api_type": "openai",
        "model": model_name,
        "timeout": timeout,
        "api_key": api_key,
        "max_retries": 3,
        "cache_seed": None,  # Disabling caching also addresses this bug: https://github.com/ag2ai/ag2/issues/1103
    }
    if temperature is not None:
        config["temperature"] = temperature

    # Make o-series specific changes
    if model_name.startswith("o"):
        if reasoning_effort is not None:
            config["reasoning_effort"] = reasoning_effort  # Defaults to medium

    return config


def get_agents(
    work_dir,
    model_name="o4-mini",
    temperature=None,
    reasoning_effort=None,
    user_query=None,
    experiment_first=False,
    # code_timeout=30 * 60,
    code_timeout=2 * 60,
    idx=0,
    log_dir="",
    params={},
) -> dict[str, ConversableAgent]:
    llm_config = get_openai_config(
        api_key=os.getenv("OPENAI_API_KEY"),
        model_name=model_name,
        temperature=temperature,
        reasoning_effort=reasoning_effort,
    )

    # Create token limit transform
    token_limit_capability = transform_messages.TransformMessages(
        transforms=[
            transforms.MessageTokenLimiter(
                max_tokens_per_message=10_000, min_tokens=12_000
            )
        ]
    )

    # Experiment Generator
    _user_query_or_empty = f"{user_query}\n\n" if user_query is not None else ""

    # experiment_generator = ConversableAgent(
    #     name="experiment_generator",
    #     llm_config={
    #         **llm_config,
    #         "response_format": ExperimentList
    #         if not experiment_first
    #         else ExperimentHypothesisList,
    #     },
    #     system_message=(
    #         "You are a research scientist who is interested in doing open-ended, data-driven research using the provided dataset(s). "
    #         f"{_user_query_or_empty}"
    #         "Be creative and think of an interesting new hypothesis and an experiment to verify it. "
    #         "The hypothesis should be a falsifiable statement that can be sufficiently tested by the proposed experiment. "
    #         "Along with the hypothesis, explain in natural language the experiment plan that the programmer should follow (do not provide the code yourself). "
    #         "Remember, you are interested in open-ended research, so do not hesitate to propose hypotheses that lack a direct connection to the previously explored hypotheses. "
    #         "Here are a few instructions that you must follow:\n"
    #         "1. Strictly use only the dataset(s) provided and do not simulate dummy/synthetic data or columns that cannot be derived from the existing columns.\n"
    #         "2. Each hypothesis (and experiment plan) should be creative, independent, and self-contained.\n"
    #         # '3. Use the prior experiments as inspiration to think of an interesting and creative new experiment. However, do not repeat the same experiments.\n\n'
    #         "Here is a possible approach to coming up with a new hypothesis and experiment plan:\n"
    #         "1. Find an interesting context: this could be a specific subset of the data. E.g., if the dataset has multiple categorical variables, you could split the data based on specific values of such variables, which would then allow you to validate a hypothesis in the specific contexts defined by the values of those variables.\n"
    #         "2. Find interesting variables: these could be the columns in the dataset that you find interesting or relevant to the context. You are allowed and encouraged to create composite variables derived from the existing variables.\n"
    #         "3. Find interesting relationships: these are interactions between the variables that you find interesting or relevant to the context. You are encouraged to propose experiments involving complex predictive or causal models.\n"
    #         "4. You must require that your proposed hypotheses are verifiable using robust statistical tests. Remember, your programmer can install python packages via pip which can allow it to write code for complex statistical analyses.\n"
    #         "5. Multiple datasets: If you are provided with more than one dataset, then try to also propose hypotheses that utilize contexts, variables, and relationships across datasets, e.g., this may involve using join or similar operations.\n\n"
    #         "Generally, in typical data-driven research, you will need to explore and visualize the data for possible high-level insights, clean, transform, or derive new variables from the dataset to be suited for the investigation, deep-dive into specific parts of the data for fine-grained analysis, perform data modeling, and run statistical tests. "
    #         # f'Now, generate exactly {branching_factor} new hypotheses (and experiment plans).'
    #         f"Now, generate exactly 1 new hypothesis (and experiment plan)."
    #     ),
    #     human_input_mode="NEVER",
    # )

    install_snippet = """\nimport subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", package])\n\n\n"""

    # Experiment Programmer
    experiment_programmer = ConversableAgent(
        name="experiment_programmer",
        llm_config={**llm_config, "response_format": ExperimentCode},
        system_message=(
            "You are a scientific experiment programmer proficient in writing python code given an experiment plan. "
            "Your code will be included in a python file that is executed and any relevant results should be printed to standard out or presented using plt.show appropriately. "
            "Make sure you provide python code in the proper format to execute. "
            "Ensure your code is clean and concise, and include debug statements only when they are absolutely necessary. "
            "Use only the dataset given and do not assume any other files are available. The state is not preserved between code blocks, so do not assume any variables or imports from previous code blocks. "
            "Import any libraries you need to use. Always attempt to import a library before installing it (it may already be installed). "
            "If you need to install a library, use the following code example:"
            f"{install_snippet}"
            "When installing python packages, use the --quiet option to minimize unnecessary output."
            "Prefer using installed libraries over installing new libraries whenever possible. "
            "If possible, instead of downgrading library versions, try to adapt your code to work with a more updated version that is already installed. "
            "Never attempt to create a new environment. Always use the current environment. "
            "If the code requires generating plots, use plt.show (not plt.savefig).  "
            "Avoid printing the whole data structure to the console directly if it is large; instead, print concise results that are directly relevant to the experiment. "
            "You are allowed 6 total attempts to run the code, including debugging attempts.\n\n"
            "Debugging instructions:\n"
            "1. Only debug if you are either unsure about the executability or validity of the code (i.e., whether it satisfies the proposed experiment).\n"
            '2. If the code you are writing is intended for debugging, the first line of your code must be "# [debug]" only.\n'
            '3. DO NOT use "[debug]" anywhere else in your code.\n'
            "4. DO NOT combine any debug code and the actual experiment implementation code; keep them separate.\n"
            "5. For each experiment, you are allowed to debug at most 3 times.\n"
            "6. As much as possible, minimize the number of debugging steps you use."
        ),
        human_input_mode="NEVER",
    )

    experiment_analyst = ConversableAgent(
        name="experiment_code_analyst",
        llm_config={**llm_config, "response_format": ExperimentAnalyst},
        system_message=(
            "You are a research scientist responsible for evaluating the code execution output for a scientific experiment written by a programmer. "
            "If no code was executed, there was an error, or the code fails silently, return the success status as **false**. "
            'If the code includes a line "# [debug]" i.e "[debug]" as a comment, strictly treat this as a debugging experiment. '
            "In such cases, strictly return the success status as **false**, provide information that it was a debug code execution, "
            "give feedback and request the experiment to be retried with the new information. "
            "Otherwise, analyze the results and provide a short summary of the code output."
        ),
        human_input_mode="NEVER",
    )

    # Experiment Reviewer
    experiment_reviewer = ConversableAgent(
        name="experiment_reviewer",
        llm_config={**llm_config, "response_format": ExperimentReviewer},
        system_message=(
            "You are a research scientist responsible for holistically reviewing the entire experiment pipeline, i.e., the generated code, the output, and the analysis w.r.t. the original experiment plan. "
            # 'write down a scientific hypothesis in natural language, derived from the provided dataset, clearly stating the context of hypothesis (if any), variables chosen (if any) and relationship between '
            "Write down a scientific hypothesis in natural language, derived from the executed experiment, clearly stating the context of hypothesis (if any), *only* relevant variables chosen (if any), and relationship between "
            "those variables (if any) including any statistical significance. Make sure the hypothesis answers the scientific question. Keep the hypothesis concise just like the following examples and only include the most important information. "
            "Also generate a summary of the full workflow starting from data loading that led to the final answer.\n"
            # "those variables (if any) including any statistical significance.\n"
            "The following are examples of valid hypotheses:\n"
            '- "From 1995 to 2009, the number of sandhill cranes around the tundra (Indigilka River) surged by an astounding ~10X"\n'
            '- "Per unit increased ease of immigration reduces 0.1059 unit of the share of offshore employment"\n'
            '- "Higher time preference associated with higher BMI for 1989 data. BMI is postively related with if person spent more than their saving with a coefficient 0.3596. BMI is also positively correlated with if the savings of a person remained unchaged with a coefficient 0.4858."\n\n'
            f"Question: {user_query}"
            # "This hypothesis is valid for the following reasons:\n"
            # '* Contexts refer to stratification of the data under which the given hypothesis is True. E.g., "For all women", "From 1995 to 2009".\n'
            # "* Variables refer to the set of variables (either dependent or independent) that are mentioned in the hypothesis. E.g., number of sandhill cranes, location.\n"
            # '* Relations refer to the form of relation between the variables. E.g., "surged by ~10x".'
            # 'Assess whether the experiment was faithfully implemented, i.e., whether the implementation follows the experiment plan without significant deviation and whether the hypothesis was in fact tested sufficiently. '
            # 'If you find issues or inconsistencies in any part of the experiment pipeline, return the success status as **false** and provide feedback about what is wrong. '
            # 'Otherwise, return the success status as **true** and provide a summary of the hypothesis, experiment results, and findings.'
        ),
        human_input_mode="NEVER",
    )

    experiment_reflector = ConversableAgent(
        name="experiment_reflector",
        llm_config={**llm_config, "response_format": ExperimentReflexion},
        system_message=(
            "You are a discovery agent. You will be given an attempt at a data driven discovery task where it includes handling datasets and python implementation to answer the query. "
            "You will also be given an evaluation of the attempt with a score between 0 (lowest) and 1 (highest). Your goal is to write 1-3 sentences to explain why the attempt is wrong as indicated by the score. "
            "Keep the reflection direct and concise. "
            "This will be a hint for when the experiment is attempted again later. Only provide the few sentence description in your answer, not the implementation."
        ),
        human_input_mode="NEVER",
    )

    # Timeout Code Executor
    # executor = LocalCommandLineCodeExecutor(
    executor = NoInstallLocalCommandLineCodeExecutor(
        timeout=code_timeout,  # Timeout in seconds
        work_dir=work_dir,
        # virtual_env_context=create_virtual_env(os.path.join(work_dir, ".venv"))
    )

    # Create an agent with code executor configuration.
    code_executor = ConversableAgent(
        "code_executor",
        llm_config=False,
        code_execution_config={"executor": executor},
        human_input_mode="NEVER",
    )

    transform_messages_capability = transform_messages.TransformMessages(
        transforms=[CodeBlockWrapperTransform()]
    )
    transform_messages_capability.add_to_agent(code_executor)

    user_proxy = UserProxyAgent(
        name="user_proxy",
        description="Responsible for providing the initial query",
        code_execution_config=False,
        human_input_mode="NEVER",
    )

    experiment_generator = ConversableAgent(
        name="experiment_generator", llm_config=False, human_input_mode="NEVER"
    )

    def evaluate_hypo(recipient, messages, sender, config):
        if "callback" in config and config["callback"] is not None:
            callback = config["callback"]
            callback(sender, recipient, messages[-1])

        dataset_metadata = params["dataset_metadata"]
        metadata_type = params["metadata_type"]
        with open(dataset_metadata, "r") as f:
            data_metadata = json.load(f)
        metadata = load_dataset_metadata(dataset_metadata)
        query = metadata["queries"][0][params["qid"]]["question"]

        # Get gold hypo
        if "test" in dataset_metadata:
            df = pd.read_csv("discoverybench/real/test/answer_key_real.csv")
            dataset_value = dataset_metadata.split("/")[3]
            metadataid = int(
                "".join([c for c in dataset_metadata.split("/")[-1] if c.isdigit()])
            )
            qid = params["qid"]
            matching_row = df[
                (df["dataset"] == dataset_value)
                & (df["metadataid"] == metadataid)
                & (df["query_id"] == qid)
            ]
            gold_hypo = matching_row["gold_hypo"].values[0]
        else:
            gold_hypo = metadata["hypotheses"]["main"][0]["text"]

        pred_hypo = json.loads(messages[-1]["content"])["hypothesis"]
        eval_output = os.path.join(log_dir, f"eval_{idx}.json")

        # # HMS scoring
        # eval_result = run_eval_gold_vs_gen_NL_hypo_workflow(
        #     query=query,
        #     gold_hypo=gold_hypo,
        #     gen_hypo=pred_hypo,
        #     dataset_meta=data_metadata,
        #     llm_used="gpt-4o",
        #     dataset_type=metadata_type,
        #     use_column_metadata=True,
        # )

        # Beliefs scoring
        metadata_path = "discoverybench/real/test/archaeology/metadata_0.json"
        with open(metadata_path, "r") as f:
            dataset_metadata = json.load(f)
        datasets_json = prepare_dataset_metadata_json(
            dataset_metadata,
            dataset_type=params["metadata_type"],
            use_column_metadata=True,
        )

        system_prompt = "You are an AI assistant that helps evaluate a data-driven hypothesis. You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation."
        prompt_text = f"""\
You are going to compare two natural-language hypotheses HypoA and HypoB. Your task is to determine whether the two hypotheses are semantically the same. \
Both the hypotheses answer a natural language over the dataset(s) described by dataset description(s) and column description(s) below. \
Compare HypoA and HypoB in terms of three aspects: **Contexts**, **Variables**, and **Relations**.
E.g., for the hypothesis "From 1995 to 2009, the number of sandhill cranes around the tundra (Indigilka River) surged by an astounding ~10X":

Definitions:
- **Contexts**: The stratification or scope of the data under which the hypothesis is claimed to be true. E.g., "From 1995 to 2009", "For all women".
- **Variables**: The data variables (either dependent or independent) used in the hypothesis. E.g., number of sandhill cranes, location.
- **Relations**: The relationship or trend expressed between the variables. E.g., "increased by ~10x", "correlated negatively".

Here is the metadata for the task:
```json
{{
"datasets": {datasets_json},
"HypoA": {gold_hypo},
"HypoB": {pred_hypo},
}}
```

Choose **one** of the following similarity ratings:
A) definitely similar
B) somewhat similar
C) uncertain
D) somewhat different
E) definitely different

Note: Two hypotheses are definitely similar when they exhibit the same contexts, variables, and relations within the same scope. \
Use the provided dataset metadata to judge similarity -- accounting for variables that may be paraphrased differently. \
Two hypotheses should be judged as similar if they use slightly different wording but their semantic meanings are the same. \
It is okay to assign a partial score if there are some overlap between the hypotheses. However, reserve "definitely similar" for hypotheses that express the exact same scope. \

Return your answer as a JSON object in the following format:
```json
{{
"answer": one of the options from "A) definitely similar", "B) somewhat similar", "C) uncertain", "D) somewhat different", "E) definitely different"
}}```"""
        messages = [
            {"content": system_prompt, "role": "system"},
            {"content": prompt_text, "role": "user"},
        ]

        response = run_chatgpt_query_multi_res(
            messages=messages,
            model_name="gpt-5-nano",
            max_tokens=2048,
            temperature=1.0,
            json_response=True,
            n=5,
        )
        eval_result = {
            "query": query,
            "HypoA": gold_hypo,
            "HypoB": pred_hypo,
            "score_dist": {},
            "final_score": 0,
        }
        if response != None:
            answers = []
            for res in response.choices:
                output = res.message.content.strip().strip("```json").strip("```")
                try:
                    response_json = json.loads(output)
                    answers.append(response_json)
                except Exception as e:
                    print("RES ERROR", e)
                    answers.append({"answer": "cannot comment"})
            distribution = BeliefTrueFalseCat.parse_response(answers)
            mean_belief = distribution.get_mean_belief()
            eval_result["final_score"] = mean_belief
            eval_result["score_dist"] = distribution.to_dict()

        with open(eval_output, "w") as f:
            f.write(json.dumps(eval_result, indent=2))

        # response = f"""\
        # Query: {query}\n
        # Previous attempt: {pred_hypo}\n
        # Evaluation score: {eval_result["final_score"]}"""
        response = {
            "Query": query,
            "Previous attempt": pred_hypo,
            "Evaluation score": eval_result["final_score"],
        }

        return True, json.dumps(response)

    experiment_evaluator = ConversableAgent(
        name="experiment_evaluator", llm_config=False, human_input_mode="NEVER"
    )
    experiment_evaluator.register_reply(
        [Agent, None], reply_func=evaluate_hypo, config={"callback": None}
    )

    agents = [
        user_proxy,
        experiment_generator,
        experiment_programmer,
        code_executor,
        experiment_analyst,
        experiment_reviewer,
        experiment_evaluator,
        experiment_reflector,
    ]

    # Apply token limit to all agents
    for agent in agents:
        token_limit_capability.add_to_agent(agent)

    agents_dict = {agent.name: agent for agent in agents}
    return agents_dict
