from abc import ABC, abstractmethod
from pathlib import Path

import gymnasium
import numpy as np
import torch

from .value_iteration import state_to_idx


class Policy(ABC):
    """Interface class for policies that matches."""

    def __str__(self):
        return self.__class__.__name__

    @abstractmethod
    def get_eval_action(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError


class OptimalPolicy(Policy):
    """Optimal policy for a given Gridworld."""

    def __init__(
        self, envs: gymnasium.vector.SyncVectorEnv, V_opt: np.ndarray, pi_opt: np.ndarray, vi_path: Path | None = None
    ):
        self.env = envs.envs[0]
        self.V_opt = torch.from_numpy(V_opt).to(torch.float32)
        self.pi_opt = torch.from_numpy(pi_opt).unsqueeze(1).to(torch.int32)

    def get_eval_action(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # The observation x is ignored, as we directly access the state from the environment.
        agent_pos = self.env.unwrapped.agent_pos
        agent_dir = self.env.unwrapped.agent_dir
        x, y = agent_pos
        idx = state_to_idx(self.env, int(x), int(y), int(agent_dir))
        action = self.pi_opt[idx]

        return action, torch.tensor(0.0)


class CircularPolicy(Policy):
    """Policy generating circular data."""

    def __init__(self): ...

    def get_eval_action(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ...


class RandomPolicy(Policy):
    def __init__(self, envs: gymnasium.vector.SyncVectorEnv):
        self.n_actions = envs.single_action_space.n

    def get_eval_action(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.randint(0, self.n_actions, (1,)), torch.tensor(0.0)
