import numpy as np


class Interception:
    def __init__(self, alpha=1, initial_state=np.array([4,1,0,1,0,1,1.1,1,-2,0,0]), 
                 terminal_time=3, dt=0.2, inner_step_n=10):
        self.state_dim = 11
        self.u_action_dim = 1
        self.v_action_dim = 1
        self.u_action_min = np.array([0])
        self.u_action_max = np.array([2 * np.pi])
        self.v_action_min = np.array([-2 * np.pi])
        self.v_action_max = np.array([0])
        self.u_radius = np.array([ 0.67, 1]) * 1.3
        self.v_radius = np.array([ 0.71, 1])
        
        self.initial_state = initial_state
        self.terminal_time = terminal_time
        self.dt = dt
        self.inner_step_n = inner_step_n
        self.inner_dt = self.dt / self.inner_step_n
        self.state = self.reset()
        
        self.alpha = alpha
        self.A = np.array([[0,1, 0,0,0]+[0]*5,
                           [0,0, 1,0,0]+[0]*5,
                           [0,0,-1,0,0]+[0]*5,
                           [0,0, 0,0,1]+[0]*5,
                           [0,0, 0,0,0]+[0]*5,
                           [0]*5+[0,1, 0,0,0],
                           [0]*5+[0,0, 1,0,0],
                           [0]*5+[0,0,-1,0,0],
                           [0]*5+[0,0, 0,0,1],
                           [0]*5+[0,0, 0,0,0]])
        
        self.B = np.array([[0,0],[0,0],[1,0],[0,0],[0,0],
                           [0,0],[0,0],[0,1],[0,0],[0,0]])
        
        self.C = np.array([[0,0],[0,0],[0,0],[0,0],[1,0],
                           [0,0],[0,0],[0,0],[0,0],[0,1]])
        return None
    

    def F(self, state, u_action, v_action):
        x = state[1:]
        u = self.u_radius * np.array([np.sin(u_action[0]), np.cos(u_action[0])])
        v = self.v_radius * np.array([np.sin(v_action[0]), np.cos(v_action[0])])
        dx = np.dot(self.A, x) + np.dot(self.B, u) + np.dot(self.C, v)
        return np.hstack((1, dx))

    
    def reset(self):
        self.state = self.initial_state
        return self.state

    
    def step(self, u_action, v_action):
        u_action = np.clip(u_action, self.u_action_min, self.u_action_max)
        v_action = np.clip(v_action, self.v_action_min, self.v_action_max)
        
        for _ in range(self.inner_step_n):
            K1 = self.F(self.state, u_action, v_action)
            K2 = self.F(self.state + K1 * self.inner_dt / 2, u_action, v_action)
            K3 = self.F(self.state + K2 * self.inner_dt / 2, u_action, v_action)
            K4 = self.F(self.state + K3 * self.inner_dt, u_action, v_action)
            self.state = self.state + (K1 + 2 * K2 + 2 * K3 + K4) * self.inner_dt / 6
        
        reward = 0
        done = False
        if self.state[0] >= self.terminal_time - self.dt / 2:
            reward = np.linalg.norm([self.state[1] - self.state[4], self.state[6] - self.state[9]])
            done = True

        return self.state, reward, done, None
    