"""
Policy training and generation functions.

This module contains functions for training behavioral cloning policies,
generating random policies, and policy-related utilities.
"""

import time
import numpy as np
import torch
from abc import ABC, abstractmethod
import gymnasium as gym
import itertools


class BasePolicy(ABC):
    @abstractmethod
    def get_action(self, obs, deterministic=True):
        raise NotImplementedError()

    # def evaluate(self, env, N=10, rollout=True):
    #     """
    #     Returns average trajectory reward over N rollouts
    #     """
    #     if not rollout:
    #         print("Warning: Rolling out policy despite rollout=False")
    #     res = 0
    #     for _ in range(N):
    #         obs = env.reset()
    #         done = False
    #         while not done:
    #             a = self.get_action(obs)
    #             obs, reward, done, _ = env.step(a)
    #             res += reward
    #     return res / N


class RandomPolicy(BasePolicy):
    """
    Policy that always returns random actions.
    """

    def __init__(self, action_space: gym.Space):
        self.action_space = action_space

    def get_action(self, obs, deterministic=False):
        return self.action_space.sample()


class EpsGreedyPolicy(BasePolicy):
    """
    Epsilon-greedy policy. With probability epsilon chooses random action,
    with probability 1-epsilon executes greedy_policy.

    Args:
        greedy_policy: BasePolicy to use as greedy policy
        eps: Epsilon value
        action_space: Action space
    """

    def __init__(self, greedy_policy: BasePolicy, eps: float, action_space: gym.Space):
        self.greedy = greedy_policy
        self.eps = eps
        self.action_space = action_space

    def get_action(self, obs, deterministic=False):
        if deterministic or np.random.random() > self.eps:
            return self.greedy.get_action(obs, deterministic=True)
        else:
            return self.action_space.sample()


class TabularPolicy(BasePolicy):
    def __init__(self, policy: np.ndarray):
        self.matrix = np.copy(policy)

    def get_action(self, state, deterministic=True):
        if deterministic:
            return np.argmax(self.matrix[state, :])
        else:
            return np.random.choice(range(self.matrix.shape[1]), p=self.matrix[state, :])

    # def evaluate(self, env, N=1, rollout=False):
    #     assert env.observation_type == "state"
    #     if rollout:
    #         return super().evaluate(env, N)
    #     else:
    #         return env.evaluate_policy(self)

    def __eq__(self, other):
        return np.all(self.matrix == other.matrix)


def train_tabular_BC_policy(
    offline_trajs,
    N_states,
    N_actions,
    init="random",
    n_epochs=10,
    lr=0.01,
    make_deterministic=True,
    verbose=[],
):
    """
    Train behavioral cloning policy from offline trajectories.

    Args:
        offline_trajs: List of offline trajectories
        N_states: Number of states
        N_actions: Number of actions
        init: Initialization method ("random")
        n_epochs: Number of training epochs
        lr: Learning rate
        make_deterministic: Whether to make policy deterministic

    Returns:
        Trained TabularPolicy object
    """
    if init == "random":
        random_policy_logits = torch.rand((N_states, N_actions), dtype=torch.float32)
        policy_logits = torch.nn.functional.softmax(random_policy_logits, dim=1)
    elif init == "uniform":
        policy_logits = torch.ones((N_states, N_actions), dtype=torch.float32) / N_actions
    else:
        raise ValueError(f"init {init} not supported, use 'random' or 'uniform'")
    policy = TabularPolicy(policy_logits)

    # convert np array to trainable torch tensor
    policy.matrix = torch.nn.Parameter(torch.tensor(policy.matrix, dtype=torch.float32))

    optimizer = torch.optim.Adam([policy.matrix], lr=lr)

    # training loop
    for epoch in range(n_epochs):
        optimizer.zero_grad()

        # compute loss for each trajectory
        total_loss = 0
        for traj in offline_trajs:
            states = traj[::3]
            actions = traj[1::3]

            # compute log probabilities for each state-action pair
            for state, action in zip(states, actions):
                state_probs = policy.matrix[state]  # prob of each action given state
                action_prob = state_probs[action]  # prob of action taken
                total_loss -= torch.log(
                    action_prob + 1e-10
                )  # add small epsilon for numerical stability

        total_loss.backward()  # backprop
        optimizer.step()  # update

        # project probabilities to be valid (non-negative and sum to 1)
        with torch.no_grad():
            policy.matrix.data = torch.nn.functional.softmax(policy.matrix, dim=1)

        if ("full" in verbose or "losses" in verbose) and epoch % 10 == 0:
            print(f"  Epoch {epoch}, Loss: {total_loss.item():.4f}")

    # after training, optionally make policy deterministic
    if make_deterministic:
        with torch.no_grad():
            max_probs, max_actions = torch.max(policy.matrix, dim=1)
            policy.matrix.zero_()
            for state in range(policy.matrix.shape[0]):
                policy.matrix[state, max_actions[state]] = 1.0

    if "full" in verbose:
        print(f"Trained policy matrix in {n_epochs} epochs:")
        print(policy.matrix)

    return policy


