import numpy as np
from solver.policy.array_policy import ArrayPolicy


class DirichletArrayPolicy(ArrayPolicy):
    """
    Implements a finite action space feedback policy with Dirichlet initialization.
    """

    def __init__(self,
                 time_steps,
                 state_space,
                 action_space,
                 alpha=1.0):
        """
        Initialize the policy with a Dirichlet distribution.
        
        Args:
            time_steps: Number of time steps
            state_space: State space
            action_space: Action space
            alpha: Concentration parameter for the Dirichlet distribution
                  (alpha=1.0 corresponds to a uniform distribution)
        """
        super().__init__(time_steps, state_space, action_space)
        
        # Override the uniform initialization with Dirichlet samples
        for t in range(time_steps):
            for s in range(state_space[-1].n):
                # Sample from Dirichlet distribution
                alpha_vec = np.ones(action_space.n) * alpha
                self.policy_array[t, s] = np.random.dirichlet(alpha_vec)
    
    def copy(self):
        """
        Create a deep copy of this policy.
        
        Returns:
            A new DirichletArrayPolicy instance with the same policy array
        """
        new_policy = DirichletArrayPolicy(
            self.policy_array.shape[0],  # time_steps
            self.state_space,
            self.action_space,
            alpha=1.0  # Default alpha, won't be used for initialization
        )
        # Override the policy array with a copy of this policy's array
        new_policy.policy_array = self.policy_array.copy()
        return new_policy
