import os
import sys
import argparse
import time
import datetime
import json
import textgrad as tg

tg.logger.propagate = False
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import reward_design.preference_model
import reward_design.utils
import reward_design.reward_function_buffer
import reward_design.prompts.d4rl_prompt as d4rl_prompt
import reward_design.prompts.preference_prompt as preference_prompt


def get_config() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Reward-Design-for-Offline-RL")
    parser.add_argument("--benchmark", type=str, default="d4rl", help="Benchmark name.")
    parser.add_argument(
        "--log_dir", type=str, default="./reward_logs", help="Path to the save logs."
    )
    parser.add_argument(
        "--env_name", type=str, default="antmaze-umaze-v0", help="Environment name."
    )
    parser.add_argument(
        "--llm", type=str, default="gpt-4o-2024-11-20", help="LLM API model name"
    )
    parser.add_argument(
        "--model_seed", type=int, default=0, help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--temperature_response",
        type=float,
        default=0.7,
        help="Sampling temperature for generation.",
    )
    parser.add_argument(
        "--max_tokens_response",
        type=int,
        default=10000,
        help="Max tokens to generate for each output sequence.",
    )
    parser.add_argument(
        "--top_p_response", type=float, default=1.0, help="Diversity of output."
    )
    parser.add_argument(
        "--use_prompt_caching",
        type=reward_design.utils.str2bool,
        default=False,
        help="Whether to prompt caching.",
    )
    parser.add_argument(
        "--use_preference",
        type=reward_design.utils.str2bool,
        default=True,
        help="Whether to use preference.",
    )
    parser.add_argument(
        "--iter_num",
        type=int,
        default=3,
        help="Number of TextGrad iteration.",
    )
    parser.add_argument(
        "--sample_num",
        type=int,
        default=5,
        help="Number of responses to sample for each step.",
    )
    parser.add_argument(
        "--num_top_episodes",
        type=int,
        default=1,
        help="Number of expert trajectory.",
    )
    parser.add_argument(
        "--expert_tolerance",
        type=float,
        default=0.01,
        help="Tolerance allowed when comparing non-expert trajectory return.",
    )
    parser.add_argument(
        "--noise_tolerance",
        type=float,
        default=0.01,
        help="Tolerance allowed when comparing noisy trajectory return.",
    )
    parser.add_argument(
        "--obs_noise_scale",
        type=float,
        default=0.05,
        help="The scaling factor of the state used to generate the noise trajectory.",
    )
    parser.add_argument(
        "--action_noise_scale",
        type=float,
        default=0.05,
        help="The scaling factor of the action used to generate the noise trajectory.",
    )
    parser.add_argument(
        "--mask_error_code",
        type=reward_design.utils.str2bool,
        default=True,
        help="Whether to mask error reward fcuntion code when getting best and worst.",
    )
    parser.add_argument(
        "--use_return_per_step",
        type=reward_design.utils.str2bool,
        default=False,
        help="Whether to calculate return per step.",
    )
    parser.add_argument(
        "--design_mode",
        type=str,
        choices=["sa", "sas", "ss"],
        default="sas",
        help="The mode of reward design.",
    )
    # only for antmaze
    parser.add_argument(
        "--disable_goal",
        type=reward_design.utils.str2bool,
        default=False,
        help="Whether to use goals.",
    )
    parser.add_argument(
        "--fix_goal",
        type=reward_design.utils.str2bool,
        default=True,
        help="Whether to use fixed goals.",
    )

    return parser.parse_args()


