from typing import List, Tuple, Union

import gymnasium as gym
import numpy as np
import torch

from src.utils import StateNormalizer, Trajectory

from typing_extensions import Self


class RFF:
    """Random Fourier Features for embedding state-action pairs, trajectories, or policies.

    This class implements Random Fourier Features to approximate a Gaussian kernel k(x,y) = exp(-||x-y||^2/2σ^2),
    where σ is the kernel width. The RFF embedding maps input points to a finite-dimensional space such that
    the inner product of two embeddings approximates the kernel evaluation.

    The kernel width (σ) determines the "similarity sensitivity" of the embedding:
    - Smaller values make the kernel more local, considering points similar only if they are very close
    - Larger values make the kernel more global, considering points similar even when they are farther apart
    How to choose σ? The similarity kernel is like exp(-[Delta]/[σ^2]). As a good rule of thumb
    we set it to the mean/median distance between (s, a) vectors. If we assume (s, a) are
    n dimensional vectors wherein elements are i.i.d Gaussians, the mean distance is
    ~\sqrt{2n}.
    If a None value is passed as the kernel width, sqrt(state_dim+action_dim) is going to be used.
    """

    def __init__(
        self,
        dim: int,
        state_dim: int,
        action_dim: int,
        kernel_width: float,
        normalize: bool = True,
        device="cpu",
        gamma: float = 0.99,
    ):
        self.dim = dim
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.kernel_width = (
            kernel_width if kernel_width else np.sqrt(state_dim + action_dim)
        )
        self.device = device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.normalize = normalize
        self.normalizer = StateNormalizer(state_dim, "cpu") if normalize else None
        self.gamma = gamma
        self.training = True  # Training mode flag

        # Initialize random matrix for RFF without gradients
        with torch.no_grad():
            self.W = (
                torch.randn(self.dim, state_dim + action_dim, device=self.device)
                / self.kernel_width
            )
            self.b = 2 * torch.pi * torch.rand(self.dim, device=self.device)

    @torch.no_grad()
    def embed_state_action(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        """Embed state-action pairs.

        Args:
            state: States tensor of shape (batch_size, state_dim)
            action: Actions tensor of shape (batch_size, action_dim)

        Returns:
            Embedding tensor of shape (batch_size, dim)
        """
        assert state.dim() == 2 and action.dim() == 2
        if self.normalize:
            state = self.normalizer.normalize(state)
            if self.training:
                self.normalizer.update_stats(state)

        # Combine state and action along last dimension
        sa = torch.cat([state, action], dim=-1)

        embedding = torch.sqrt(torch.tensor(2.0 / self.dim)) * torch.cos(
            sa @ self.W.T + self.b
        )

        return embedding

    @torch.no_grad()
    def embed_trajectory(
        self,
        states: Union[np.ndarray, torch.Tensor],
        actions: Union[np.ndarray, torch.Tensor],
    ) -> torch.Tensor:
        """Embed a single trajectory using discounted sum of state-action embeddings.

        Args:
            states: States array/tensor of shape (T, state_dim)
            actions: Actions array/tensor of shape (T, action_dim)

        Returns:
            Trajectory embedding tensor of shape (dim,)
        """
        # Convert to tensors if necessary
        if isinstance(states, np.ndarray):
            states = torch.from_numpy(states.copy()).to(self.device)
        if isinstance(actions, np.ndarray):
            actions = torch.from_numpy(actions.copy()).to(self.device)

        T = states.shape[0]
        gammas = torch.tensor([self.gamma**t for t in range(T)], device=self.device)

        # Embed all state-action pairs
        embeddings = self.embed_state_action(states, actions)  # [T, dim]

        # Multiply by discount factors and sum over time
        return (1 - self.gamma) * (gammas.unsqueeze(1) * embeddings).sum(0)

    @torch.no_grad()
    def embed_trajectories(self, trajectories: List[Trajectory]) -> torch.Tensor:
        """Embed multiple trajectories and return their average embedding.

        Args:
            trajectories: List of Trajectory(states, actions, rewards) objects, where:
                        - trajectory.states is an array of shape (T, state_dim)
                        - trajectory.actions is an array of shape (T, action_dim)
                        T may vary across trajectories

        Returns:
            Average embedding tensor of shape (dim,)
        """
        all_states = torch.cat([torch.tensor(traj.states) for traj in trajectories])
        all_actions = torch.cat([torch.tensor(traj.actions) for traj in trajectories])
        traj_lens = [len(traj.states) for traj in trajectories]

        if self.normalize:
            all_states = self.normalizer.normalize(all_states)
            if self.training:
                self.normalizer.update_stats(all_states)
        sa = torch.cat([all_states, all_actions], dim=-1).to(self.device)
        c = torch.sqrt(torch.tensor(2.0 / self.dim)).to(self.device)
        embeddings = c * torch.cos(sa @ self.W.T + self.b)

        embeddings = torch.split(embeddings, traj_lens)
        traj_embeddings = []
        for traj in embeddings:
            T = traj.shape[0]
            weights = self.gamma ** torch.arange(
                T, dtype=traj.dtype, device=traj.device
            )
            traj_embeddings.append((1 - self.gamma) * (weights @ traj))

        # embeddings = []
        # for traj in trajectories:
        #     traj_embedding = self.embed_trajectory(traj.states, traj.actions)
        #     embeddings.append(traj_embedding)

        return torch.stack(traj_embeddings).mean(0).cpu()

    @torch.no_grad()
    def embed_individual_trajectories(
        self, trajectories: List[Trajectory]
    ) -> torch.Tensor:
        """Embed multiple trajectories and return their average embedding.

        Args:
            trajectories: List of Trajectory(states, actions, rewards) objects, where:
                        - trajectory.states is an array of shape (T, state_dim)
                        - trajectory.actions is an array of shape (T, action_dim)
                        T may vary across trajectories

        Returns:
            Average embedding tensor of shape (dim,)
        """
        all_states = torch.cat([torch.tensor(traj.states) for traj in trajectories])
        all_actions = torch.cat([torch.tensor(traj.actions) for traj in trajectories])
        traj_lens = [len(traj.states) for traj in trajectories]

        if self.normalize:
            all_states = self.normalizer.normalize(all_states)
            if self.training:
                self.normalizer.update_stats(all_states)
        sa = torch.cat([all_states, all_actions], dim=-1).to(self.device)
        c = torch.sqrt(torch.tensor(2.0 / self.dim)).to(self.device)
        embeddings = c * torch.cos(sa @ self.W.T + self.b)

        embeddings = torch.split(embeddings, traj_lens)
        traj_embeddings = []
        for traj in embeddings:
            T = traj.shape[0]
            weights = self.gamma ** torch.arange(
                T, dtype=traj.dtype, device=traj.device
            )
            traj_embeddings.append((1 - self.gamma) * (weights @ traj))

        return torch.stack(traj_embeddings).cpu()

    @torch.no_grad()
    def embed_policy(
        self, policy: torch.nn.Module, env: gym.Env, n_trajectories: int = 10
    ) -> torch.Tensor:
        """Embed a policy by averaging embeddings of trajectories collected using the policy.

        Args:
            policy: Policy object with an act(observation) method that returns actions
            env: Gym-like environment object with reset() and step() methods
            n_trajectories: Number of trajectories to collect and average over

        Returns:
            Policy embedding tensor of shape (dim,)
        """
        trajectories = []

        for _ in range(n_trajectories):
            states, actions = [], []
            state, info = env.reset()
            done = False

            while not done:
                action = policy.act(state)
                states.append(state)
                actions.append(action)
                state, _, term, trunc, _ = env.step(action)
                done = term or trunc

            # Convert lists to arrays
            states = np.array(states)
            actions = np.array(actions)
            trajectories.append((states, actions))

        return self.embed_trajectories(trajectories)

    def to(self, device) -> Self:
        self.W = self.W.to(device)
        self.b = self.b.to(device)
        self.device = device
        return self
