from unsloth import FastLanguageModel, PatchFastRL  # Needs to be first import

PatchFastRL("GRPO", FastLanguageModel)
import concurrent.futures
import shutil
from trl import maybe_apply_chat_template
from utils.discoverybench_utils.dataset import (
    get_datasets_fpaths,
    get_dataset_description,
)
from autogen import GroupChat, GroupChatManager
from utils.discoverybench_utils.agents import get_agents
from utils.discoverybench_utils.transitions import SpeakerSelector
from utils.discoverybench_utils.dataset import load_dataset_metadata
import torch
import json
import argparse
from datetime import datetime
import os
import re
import time

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def setup_group_chat(agents, max_rounds):
    # Set up the group chat with agents and rules
    group_chat = GroupChat(
        agents=list(agents.values()),
        messages=[],
        max_round=max_rounds,
        speaker_selection_method=SpeakerSelector().select_next_speaker,
    )
    chat_manager = GroupChatManager(groupchat=group_chat, llm_config=None)
    return group_chat, chat_manager


def handle_completion(completion, idx, log_dir, work_dir, params):
    # model_name = "gpt-4.1-nano"
    model_name = params["gpt_model"]
    temperature = 1.0
    reasoning_effort = ("medium",)
    experiment_first = False
    code_timeout = 5 * 60  # 5 minutes
    user_query = None

    dataset_metadata = params["dataset_metadata"]
    dataset_paths = get_datasets_fpaths(dataset_metadata)
    for dataset_fpath in dataset_paths:
        shutil.copy(dataset_fpath, work_dir)
    exp_objective = get_dataset_description(dataset_metadata, params["qid"])
    query = "Plan an experiment to answer the question about the following dataset.\n"
    query += f"{exp_objective}"

    metadata = load_dataset_metadata(dataset_metadata)
    user_query = metadata["queries"][0][params["qid"]]["question"]

    max_rounds = 100000

    agent_objs = get_agents(
        work_dir,
        model_name=model_name,
        temperature=temperature,
        reasoning_effort=reasoning_effort,
        user_query=user_query,
        experiment_first=experiment_first,
        code_timeout=code_timeout,
        idx=idx,
        log_dir=log_dir,
        params=params,
    )
    groupchat, chat_manager = setup_group_chat(agent_objs, max_rounds)
    # user_proxy = agent_objs["user_proxy"]
    experiment_generator = agent_objs["experiment_generator"]

    _, last_message = chat_manager.resume(
        messages=[
            {"name": "user_proxy", "role": "user", "content": query},
            {"name": "experiment_generator", "role": "user", "content": completion},
        ]
    )
    experiment_generator.initiate_chat(
        recipient=chat_manager, message=last_message, clear_history=False
    )

    chat_messages = json.loads(chat_manager.messages_to_string(groupchat.messages))

    return chat_messages