def main():
    start_time = time.time()
    args = get_config()
    now = datetime.datetime.now()
    formatted_time = now.strftime("%Y-%m-%d-%H-%M-%S")
    log_path = os.path.join(
        args.log_dir,
        args.benchmark,
        args.llm,
        f"model_seed={args.model_seed}",
        f"mask_error_code={args.mask_error_code}",
        f"use_return_per_step={args.use_return_per_step}",
        # for loss, backward, and optimization
        # Modify the textgrad source code to make it effective.
        # The purpose here is just to record.
        f"temperature_textgrad=0.7",
        f"max_tokens_textgrad=10000",
        f"top_p_textgrad=1.0",
        # for query
        f"temperature_response={args.temperature_response}",
        f"max_tokens_response={args.max_tokens_response}",
        f"top_p_response={args.top_p_response}",
        # for ablation
        f"iter_num={args.iter_num}",
        f"sample_num={args.sample_num}",
        f"expert_tolerance={args.expert_tolerance}",
        f"noise_tolerance={args.noise_tolerance}",
        f"obs_noise_scale={args.obs_noise_scale}",
        f"action_noise_scale={args.action_noise_scale}",
        f"use_preference={args.use_preference}",
        f"num_top_episodes={args.num_top_episodes}",
        f"design_mode={args.design_mode}",
        f"env_name={args.env_name}",
        formatted_time,
    )
    os.makedirs(log_path, exist_ok=True)
    file_logger = reward_design.utils.FileLogger(
        filename=os.path.join(log_path, "llm_output.log")
    )
    sys.stdout = file_logger
    sys.stderr = file_logger

    # print config
    print("=" * 10 + f"args is:" + "=" * 10)
    print(json.dumps(vars(args), indent=4))
    with open(os.path.join(log_path, 'llm_args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    # utils.get_logger(logger_name="code_generation", log_file_path=os.path.join(args.log_dir, "log.txt"))

    # define context engine
    llm_engine = tg.get_engine(
        args.llm,
        base_url=os.getenv("OPENAI_API_BASE"),
        # extra args
        use_prompt_caching=args.use_prompt_caching,
    )
    tg.set_backward_engine(llm_engine, override=True)

    # define preference model
    pref_model = reward_design.preference_model.PreferenceModel(
        env_name=args.env_name,
        design_mode=args.design_mode,
        model_seed=args.model_seed,
        num_top_episodes=args.num_top_episodes,
        expert_tolerance=args.expert_tolerance,
        noise_tolerance=args.noise_tolerance,
        obs_noise_scale=args.obs_noise_scale,
        action_noise_scale=args.action_noise_scale,
        use_return_per_step=args.use_return_per_step,
        disable_goal=args.disable_goal,
        fix_goal=args.fix_goal,
    )

    # define reward buffer
    func_buffer = reward_design.reward_function_buffer.RewardFunctionBuffer(
        pref_model=pref_model, mask_error_code=args.mask_error_code
    )

    # Prepare generation params
    gen_params = {
        "temperature": args.temperature_response,
        "max_tokens": args.max_tokens_response,
        "top_p": args.top_p_response,
    }

    # 1) First Generation
    # 1.1) Initial sampling for candidates
    question = d4rl_prompt.get_d4rl_prompt(args.env_name, args.design_mode)

    print("=" * 10 + "First Generation" + "=" * 10 + "\n")
    for cur_sample in range(1, args.sample_num + 1):
        print(
            "=" * 10
            + f"Sample Number = {cur_sample}/{args.sample_num}"
            + "=" * 10
            + "\n"
        )
        response = llm_engine(question, **gen_params)
        rsp_model = llm_engine.rsp_model
        rsp_finish_reason = llm_engine.rsp_finish_reason
        rsp_usage = llm_engine.rsp_usage
        code = reward_design.utils.extract_code(response=response)
        reward_obj = {
                "iter": 0,
                "sample": cur_sample,
                "stage": ["first_generation"],
                "response": [response],
                "model": [rsp_model],
                "finish_reason": [rsp_finish_reason],
                "usage": [rsp_usage],
                "code": code,
                "score": None,
                "metrics": None,
                "executable_flag": None,
        }
        # func = reward_obj["code"]
        # print(func)
        # print("hash is:", hash(func))
        func_buffer.update_buffer([reward_obj])

    # 1.2) Initial "chosen_resp_text" and "rej_resp_text"
    (
        (chosen_resp_text, best_score),
        (rej_resp_text, worst_score),
        delta,
        all_failed_flag,
    ) = func_buffer.get_preference_result()
    print(
        "=" * 10
        + f"First Generation Completed ({args.sample_num}/{args.sample_num})"
        + "=" * 10
        + "\n"
    )
    if all_failed_flag:
        print(
            "=" * 10
            + f"Error: iter={0}, sample={cur_sample}, all_failed_flag={all_failed_flag}"
            + "=" * 10
        )
        with open(
            os.path.join(log_path, "error_report.txt"), "a", encoding="utf-8"
        ) as f:
            f.write(
                f"Error: iter={0}, sample={cur_sample}, all_failed_flag={all_failed_flag}\n"
            )
    # func_buffer.save(log_path, cur_iter_num=0)

    # 2) Reward Improvement
    for cur_iter in range(1, args.iter_num + 1):
        # 2.1) Define fixed constant
        response_role = (
            "a model response to a user query"
            if not args.use_preference
            else "a chosen response to a user query"
        )
        # Constraints for textual updates
        constraints = (
            ["Only generate a model response."]
            if not args.use_preference
            else [
                "Only generate a chosen response.",
                "Do NOT generate a rejected response.",
            ]
        )

        # 2.2) Perform {args.iter_num} optimizations independently
        for cur_sample in range(1, args.sample_num + 1):

            # 2.2.1) Define the variable to be optimized using "chosen_resp_text"
            # The "chosen_resp_text" will be updated after each iter in "2.3)"
            chosen_response = tg.Variable(
                "**Chosen Response**:\n" + "```python" + chosen_resp_text + "```",
                requires_grad=True,
                role_description=response_role,
            )

            # 2.2.2) Update "evaluation_sys_text" using "rej_resp_text"
            # The "rej_resp_text" will be updated after each iter in "2.4)"
            if not args.use_preference:
                # No rejected sample provided
                evaluation_sys_text = (
                    preference_prompt.EVALUATION_SYS_TEMPLATE_REVISION.format(
                        query=question
                    )
                )
            else:
                # Using Preference, includes rejected response
                evaluation_sys_text = preference_prompt.EVALUATION_SYS_TEMPLATE.format(
                    query=question,
                    rejected_response=rej_resp_text,
                )
            print(
                "chosen_resp_text hash is:",
                hash(chosen_resp_text),
                best_score,
                hash(rej_resp_text),
                worst_score,
            )

            # 2.2.3) Define loss
            loss_fn = tg.TextLoss(evaluation_sys_text)

            # 2.2.4) Create the TPO optimizer
            optimizer = tg.TextualGradientDescent(
                engine=llm_engine,
                parameters=[chosen_response],
                constraints=constraints,
            )
            # 2.2.5) Start reward improvement, clear optimizer
            print(
                "=" * 10
                + f"Start Reward Improvement (iter={cur_iter}/{args.iter_num}, sample={cur_sample}/{args.sample_num}):"
                + "=" * 10,
                end="\n\n",
            )
            optimizer.zero_grad()

            # 2.2.6) Compute textual loss
            print("score:", best_score, worst_score)
            print(
                "=" * 10
                + f"Calculate Loss (iter={cur_iter}/{args.iter_num}, sample={cur_sample}/{args.sample_num}):"
                + "=" * 10,
                end="\n\n",
            )
            loss = loss_fn(chosen_response)
            loss_rsp = llm_engine.rsp
            loss_model = llm_engine.rsp_model
            loss_finish_reason = llm_engine.rsp_finish_reason
            loss_usage = llm_engine.rsp_usage

            # 2.2.7) Compute textual gradients
            print(
                "=" * 10
                + f"Calculate Backward (iter={cur_iter}/{args.iter_num}, sample={cur_sample}/{args.sample_num}):"
                + "=" * 10,
                end="\n\n",
            )
            loss.backward()
            backward_rsp = llm_engine.rsp
            backward_model = llm_engine.rsp_model
            backward_finish_reason = llm_engine.rsp_finish_reason
            backward_usage = llm_engine.rsp_usage

            # 2.2.8) Update variable using textual gradients
            print(
                "=" * 10
                + f"Calculate Optimization (iter={cur_iter}/{args.iter_num}, sample={cur_sample}/{args.sample_num}):"
                + "=" * 10,
                end="\n\n",
            )
            optimizer.step()
            optimizer_rsp = llm_engine.rsp
            optimizer_model = llm_engine.rsp_model
            optimizer_finish_reason = llm_engine.rsp_finish_reason
            optimizer_usage = llm_engine.rsp_usage

            # 2.2.9) Get improved reward function code
            print(
                "=" * 10
                + f"Optimization code (iter={cur_iter}/{args.iter_num}, sample={cur_sample}/{args.sample_num}):"
                + "=" * 10,
                end="\n\n",
            )
            code = reward_design.utils.extract_code(response=chosen_response.value)
            reward_obj = {
                    "iter": cur_iter,
                    "sample": cur_sample,
                    "response": [loss_rsp, backward_rsp, optimizer_rsp],
                    "model": [loss_model, backward_model, optimizer_model],
                    "finish_reason": [
                        loss_finish_reason,
                        backward_finish_reason,
                        optimizer_finish_reason,
                    ],
                    "usage": [loss_usage, backward_usage, optimizer_usage],
                    "code": code,
                    "score": None,
                    "metrics": None,
                    "executable_flag": None,
            }

            # func = reward_obj["code"]
            # print(func)
            # print("hash is:", hash(func))
            func_buffer.update_buffer([reward_obj])

        # 2.3) Update the "chosen_resp_text", correspond to "2.2.1)"
        # 2.4) Update the "rej_resp_text", correspond to "2.2.2)"
        (
            (chosen_resp_text, best_score),
            (rej_resp_text, worst_score),
            delta,
            all_failed_flag,
        ) = func_buffer.get_preference_result()

        # Update cache with new responses, get chosen and rejected
        # chosen_response.set_value(
        #     "**Chosen Response**:\n" + "```python" + chosen_resp_text + "```"
        # )

        # if args.use_preference:
        #     # In use_preference mode, update the rejected response for the next iteration
        #     evaluation_sys_text = preference_prompt.EVALUATION_SYS_TEMPLATE.format(
        #         query=question,
        #         rejected_response=rej_resp_text,
        #     )
        # loss_fn = tg.TextLoss(evaluation_sys_text)
        print(
            "=" * 10
            + f"Iteration={cur_iter}/{args.iter_num} Completed"
            + "=" * 10
            + "\n"
        )
        if all_failed_flag:
            print(
                "=" * 10
                + f"Error: iter={0}, sample={cur_sample}, all_failed_flag={all_failed_flag}"
                + "=" * 10
            )
            with open(
                os.path.join(log_path, "error_report.txt"), "a", encoding="utf-8"
            ) as f:
                f.write(
                    f"Error: iter={0}, sample={cur_sample}, all_failed_flag={all_failed_flag}\n"
                )
        # func_buffer.save(log_path, cur_iter_num=cur_iter)
    # 3.save data
    func_buffer.save(log_path)

    # calculate time
    end_time = time.time()
    elapsed_time_str = str(datetime.timedelta(seconds=int(end_time - start_time)))
    print(f"Elapsed Time: {elapsed_time_str}")
    print("=" * 10 + "finished" + "=" * 10)
    # Restore original stdout
    sys.stdout = file_logger.stdout
    sys.stderr = file_logger.stderr
    file_logger.close()


if __name__ == "__main__":
    main()
