#  Copyright (c) 2025

import math


class FootBallHeuristicCurriculum:
    def __init__(
        self,
        iter_add_adversary,
        ai_strength_start,
        ai_strength_end,
        n_iters_annealing,
        nonlinear_annealing,
        env,
        test_env,
        iter_disable_beta,
    ):
        super().__init__()
        self.nonlinear_annealing = nonlinear_annealing
        self.iter_add_adversary = iter_add_adversary
        self.ai_strength_start = ai_strength_start
        self.ai_strength_end = ai_strength_end
        self.n_iters_annealing = n_iters_annealing
        self.iter_disable_beta = iter_disable_beta
        self._set_ai_hardness(env, test_env, ai_strength_start)

    def _set_ai_hardness(self, env, test_env, hardness):
        for scenario in [
            env.scenario,
            test_env.scenario,
        ]:
            scenario.red_controller.enable()
            scenario.red_controller.speed_strength = hardness

    def on_iter(self, iter, env, env_test, logger):
        if iter >= self.iter_disable_beta:
            for scenario in [
                env.scenario,
                env_test.scenario,
            ]:
                scenario.rew_coeff_others = 0
        if iter >= self.iter_add_adversary:
            self.ai_strength = min(
                self.ai_strength_end,
                self.ai_strength_start
                + (self.ai_strength_end - self.ai_strength_start)
                * (iter - self.iter_add_adversary)
                / self.n_iters_annealing,
            )
            if self.nonlinear_annealing:
                strength = math.floor(self.ai_strength * 10) / 10
            else:
                strength = self.ai_strength
            logger.experiment.log(
                {"ai_strength": strength},
                commit=False,
            )

            self._set_ai_hardness(env, env_test, strength)
        else:
            logger.experiment.log(
                {"ai_strength": self.ai_strength_start},
                commit=False,
            )
