# Copyright (c) 2024-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/BorealisAI/llm-pddl-planning/blob/main/src/evaluation.py


from typing import Callable, Dict, Optional

import numpy as np

from tp_lodge.random_walk.random_walk_env import RandomWalkEnv
from tp_lodge.task_planning.pddl_planner.ai_planner.ai_validator import AIValidator
from tp_lodge.utils.pddl_lib_utils import remove_types_from_domain
from tp_lodge.utils.planning_cache_utils import SasAction


def harmonic_mean(a, b):
    if a + b == 0:
        return 0.0
    return 2 * a * b / (a + b)


class PlanRatings:
    EMPTY_CODE = -6.0
    INVALID_MODIFICATION = -5.0
    PDDL_SANITY_ERROR = -4.0  # empty effect actions
    INVALID_DOMAIN = -3.0  # e.g., undefined predicates
    NO_PLAN = -1.0  # e.g., disconnected initial state and goal state
    # 0.0 <= rating <= 1.0: ratio of random walks that are executable
    SOLUTION_FOUND = 2.0


class RandomWalkEvaluator:
    def __init__(
        self,
        gt_domain: str,
        gt_function_mapping: Dict[str, Callable[[SasAction], str]],
        gt_problem: Optional[str] = None,
    ):
        self.gt_domain = gt_domain
        self.gt_function_mapping = gt_function_mapping
        self.ai_validator = AIValidator()

        if gt_problem is not None:
            self.target_env = RandomWalkEnv(
                domain_pddl=gt_domain, problem_pddl=gt_problem, function_mapping=gt_function_mapping
            )
        else:
            self.target_env = None

    def evaluate_task(
        self,
        gen_domain: str,
        gen_problem: str,
        gt_problem: Optional[str],
        gen_function_mapping: Dict[str, Callable[[SasAction], str]],
        *,
        n_walks: int = 100,
    ):
        # Check if the domain is valid (ignore types by removing them)
        _, is_domain_valid = self.ai_validator.validate(
            domain=remove_types_from_domain(gen_domain),
            problem=gen_problem,
            plan=None,
            options="-v",
        )
        if not is_domain_valid:
            return PlanRatings.INVALID_DOMAIN, 0, 0

        if gt_problem is not None:
            self.target_env = RandomWalkEnv(
                domain_pddl=self.gt_domain, problem_pddl=gt_problem, function_mapping=self.gt_function_mapping
            )
        assert self.target_env is not None
        gen_env = RandomWalkEnv(domain_pddl=gen_domain, problem_pddl=gen_problem, function_mapping=gen_function_mapping)

        eval_turn = np.arange(n_walks) % 2 == 0
        exec_cnt = 0
        t_to_gen_exec, gen_to_t_exec = 0, 0
        t_to_gen_all, gen_to_t_all = 0, 0
        max_length = 10
        for turn in eval_turn:
            if not turn:
                max_steps = (t_to_gen_all % max_length) + 1
                random_walk_plan = self.target_env.get_random_walk_plan(max_steps=max_steps)
                assert len(random_walk_plan) > 0, "Empty plan generated by target environment!"

                is_executable = gen_env.get_plan_execution_feedback(random_walk_plan)
                t_to_gen_all += 1
                t_to_gen_exec += is_executable is True
            else:
                max_steps = (gen_to_t_all % max_length) + 1
                random_walk_plan = gen_env.get_random_walk_plan(max_steps=max_steps)
                if len(random_walk_plan) == 0:
                    return PlanRatings.NO_PLAN, 0, 0
                is_executable = self.target_env.get_plan_execution_feedback(random_walk_plan)
                gen_to_t_all += 1
                gen_to_t_exec += is_executable is True
            exec_cnt += is_executable is True

        gen_env.close()
        self.target_env.close()

        total_avg = harmonic_mean(t_to_gen_exec / t_to_gen_all, gen_to_t_exec / gen_to_t_all)
        return total_avg, t_to_gen_exec / t_to_gen_all, gen_to_t_exec / gen_to_t_all