def generate_random_tabular_policies(
    N_states,  # number of states
    N_actions,  # number of actions
    N_policies=100,  # number of policies to generate
    make_deterministic=True,  # whether to make policies deterministic
    dtype=np.int32,
):
    """
    Generate random policies for initialization.

    Returns:
        List of TabularPolicy objects
    """
    policies = []
    for _ in range(N_policies):
        if make_deterministic:
            policy = np.zeros((N_states, N_actions), dtype=dtype)
            # For each state, randomly select one action to have probability 1
            for s in range(N_states):
                a = np.random.choice(N_actions)
                policy[s, a] = 1
        else:
            policy = np.random.rand(N_states, N_actions)
            policy = policy / np.sum(policy, axis=1, keepdims=True)
        policies.append(TabularPolicy(policy))
    return policies


def generate_random_tabular_policies_vectorized(
    N_states,
    N_actions,
    N_policies=100,
    make_deterministic=True,
    dtype=np.int32,
):
    """
    Generate random tabular policies using vectorized NumPy operations.

    Returns:
        List of TabularPolicy objects
    """
    policies = []
    if make_deterministic:
        # For each policy, for each state, randomly select one action
        # Shape: (N_policies, N_states)
        actions = np.random.randint(0, N_actions, size=(N_policies, N_states))
        # Create one-hot encoded policy matrices
        policy_matrices = np.zeros((N_policies, N_states, N_actions), dtype=dtype)
        rows = np.arange(N_states)
        for i in range(N_policies):
            policy_matrices[i, rows, actions[i]] = 1
    else:
        # Random probabilities, normalized along actions axis
        policy_matrices = np.random.rand(N_policies, N_states, N_actions)
        policy_matrices /= policy_matrices.sum(axis=2, keepdims=True)
    # Convert to TabularPolicy objects
    for i in range(N_policies):
        policies.append(TabularPolicy(policy_matrices[i]))
    return policies


def generate_random_tabular_policies_torch(
    N_states, N_actions, N_policies=100, make_deterministic=True, device="cpu"
):
    policies = []
    if make_deterministic:
        # For each policy, for each state, randomly select one action
        actions = torch.randint(0, N_actions, (N_policies, N_states), device=device)
        # Create one-hot matrices
        policy_matrices = torch.zeros((N_policies, N_states, N_actions), device=device)
        policy_matrices.scatter_(2, actions.unsqueeze(-1), 1)
        for i in range(N_policies):
            policies.append(TabularPolicy(policy_matrices[i].cpu().numpy()))
    else:
        policy_matrices = torch.rand((N_policies, N_states, N_actions), device=device)
        policy_matrices = policy_matrices / policy_matrices.sum(dim=2, keepdim=True)
        for i in range(N_policies):
            policies.append(TabularPolicy(policy_matrices[i].cpu().numpy()))
    return policies