def gen_experiment_plans(
    dataset_metadata, model_name, work_dir, params, prev_plans=None
):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name,
        device_map="auto",
        max_seq_length=8192,
        gpu_memory_utilization=0.8,
    )

    dataset_paths = get_datasets_fpaths(dataset_metadata)
    for dataset_fpath in dataset_paths:
        shutil.copy(dataset_fpath, work_dir)

    exp_objective = get_dataset_description(dataset_metadata, params["qid"])

    query = "Plan an experiment to answer the question about the following dataset.\n"
    query += f"{exp_objective}"
    if prev_plans:
        query += "\n\n##### PRIOR EXPERIMENTS #####\n"
        # query += "The following are past experiments and a score that indicates the success of the plan "
        for i, plan in enumerate(prev_plans):
            query += f"\n\nPlan #{i + 1}:\n{plan['plan']}\n"
            query += f"Evaluation Score: {round(plan['score'], 3)}"
            if plan["reflection"]:
                query += f"\nFeedback: {plan['reflection']}"
    query += "\n\n"

    query += (
        f"Now create exactly {params['batch_size']} new experiment plans that could answer the scientific question. "
        """(Note: give only a list of experiment plans in the provided JSON format, e.g. {"response": ["experiment_plan_1", "experiment_plan_2", ...]})"""
    )

    system_message = (
        "You are a research scientist who is interested in data-driven research using the provided dataset(s) and query. "
        # 'Be creative and think of an interesting new hypothesis and an experiment to verify it. '
        "Be creative and think of an interesting new experiment to help answer the provided scientific query. "
        # 'The hypothesis should be a falsifiable statement that can be sufficiently tested by the proposed experiment. '
        # 'Along with the hypothesis, explain in natural language the experiment plan that the programmer should follow (do not provide the code yourself). '
        "Explain in natural language the experiment plan that the programmer should follow (do not provide the code yourself). "
        # 'Remember, you are interested in open-ended research, so do not hesitate to propose hypotheses that lack a direct connection to the previously explored hypotheses. '
        "Here are a few instructions that you must follow:\n"
        "1. Strictly use only the dataset(s) provided and do not simulate dummy/synthetic data or columns that cannot be derived from the existing columns.\n"
        # '2. Each hypothesis (and experiment plan) should be creative, independent, and self-contained.\n'
        "2. The experiment plan should be creative, independent, and self-contained.\n"
        "3. Use the prior experiments (if any) as inspiration to think of an interesting and creative new experiment. However, do not repeat the same experiments.\n\n"
        # 'Here is a possible approach to coming up with a new hypothesis and experiment plan:\n'
        "Here is a possible approach to coming up with a new experiment plan:\n"
        "1. Find an interesting context: this could be a specific subset of the data. E.g., if the dataset has multiple categorical variables, you could split the data based on specific values of such variables, which would then allow you to validate a hypothesis in the specific contexts defined by the values of those variables.\n"
        "2. Find interesting variables: these could be the columns in the dataset that you find interesting or relevant to the context. You are allowed and encouraged to create composite variables derived from the existing variables.\n"
        "3. Find interesting relationships: these are interactions between the variables that you find interesting or relevant to the context. You are encouraged to propose experiments involving complex predictive or causal models.\n"
        # '4. You must require that your proposed hypotheses are verifiable using robust statistical tests. Remember, your programmer can install python packages via pip which can allow it to write code for complex statistical analyses.\n'
        "4. You must require that your proposed experiment plan is based on robust statistical tests. Remember, your programmer can install python packages via pip which can allow it to write code for complex statistical analyses.\n"
        # '5. Multiple datasets: If you are provided with more than one dataset, then try to also propose hypotheses that utilize contexts, variables, and relationships across datasets, e.g., this may involve using join or similar operations.\n\n'
        "5. Multiple datasets: If you are provided with more than one dataset, then try to also propose an experiment that utilize contexts, variables, and relationships across datasets, e.g., this may involve using join or similar operations.\n\n"
        "Generally, in typical data-driven research, you will need to explore and visualize the data for possible high-level insights, clean, transform, or derive new variables from the dataset to be suited for the investigation, deep-dive into specific parts of the data for fine-grained analysis, perform data modeling, and run statistical tests.\n\n"
        # f'Now, generate exactly {branching_factor} new hypotheses (and experiment plans).'
        # f'Now, generate exactly 1 new hypothesis (and experiment plan).'
        "Examples of valid experiment plans:\n"
        "Experiment plan #1:\n"
        "1. Merge the datasets offshore, immigration, and native_employment on the common columns 'year' and 'beaind'.\n2. Replace infinite values with NaNs and drop rows with NaNs in any column.\n3. Independent variables: 'iv_offshoring_1', 'penetration'\n4. Fit the OLS regression modela\n\n"
        "Experiment plan #2:\n"
        "1. Chose BMI as dependent variable.\n2. Time preference (independent) variables as 'DISSAVED' and 'SAMESAVE'.\n3. Fit an OLS regression model and returned the model summary.\n\n"
        "Experiment plan #3:\n"
        "1. Selected appropriate variables from the raw data\ne.g.: Height: Height of respondent on 1985 instead of 1981.\n      Income: Total net family income, 1989 (There are many other income variables in the raw data)\n      Age:    Age of respondent at 1989 (Derived from Age at 1979).\nData Transformation:\n2. Replaced -1 to -5 values (unavailable data) with NaN\n3. Imputed the missing values in the AGE and INCOME variable with mean.\n4. AGE_1989 had missing values, hence derived the variable as [AGE_1979 + 10]\n5. Created a BMI variable using: bmi = (weight) * 0.453592 / (height) * 0.0254\n6. Divided the Family income variable by 1000$ (Mentioned in the paper)\n7. Created an AGE^2 variable (From the paper)\n8. One-hot encoded RACE variable into BLACK and HISPANIC\n9. One-hot encooded GENDER variable into MALE and FEMALE\n10. Selected 'Was more money put into or taken out of R/spouse savings since last interview, 1989' as the Time Preference variable.\n  DISSAVED = 1 if 'TOOK MORE MONEY OUT' else 0\n  SAMESAVE = 1 if 'NO SAVINGS' or 'NO CHANGE' else 0\n11. Dropped the unimportant columns for replication\n12. DISSAVED and SAMESAVE as independent variables and BMI as dependent variable\n13. Fit an OLS Regression Model\n\n"
        # f'Now, generate exactly 5 new experiment plan.'
    )
    prompts = [
        {"content": system_message, "role": "system"},
        {"content": query, "role": "user"},
    ]
    prompts_text = [
        maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"]
        for prompt in [prompts]
    ]
    prompt_inputs = tokenizer(
        prompts_text,
        return_tensors="pt",
        padding=True,
        padding_side="left",
        add_special_tokens=False,
    )
    prompt_ids, prompt_mask = (
        prompt_inputs["input_ids"],
        prompt_inputs["attention_mask"],
    )
    prompt_length = prompt_ids.size(1)
    for _ in range(3):
        with torch.no_grad():
            completion_ids = model.generate(
                input_ids=prompt_ids.to(DEVICE),
                attention_mask=prompt_mask.to(DEVICE),
                do_sample=True,
                temperature=1.0,
                max_new_tokens=1024,
            )
        completions = tokenizer.batch_decode(
            completion_ids[:, prompt_length:], skip_special_tokens=True
        )

        json_pattern = re.compile(r"\{.*\}")
        json_match = json_pattern.search(completions[0])
        if json_match:
            try:
                json_data = json.loads(json_match.group())  # Parse the JSON data
                return json_data["response"]
            except:
                pass
    return None


