import numpy as np
import cma

from abc import (
    ABCMeta,
    abstractmethod,
)
from lift.environments import (
    BaseEnvironment,
    PositionOnly,
)

class LoggingPolicy(metaclass=ABCMeta):
    """
    Base class for expert policies.

    Contains common initialization, noise mapping, and CMA-ES optimization logic.
    Subclasses must implement their own 'predict' method.
    """
    def __init__(self, env, k=0.1):
        """Initializes the base expert policy."""
        if not isinstance(env, BaseEnvironment) and not isinstance(env, PositionOnly):
            if not isinstance(env.env, BaseEnvironment):
                env = env.env.env
            else:
                env = env.env

        self.env = env
        self.k = k
        if not hasattr(self.env, 'n_actions'):
            raise AttributeError("Environment object must have a 'n_action' attribute specifying the action space size.")
        self.r_offset = np.zeros(self.env.n_actions)

    def map_step_to_noise(self, x):
        """Maps the number of steps to a noise scaling factor."""
        return np.exp(-self.k * x)

    def get_avg_step(self, threshold, n_runs=10):
        steps = []
        for _ in range(n_runs):
            obs = self.env.reset()
            steps.append(0)
            done = False

            while not done:
                steps[-1] += 1
                action = self.predict(obs[np.newaxis])
                output= self.env.step(action[-1])
                obs = output['img']
                reward = output['reward']
                done = output['truncated'] or output['terminated'] or (-reward  < threshold)
        return np.mean(steps), np.std(steps)

    def calculate_optimum(self):
        """Uses CMA-ES to find an optimal offset (r_offset)."""

        # Define the objective function locally within calculate_optimum
        # It captures 'self.env' from the enclosing scope
        def objective_cma(r_offset_candidate):
            """CMA-ES objective: calculates distance for a given offset."""
            # Access the environment via self
            score = self.env.compute_distance_to_gt(
                self.env._make_image(self.env.r_optimal + r_offset_candidate)
            )
            return score # CMA-ES minimizes

        # Starting point and initial sigma
        x0 = np.zeros(5)
        sigma0 = 0.01

        es = cma.CMAEvolutionStrategy(x0.tolist(), sigma0, {
            'popsize': 20,
            'maxiter': 50,
            'verbose': -9,
        })

        while not es.stop():
            solutions = es.ask()
            scores = [objective_cma(x) for x in solutions]
            es.tell(solutions, scores)

        self.r_offset = es.result.xbest
        print(f"CMA-ES finished. Found optimal offset: {self.r_offset}: {es.result.fbest}", objective_cma(np.zeros_like(self.r_offset)))

    @abstractmethod
    def restart(self, obs):
        pass

    @abstractmethod
    def predict(self):
        pass



class OptimalPolicy(LoggingPolicy):

    def __init__(self, env, max_step_length=0.1):
        self.max_step_length = max_step_length

        super().__init__(env)

    def restart(self, obs):
        pass

    def predict(self, obs):
        step = self.env.r_optimal - self.env.r
        length = np.linalg.norm(step)

        if length > self.max_step_length:
            return self.max_step_length * step / length
        else:
            return step



class CoordinateWalkPolicy(LoggingPolicy):

    def __init__(self, env, initial_step_length=0.1):

        super().__init__(env)
        self.initial_step_length = initial_step_length

    def restart(self, obs):
        self.dims_optimized = set()
        self.step_sizes = {
            # Initialize with twice the initial step length as its halved on initialization
            i: 2*self.initial_step_length for i in range(self.env.n_actions)
        }

        self.current_dim = -1
        self._init_next_dim()

    def _init_next_dim(self):
        # Take the first unoptimized dimension
        self.current_dim = (self.current_dim + 1) % self.env.n_actions

        self.step_sizes[self.current_dim] = 0.5*self.step_sizes[self.current_dim]


        self.step_length = (self.env.r_optimal - self.env.r)[self.current_dim]

        self.step_direction = np.sign(self.step_length)
        self.step_length_remaining = np.abs(self.step_length)

    def predict(self, obs):

        if self.step_length_remaining < 0:
            self._init_next_dim()

        step = np.zeros(self.env.n_actions)
        step_size = self.step_sizes[self.current_dim]
        step[self.current_dim] = self.step_direction*step_size
        self.step_length_remaining -= step_size

        return step