def generate_all_deterministic_stationary_policies(N_states, N_actions, dtype=np.int32):
    """
    Generate all possible deterministic stationary policies for a given state and action space. itertools.product.
    """
    policies = []
    action_combinations = itertools.product(range(N_actions), repeat=N_states)
    total_policies = N_actions**N_states
    for actions in action_combinations:
        # Create policy matrix
        policy_matrix = np.zeros((N_states, N_actions), dtype=dtype)
        for state, action in enumerate(actions):
            policy_matrix[state, action] = 1  # one-hot encoded action
        policy = TabularPolicy(policy_matrix)
        policies.append(policy)

    return policies


# class SB3PolicyTabularWrapper:
#     """
#     Wrapper for Stable Baselines3 policies to work with tabular interface.

#     TODO: Move from preferences_offlineRL/models/mlp_policies.py
#     """

#     def __init__(self, sb3_model, minigrid_env_instance):
#         """
#         Initialize wrapper for SB3 policy.

#         Args:
#             sb3_model: Stable Baselines3 model
#             minigrid_env_instance: MiniGrid environment instance
#         """
#         print("TODO: Initialize SB3PolicyTabularWrapper")

#         self.sb3_model = sb3_model
#         self.minigrid_env = minigrid_env_instance

#         # Extract dimensions
#         if hasattr(minigrid_env_instance, "env"):
#             raw_env = minigrid_env_instance.env
#             self.grid_width = raw_env.width
#             self.grid_height = raw_env.height
#         else:
#             self.grid_width = 4  # Default
#             self.grid_height = 4  # Default

#         self.num_directions = 4
#         self.N_states = self.grid_width * self.grid_height * self.num_directions
#         self.N_actions = minigrid_env_instance.action_space.n

#         self._matrix = None

#     @property
#     def matrix(self):
#         """Get tabular policy matrix."""
#         if self._matrix is None:
#             self._matrix = self._build_matrix()
#         return self._matrix

#     def _build_matrix(self):
#         """Build policy matrix from SB3 model."""
#         print(f"TODO: Build policy matrix ({self.N_states}x{self.N_actions}) from SB3 model")

#         # Placeholder: return uniform random policy
#         policy_matrix = np.ones((self.N_states, self.N_actions)) / self.N_actions
#         return policy_matrix

#     def get_action_distribution(self, state_obs_or_idx):
#         """Get action distribution for given state."""
#         if isinstance(state_obs_or_idx, int):
#             return self.matrix[state_obs_or_idx]
#         else:
#             # Handle observation input
#             return np.ones(self.N_actions) / self.N_actions  # Placeholder

#     def predict(self, observation, state=None, episode_start=None, deterministic=True):
#         """Mimic SB3 predict interface."""
#         if isinstance(observation, int) and deterministic:
#             action = np.argmax(self.matrix[observation])
#             return action, state
#         elif isinstance(observation, int) and not deterministic:
#             action = np.random.choice(self.N_actions, p=self.matrix[observation])
#             return action, state
#         else:
#             return self.sb3_model.predict(
#                 observation, state=state, episode_start=episode_start, deterministic=deterministic
#             )

if __name__ == "__main__":
    N_states = 4
    N_actions = 4
    N_policies = int(1e5)
    make_deterministic = True
    default_time = time.time()

    _ = generate_random_tabular_policies(N_states, N_actions, N_policies, make_deterministic)
    print(f"Default version: time taken: {time.time() - default_time:.2f} seconds")
    vectorized_numpy_time = time.time()
    _ = generate_random_tabular_policies_vectorized(
        N_states, N_actions, N_policies, make_deterministic
    )
    print(f"Vectorized version: time taken: {time.time() - vectorized_numpy_time:.2f} seconds")
    torch_time = time.time()
    _ = generate_random_tabular_policies_torch(
        N_states, N_actions, N_policies, make_deterministic, device="cpu"
    )
    print(f"Torch version: time taken: {time.time() - torch_time:.2f} seconds")
