# import enum
# from typing import Set
#
# import torch
#
#
# class RewardHeuristicsStrength(enum.Enum):
#     no_heuristics = 0
#     bimodal_dense = 1  # deprecated
#     bimodal_sparse = 2  # deprecated
#     trimodal = 3  # deprecated
#     bimodal_sparse2 = 4
#     trimodal2 = 5
#     bimodal_ground_truth = 6
#     trimodal_ground_truth = 7
#     bimodal_custom = 8
#     trimodal_custom = 9
#
#
# def reward_heuristics_trimodal2(example_batch, default_reward_value):
#     batch_size = len(example_batch["chat_feedback"])
#     rewards = [default_reward_value] * batch_size
#     available = [True] * batch_size
#     for b in range(batch_size):
#         if example_batch["chat_feedback"][b] == "":
#             continue
#         chat_feedback = example_batch["chat_feedback"][b].lower()
#         if any(x in chat_feedback for x in ["yes", "good", "keep", "next", "correct"]):
#             rewards[b] = 1
#         if any(x in chat_feedback for x in ["deselect", "delete", "wrong", "incorrect"]):
#             rewards[b] = -1
#
#     return {"reward": torch.tensor(rewards, dtype=torch.float),
#             "available": torch.tensor(available, dtype=torch.float)}
#
#
# def reward_heuristics_bimodal_sparse2(example_batch, default_reward_value):
#     temp = reward_heuristics_trimodal2(example_batch, default_reward_value)
#     rewards = temp["reward"]
#     available = temp["available"]
#     available[rewards == default_reward_value] = 0.0
#     return {"reward": rewards, "available": available}
#
#
# def reward_heuristics_trimodal_ground_truth(example_batch, default_reward_value):
#     batch_size = len(example_batch["chat_feedback"])
#     rewards = [-1] * batch_size
#     available = [True] * batch_size
#     for b in range(batch_size):
#         # if example_batch["is_good_select"][b] or example_batch["is_good_deselect"][b]:
#         #     rewards[b] = 1
#         if example_batch["is_good_select"][b] and \
#                 len(example_batch["deselected"][b]) == 0:
#             # pure gt good select
#             rewards[b] = 1
#         elif example_batch["is_good_deselect"][b] and \
#                 len(example_batch["selected"][b]) == 0:
#             # pure gt good deselect
#             rewards[b] = 1
#         elif example_batch["is_good_select"][b] and \
#                 example_batch["is_good_deselect"][b]:
#             # both gt good select and good deselect
#             rewards[b] = 1
#         elif example_batch["is_good_select"][b] and \
#                 len(example_batch["deselected"][b]) > 0:
#             rewards[b] = default_reward_value
#         elif example_batch["is_good_deselect"][b] and \
#                 len(example_batch["selected"][b]) > 0:
#             rewards[b] = default_reward_value
#         else:
#             rewards[b] = -1
#     return {"reward": torch.tensor(rewards, dtype=torch.float),
#             "available": torch.tensor(available, dtype=torch.float)}
#
#
# def reward_heuristics_bimodal_ground_truth(example_batch, default_reward_value):
#     temp = reward_heuristics_trimodal_ground_truth(
#         example_batch, default_reward_value)
#     rewards = temp["reward"]
#     available = temp["available"]
#     available[rewards == default_reward_value] = -1
#     return {"reward": rewards, "available": available}
#
#
# def reward_heuristics_bimodal_custom(example_batch, default_reward_value: float, **kwargs):
#     temp = reward_heuristics_trimodal_custom(
#         example_batch, default_reward_value, **kwargs)
#     rewards = temp["reward"]
#     available = temp["available"]
#     available[rewards == default_reward_value] = -1
#     return {"reward": rewards, "available": available}
#
#
# def reward_heuristics_trimodal_custom(example_batch,
#                                       default_reward_value: float,
#                                       pos_set: Set, neg_set: Set, neu_set: Set):
#     batch_size = len(example_batch["chat_feedback"])
#     rewards = [0] * batch_size
#     available = [True] * batch_size
#     for b in range(batch_size):
#         game_turn_id = example_batch["game_turn_id"][b]
#         if game_turn_id in pos_set:
#             rewards[b] = 1
#         elif game_turn_id in neg_set:
#             rewards[b] = -1
#         elif game_turn_id in neu_set:
#             rewards[b] = default_reward_value
#         else:
#             raise ValueError(f"Unknown game_turn_id: {game_turn_id}")
#     return {"reward": torch.tensor(rewards, dtype=torch.float),
#             "available": torch.tensor(available, dtype=torch.float)}
#
#
# def reward_heuristics(example_batch, strength: RewardHeuristicsStrength, default_reward_value: float, **kwargs):
#     strength = RewardHeuristicsStrength(strength)
#     if strength is RewardHeuristicsStrength.no_heuristics:
#         return {"reward": torch.full(len(
#             example_batch["chat_feedback"], fill_value=default_reward_value), dtype=torch.float),
#             "available": torch.zeros(len(example_batch["chat_feedback"]), dtype=torch.float)}
#     elif strength is RewardHeuristicsStrength.bimodal_sparse2:
#         return reward_heuristics_bimodal_sparse2(example_batch, default_reward_value)
#     elif strength is RewardHeuristicsStrength.trimodal2:
#         return reward_heuristics_trimodal2(example_batch, default_reward_value)
#     elif strength is RewardHeuristicsStrength.bimodal_ground_truth:
#         return reward_heuristics_bimodal_ground_truth(example_batch, default_reward_value)
#     elif strength is RewardHeuristicsStrength.trimodal_ground_truth:
#         return reward_heuristics_trimodal_ground_truth(example_batch, default_reward_value)
#     elif strength is RewardHeuristicsStrength.bimodal_custom:
#         return reward_heuristics_bimodal_custom(example_batch, default_reward_value, **kwargs)
#     elif strength is RewardHeuristicsStrength.trimodal_custom:
#         return reward_heuristics_trimodal_custom(example_batch, default_reward_value, **kwargs)
#     else:
#         raise ValueError(f"Unknown strength: {strength}")
