"""
Store all the configuration here
rewards
players
collect_iter: number of iterations to collect data and train
collect_episodes: number of episodes to collect data in one collection iter
iterations: number of iterations to train in each collection iter
all_commits_ratio: percentage that all initial states are local commit during data collection
crash_ratio: possibility that a node will crash during data collection
receive_crash_ratio: possibility that an alive node can receive message from a crashing ndoe.
is_memory: whether use memorized value as trainning target
"""

REWARDS = {
    "Bad": -1,
    "Good": [1] * 1,
    "Neutral": -1,
}
ALL_COMMITS_RATIO = 0.5
CRASH_RATIO = 0.5
RECEIVE_CRASH_RATIO = 0.5
IS_ENCODED_HISTORY = False
FC_LAYER_PARAMS = (
    128,
    32,
    8,
)
EPSILON = 0.05
LEARNING_RATE = 0.001
LOSS_FN = "element_wise_squared_loss"
IS_REFINE = True
REFINE_THRESH = 5000
SUPPORT_PROTOCOLS = [
    "atomic_commit",
    "distributed_locking",
    "counter",
    "math",
    "primary_backup",
]
PROFILE_FRE = 500
CUTOF_INDEX = 2


class configuration:
    def __init__(self, opt):
        self.players = opt.players
        self.protocol = opt.protocol
        self.collect_iter = opt.collect_iter
        self.collect_episodes = opt.num_episodes
        self.iterations = opt.iteration
        self.use_gpu = opt.gpu
        if self.protocol == "math":
            self.math_fn = opt.math_fn
        self.is_memory = False

    def log(self, opt):
        print("protocol: ", self.protocol)
        print("rewards: ", REWARDS)
        print("players: ", self.players)
        print("collect_iter: ", self.collect_iter)
        print("collect_episodes: ", self.collect_episodes)
        print("iterations: ", self.iterations)
        print("model type: ", opt.model)
        if opt.model == "mlp":
            print("fc_layer_params: ", FC_LAYER_PARAMS)
        print("epsilon: ", EPSILON)
        print("learning_rate: ", LEARNING_RATE)
        print("loss_fn: ", LOSS_FN)
        print("is_refine: ", IS_REFINE)
        print("use gpu: ", self.use_gpu)
        if IS_REFINE:
            print("iteration to start refinement training: ", REFINE_THRESH)
        if self.protocol == "atomic_commit":
            print("all_commmits_ratio: ", ALL_COMMITS_RATIO)
            print("crash_ratio: ", CRASH_RATIO)
            print("receive_crash_ratio: ", RECEIVE_CRASH_RATIO)
            print("is_memory: ", self.is_memory)
            print("is_encoded_history: ", IS_ENCODED_HISTORY)
        if self.protocol == "math":
            print("math function to learn: ", opt.math_fn)
        if self.protocol == "primary_backup":
            print("crash_ratio: ", CRASH_RATIO)
            print("receive_crash_ratio: ", RECEIVE_CRASH_RATIO)