def main(params):
    dataset_metadata = params["dataset_metadata"]
    model_name = params["model_name"]
    work_dir = params["work_dir"]

    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir_prefix = f"logs/discoverybench/{dataset_metadata.split('/')[-2]}/{dataset_metadata.split('/')[-1][:-5]}_{params['qid']}"
    log_dir = os.path.join(log_dir_prefix, timestamp)
    work_dir = os.path.join(log_dir, "work")
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(work_dir, exist_ok=True)

    args_file = os.path.join(log_dir, "args.json")
    with open(args_file, "w") as f:
        json.dump(params, f, indent=2)

    log_file = os.path.join(log_dir, "results.json")
    all_results = {"experiment_plans": []}
    prev_results = None

    start_time = time.time()
    for i in range(params["iterations"]):
        # Generate experiment plans from local model
        experiment_plans = gen_experiment_plans(
            dataset_metadata, model_name, work_dir, params, prev_plans=prev_results
        )
        if not experiment_plans:
            all_results["experiment_plans"].append(
                [{"plan": "", "score": 0} for _ in range(params["batch_size"])]
            )
            continue

        # Execute experiment plans with GPT model
        completed_chats = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(
                    handle_completion,
                    plan,
                    i * params["batch_size"] + idx,
                    log_dir,
                    work_dir,
                    params,
                )
                for idx, plan in enumerate(experiment_plans)
            ]

            for future in concurrent.futures.as_completed(futures):
                try:
                    result = future.result()
                    completed_chats.append(result)
                except Exception as e:
                    print(f"Error processing experiment plan: {e}")

        new_results = []
        for idx, (plan, chat) in enumerate(zip(experiment_plans, completed_chats)):
            score = 0
            reflection = ""
            hypo = ""
            for msg in chat:
                if msg["name"] == "experiment_evaluator":
                    try:
                        score = json.loads(msg["content"])["Evaluation score"]
                    except:
                        pass
                if msg["name"] == "experiment_reviewer":
                    try:
                        hypo = json.loads(msg["content"])["hypothesis"]
                    except:
                        pass
                if msg["name"] == "experiment_reflector":
                    try:
                        reflection = json.loads(msg["content"])["reflection"]
                    except:
                        pass
            result = {
                "id": i * params["batch_size"] + idx,
                "plan": plan,
                "gen_hypo": hypo,
                "score": score,
                "reflection": reflection,
            }
            new_results.append(result)
        all_results["experiment_plans"] += new_results
        if new_results:
            prev_results = new_results

        # Log new results
        with open(log_file, "w") as file:
            json.dump(all_results, file, indent=2)

        # Log individual group chats
        for idx, chat in enumerate(completed_chats):
            chat_file = os.path.join(
                log_dir, f"chat_history{i * params['batch_size'] + idx}.json"
            )
            with open(chat_file, "w") as file:
                json.dump(chat, file, indent=2)

        # Early stop
        if new_results and max([x["score"] for x in new_results]) > 0.8:
            print("Early Stopping.")
            break

    duration = time.time() - start_time
    all_results["duration"] = duration
    all_results["best_score"] = max(
        [x["score"] for x in all_results["experiment_plans"]]
    )
    with open(log_file, "w") as file:
        json.dump(all_results, file, indent=2)


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_metadata",
        type=str,
        default="discoverybench/real/train/evolution_freshwater_fish/metadata_0.json",
    )
    parser.add_argument("--metadata_type", type=str, default="real")
    parser.add_argument("--qid", type=int, default=0)
    parser.add_argument("--work_dir", type=str, default="work")
    # parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-3B-Instruct")
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct")
    parser.add_argument("--gpt_model", type=str, default="gpt-5-nano")
    parser.add_argument("--batch_size", type=int, default=5)
    parser.add_argument("--iterations", type=int, default=40)
    # parser.add_argument("--iterations", type=int, default=2)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_arguments()
    params = vars(args)

    main(params)
