## Environment of Frequency Control
import numpy as np


class Frequency():
    def __init__(self,Pm,M,D,F,delta_t,max_action,dim_action,Penalty_action):
        self.param_gamma=1
        self.M=M
        self.D=D
        self.Pm=Pm
        self.max_action=max_action
        self.dim_action=dim_action
        self.omega_scale=2*np.pi
        self.state=[]
        self.delta_t=delta_t
        self.Penalty_action=Penalty_action
        self.state_transfer1 = np.vstack((np.hstack((np.identity(dim_action, dtype=np.float32), np.zeros((dim_action, dim_action), dtype=np.float32))),\
                                          np.hstack((delta_t * self.omega_scale * np.identity(dim_action, dtype=np.float32),\
                                                     np.identity(dim_action, dtype=np.float32) - delta_t * np.diag(D / M)))))

        self.state_transferF=-delta_t*(((M**(-1)).reshape(dim_action,1))@np.ones((1,dim_action),dtype=np.float32))*F
        self.state_transfer2 = np.hstack((np.zeros((dim_action, dim_action), dtype=np.float32),\
                                          np.identity(dim_action, dtype=np.float32)))

        self.state_transfer3 = np.hstack((np.zeros((1, dim_action), dtype=np.float32),\
                                          delta_t * Pm * (M ** (-1))))
        self.state_transfer3_Pm = np.hstack((np.zeros((dim_action, dim_action), dtype=np.float32),\
                                             delta_t * np.diag((M ** (-1)))))
        self.state_transfer4 = np.hstack((np.zeros((dim_action, dim_action), dtype=np.float32),\
                                          -delta_t * np.diag((M ** (-1)))))
        self.select_add_w=np.vstack((np.zeros((dim_action,1),dtype=np.float32),\
                                     np.ones((dim_action,1),dtype=np.float32)))
        self.select_w = np.vstack((np.zeros((dim_action, dim_action), dtype=np.float32),\
                                   np.identity(dim_action, dtype=np.float32)))
        self.select_delta = np.vstack((np.identity(dim_action, dtype=np.float32),\
                                       np.zeros((dim_action, dim_action), dtype=np.float32)))

    def step(self,action,Pm):
        self.state=self.state@self.state_transfer1\
            +np.sum(np.sin(np.transpose(self.state@self.select_delta)@np.ones((1,self.dim_action),dtype=np.float32)-\
                           np.ones((self.dim_action,1),dtype=np.float32)@(self.state@self.select_delta))*self.state_transferF,axis=1)\
              @self.state_transfer2\
            +Pm@self.state_transfer3_Pm+action@self.state_transfer4
        loss=self.param_gamma*pow(self.state,2)@self.select_add_w
        return self.state,loss

    def set_state(self,state_input):
        self.state=state_input

    def reset(self):
        initial_state1=np.random.uniform(0.0,0.3,(1,self.dim_action))
        initial_state2=np.random.uniform(-0.03,0.03,(1,self.dim_action))
        s_concate=np.hstack((initial_state1,initial_state2)).astype(np.float32)
        self.state=s_concate
        return self.state