import os
import matplotlib.pyplot as plt
import networkx as nx
from scipy.integrate import solve_ivp
import numpy as np
import torch

class Engine(object):
    def __init__(self, dt, state_dim, action_dim, param_dim):
        self.dt = dt
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.param_dim = param_dim

        self.state = None
        self.action = None
        self.param = None

        self.init()

    def init(self):
        pass

    def get_param(self):
        return self.param.copy()

    def set_param(self, param):
        self.param = param.copy()

    def get_state(self):
        return self.state.copy()

    def set_state(self, state):
        self.state = state.copy()

    def get_scene(self):
        return self.state.copy(), self.param.copy()

    def set_scene(self, state, param):
        self.state = state.copy()
        self.param = param.copy()

    def get_action(self):
        return self.action.copy()

    def set_action(self, action):
        self.action = action.copy()

    def d(self, state, t, param):
        # time derivative
        pass

    def step(self):
        pass

    def render(self, state, param):
        pass

    def clean(self):
        pass


class VoltageControlODEGraph:
    """
    Power System Voltage Control Model with Graph Structure
    
    This class implements a second-order dynamical system that models voltage
    dynamics in power networks. The model includes:
    - Network topology represented as a graph
    - Generator and load nodes with different dynamics
    - Reactive power control capabilities
    - Second-order voltage dynamics with damping
    """
    
    def __init__(self, num_nodes, fixed_graph=False, dt=None, conn_prob=None,
                 alpha=None, beta=None, gamma=None, delta=None, zeta=None, omega_n=None,
                 v_ref=None, gen_ratio_range=None, load_variance=None):
        """
        Initialize the voltage control ODE system with graph structure.
        
        Args:
            num_nodes: Number of nodes in the power network
            fixed_graph: Whether to use a predefined graph structure
            dt: Time step for numerical integration
            conn_prob: Connection probability for random graph generation
            alpha: Voltage control gain parameter
            beta: Power flow sensitivity coefficient
            gamma: Load voltage sensitivity parameter
            delta: Generator control parameter
            zeta: Damping ratio for voltage dynamics
            omega_n: Natural frequency of voltage oscillations
            v_ref: Reference voltage level (per unit)
            gen_ratio_range: Range of generator-to-node ratio [min, max]
            load_variance: Variance for stochastic load variations
        """
        self.num_nodes = num_nodes
        
        # Set default values if not provided
        self.v_ref = v_ref if v_ref is not None else 1.0
        self.alpha = alpha if alpha is not None else 0.3
        self.beta = beta if beta is not None else 0.5
        self.gamma = gamma if gamma is not None else 0.5
        self.omega_n = omega_n if omega_n is not None else 6.0
        self.delta = delta if delta is not None else 1.0
        self.zeta = zeta if zeta is not None else 0.3
        self.conn_prob = conn_prob if conn_prob is not None else 0.4
        self.dt = dt if dt is not None else 0.1
        self.gen_ratio_range = gen_ratio_range if gen_ratio_range is not None else [0.2, 0.4]
        self.load_variance = load_variance if load_variance is not None else 0.01

        # Initialize network topology
        if fixed_graph:
            self.graph = self._generate_fixed_graph(graph_type='ring')
        else:
            self.graph = self._generate_random_graph(self.conn_prob)

        # Extract edge indices for power flow calculations
        self.edge_index = np.array(self.graph.edges).T
        
        # Initialize generator and load nodes
        self._initialize_node_types()
        
        # Initialize bus admittance matrix
        self._initialize_admittance_matrix()
        
        # Set up control and load parameters
        self.Qgen = np.zeros(num_nodes)
        self.Qload_base = np.zeros(num_nodes)

    def _generate_random_graph(self, conn_prob):
        """
        Generate a random graph with given connection probability.
        
        Uses Erdos-Renyi model for random graph generation.
        
        Args:
            conn_prob: Connection probability between each node pair
            
        Returns:
            NetworkX graph object
        """
        G = nx.erdos_renyi_graph(self.num_nodes, conn_prob)
        
        # Ensure graph is connected
        while not nx.is_connected(G):
            # Add random edges until connected
            components = list(nx.connected_components(G))
            if len(components) > 1:
                comp1 = list(components[0])
                comp2 = list(components[1])
                i = np.random.choice(comp1)
                j = np.random.choice(comp2)
                G.add_edge(i, j)
                
        return G
        
    def _generate_fixed_graph(self, graph_type):
        """
        Generate a predefined graph structure.
        
        Args:
            graph_type: Type of graph ('ring', 'star', 'grid', etc.)
            
        Returns:
            NetworkX graph object
        """
        if graph_type == 'ring':
            G = nx.cycle_graph(self.num_nodes)
        elif graph_type == 'star':
            G = nx.star_graph(self.num_nodes - 1)
        elif graph_type == 'grid':
            side = int(np.ceil(np.sqrt(self.num_nodes)))
            G = nx.grid_2d_graph(side, side)
            G = nx.convert_node_labels_to_integers(G)
            # Remove excess nodes
            while G.number_of_nodes() > self.num_nodes:
                G.remove_node(G.number_of_nodes() - 1)
        else:
            # Default to ring
            G = nx.cycle_graph(self.num_nodes)
            
        return G
        
    def _initialize_node_types(self):
        """
        Initialize node types: generators and loads.
        
        Determines which buses are generators vs. loads based on a
        configurable ratio.
        """
        # Allocate generators based on specified ratio range
        min_ratio, max_ratio = self.gen_ratio_range
        num_gen = np.random.randint(int(min_ratio * self.num_nodes), 
                                   int(max_ratio * self.num_nodes) + 1)
        gen_indices = np.random.choice(self.num_nodes, size=num_gen, replace=False)
        self.gen_mask = np.zeros(self.num_nodes, dtype=bool)
        self.gen_mask[gen_indices] = True
        self.load_mask = ~self.gen_mask
        
        # Track number of each type for reporting
        self.num_generators = np.sum(self.gen_mask)
        self.num_loads = np.sum(self.load_mask)

    def _initialize_admittance_matrix(self):
        """
        Initialize the bus admittance matrix for power flow calculations.
        
        This constructs a simplified admittance matrix where all branches
        have unit susceptance.
        """
        # Initialize with zeros
        self.Y_bus = np.zeros((self.num_nodes, self.num_nodes), dtype=complex)
        
        # Add branch admittances (simplified: all branches have unit susceptance)
        for i, j in self.graph.edges():
            # Off-diagonal elements are negative of branch admittance
            self.Y_bus[i, j] = self.Y_bus[j, i] = -1j
            
        # Diagonal elements: sum of all connected branch admittances
        for i in range(self.num_nodes):
            self.Y_bus[i, i] = -np.sum(self.Y_bus[i, :])

    def set_control(self, Qgen_vals):
        """
        Set the reactive power injection for generator nodes.
        
        Args:
            Qgen_vals: Array of reactive power values for generators [p.u.]
        """
        assert len(Qgen_vals) == np.sum(self.gen_mask), "Control input dimension mismatch"
        self.Qgen[self.gen_mask] = Qgen_vals
        
    def _get_Qload(self, t):
        """
        Get the time-varying reactive power demand at load buses.
        
        Args:
            t: Current simulation time [s]
            
        Returns:
            Array of reactive power loads for all nodes
        """
        # Base load
        Qload = self.Qload_base.copy()
        
        # Add time-varying component after initial settling period
        settling_time = 2.0
        if t > settling_time:
            Qload[self.load_mask] += np.random.normal(0, self.load_variance, 
                                                     size=np.sum(self.load_mask))
        return Qload

    def second_order_rhs(self, t, state):
        """
        Second-order system right-hand side function.
        
        Computes the derivatives [dV, ddV] for the voltage dynamics.
        
        Args:
            t: Current time [s]
            state: System state vector [V, dV]
            
        Returns:
            Derivatives [dV, ddV]
        """
        # Extract state components
        V = state[:self.num_nodes]  # Voltage magnitudes
        dV = state[self.num_nodes:]  # Voltage derivatives
        
        # Get current load values
        Qload = self._get_Qload(t)
        
        # Pre-allocate acceleration vector
        ddV = np.zeros_like(V)
        
        # Calculate net reactive power injection
        Q_net = np.zeros_like(V)
        Q_net[self.gen_mask] = self.Qgen[self.gen_mask]  # Generator injections
        Q_net -= Qload  # Load withdrawals
        
        # Calculate reactive power mismatch using power flow equations
        Q_calc = np.zeros_like(V)
        for i in range(self.num_nodes):
            for j in range(self.num_nodes):
                if i != j:
                    # Off-diagonal term
                    Q_calc[i] -= V[i] * V[j] * np.imag(self.Y_bus[i, j])
            # Diagonal term
            Q_calc[i] += V[i]**2 * np.imag(self.Y_bus[i, i])
            
        Q_mismatch = Q_net - Q_calc
        
        # Second-order dynamics
        for i in range(self.num_nodes):
            # Different dynamics for generators vs loads
            if self.gen_mask[i]:
                # Generator dynamics (controllable)
                ddV[i] = -2 * self.zeta * self.omega_n * dV[i] - self.omega_n**2 * (V[i] - self.v_ref) + self.delta * Q_mismatch[i]
            else:
                # Load dynamics (follows power flow)
                ddV[i] = -2 * self.zeta * self.omega_n * dV[i] - self.omega_n**2 * (V[i] - self.v_ref) + self.gamma * Q_mismatch[i]
        
        # Return combined derivatives
        return np.concatenate([dV, ddV])

    def run_simulation(self, V0, dV0, T_sim, control_func=None):
        """
        Run a complete simulation of the voltage dynamics.
        
        Args:
            V0: Initial voltage magnitudes
            dV0: Initial voltage derivatives
            T_sim: Total simulation time [s]
            control_func: Optional function(t, V, dV) that returns control actions
            
        Returns:
            Dictionary containing simulation results
        """
        # Initial state
        state0 = np.concatenate([V0, dV0])
        
        # Time points for simulation
        t_eval = np.arange(0, T_sim, self.dt)
        
        # Storage for results
        results = {
            't': t_eval,
            'V': np.zeros((len(t_eval), self.num_nodes)),
            'dV': np.zeros((len(t_eval), self.num_nodes)),
            'Q_gen': np.zeros((len(t_eval), np.sum(self.gen_mask)))
        }
        
        # Define RHS function with control
        def rhs_with_control(t, state):
            V = state[:self.num_nodes]
            dV = state[self.num_nodes:]
            
            # Apply control if provided
            if control_func is not None:
                control = control_func(t, V, dV)
                self.set_control(control)
                # Store control action
                idx = np.searchsorted(t_eval, t)
                if idx < len(t_eval):
                    results['Q_gen'][idx] = control
            
            return self.second_order_rhs(t, state)
        
        # Solve the ODE system
        sol = solve_ivp(rhs_with_control, [0, T_sim], state0, 
                         t_eval=t_eval, method='RK45')
        
        # Extract results
        results['V'] = sol.y[:self.num_nodes].T
        results['dV'] = sol.y[self.num_nodes:].T
        
        return results

    def get_network_stats(self):
        """
        Return basic statistics about the network.
        
        Returns:
            Dictionary with network information
        """
        return {
            'num_nodes': self.num_nodes,
            'num_edges': self.graph.number_of_edges(),
            'num_generators': np.sum(self.gen_mask),
            'num_loads': np.sum(self.load_mask),
            'avg_degree': np.mean([d for _, d in self.graph.degree()]),
            'diameter': nx.diameter(self.graph) if nx.is_connected(self.graph) else float('inf')
        }
        
    def visualize_network(self, V=None, highlight_edges=None):
        """
        Visualize the power network with optional voltage values.
        
        Args:
            V: Optional voltage values to display (color-coded)
            highlight_edges: Optional edges to highlight
            
        Returns:
            Matplotlib figure
        """
        # Create figure
        plt.figure(figsize=(10, 8))
        
        # Calculate positions for nodes
        pos = nx.spring_layout(self.graph, seed=42)
        
        # Node colors based on type and voltage
        node_colors = []
        for i in range(self.num_nodes):
            if self.gen_mask[i]:
                if V is not None:
                    # Color generators by voltage
                    v_normalized = max(0, min(1, V[i] / self.v_ref))
                    node_colors.append((1-v_normalized, v_normalized, 0))  # Red to Green
                else:
                    node_colors.append('lightgreen')  # Default generator color
            else:
                if V is not None:
                    # Color loads by voltage
                    v_normalized = max(0, min(1, V[i] / self.v_ref))
                    node_colors.append((1-v_normalized, 0, v_normalized))  # Red to Blue
                else:
                    node_colors.append('skyblue')  # Default load color
        
        # Draw nodes with size reflecting their type
        nx.draw_networkx_nodes(
            self.graph, pos,
            nodelist=range(self.num_nodes),
            node_color=node_colors,
            node_size=[300 if self.gen_mask[i] else 200 for i in range(self.num_nodes)]
        )
        
        # Draw edges
        nx.draw_networkx_edges(self.graph, pos, alpha=0.5)
        
        # Highlight specific edges if provided
        if highlight_edges is not None:
            nx.draw_networkx_edges(
                self.graph, pos,
                edgelist=highlight_edges,
                width=2, edge_color='red'
            )
        
        # Draw node labels
        if V is not None:
            labels = {i: f"{i}: {V[i]:.2f}" for i in range(self.num_nodes)}
        else:
            labels = {i: f"{i}" for i in range(self.num_nodes)}
            
        nx.draw_networkx_labels(self.graph, pos, labels=labels, font_size=8)
        
        # Title and styling
        gen_count = np.sum(self.gen_mask)
        load_count = np.sum(self.load_mask)
        plt.title(f"Power Network: {self.num_nodes} nodes ({gen_count} generators, {load_count} loads)")
        plt.axis('off')
        
        return plt.gcf()


# Power system voltage control environment for real-world grid applications
class VoltageControlEngine(Engine):
    """
    Power system voltage control environment that simulates a power grid network.
    
    This environment models voltage dynamics in electrical power systems,
    allowing for control actions on generator reactive power injections.
    """
    def __init__(self, num_nodes, state_dim=None, action_dim=None, param_dim=None,
                 fixed_graph=False, dt=None, conn_prob=None,
                 alpha=None, beta=None, gamma=None, delta=None, zeta=None, omega_n=None,
                 v_ref=None, gen_ratio_range=None, load_variance=None, seed=None):
        """
        Initialize the voltage control environment.
        
        Args:
            num_nodes: Number of nodes in the power grid
            state_dim: State dimension (defaults to 2*num_nodes)
            action_dim: Action dimension (generator reactive power control)
            param_dim: Parameter dimension
            fixed_graph: Whether to use a fixed graph topology
            dt: Time step for simulation
            conn_prob: Connection probability between nodes
            alpha, beta, gamma, delta, zeta: Control parameters
            omega_n: Natural frequency
            v_ref: Reference voltage value
            gen_ratio_range: Range for generator ratio [min, max]
            load_variance: Variance for load fluctuations
            seed: Random seed for reproducibility
        """
        if state_dim is None:
            state_dim = 2 * num_nodes
        if action_dim is None:
            action_dim = num_nodes
        if param_dim is None:
            param_dim = num_nodes
            
        # Set random seed if provided
        if seed is not None:
            np.random.seed(seed)
            
        super().__init__(dt if dt is not None else 0.05, state_dim, action_dim, param_dim)
        
        # Initialize voltage control ODE simulation with graph topology
        self.vc_sim = VoltageControlODEGraph(
            num_nodes=num_nodes,
            fixed_graph=fixed_graph,
            dt=dt,
            conn_prob=conn_prob,
            alpha=alpha,
            beta=beta,
            gamma=gamma,
            delta=delta,
            zeta=zeta,
            omega_n=omega_n,
            v_ref=v_ref,
            gen_ratio_range=gen_ratio_range,
            load_variance=load_variance
        )
        
        # Control parameters: Generator reactive power (Qgen)
        self.param = self.vc_sim.Qgen[self.vc_sim.gen_mask].copy()
        self.reset()
    
    def reset(self):
        """
        Reset the environment to a random initial state.
        
        Initial voltages are uniformly distributed between 0 and 2 p.u.
        Initial voltage derivatives are uniformly distributed between -1 and 1 p.u./s.
        """
        v_min, v_max = 0.0, 2.0
        dv_min, dv_max = -1.0, 1.0
        
        V0 = np.random.uniform(v_min, v_max, size=self.vc_sim.num_nodes)
        dV0 = np.random.uniform(dv_min, dv_max, size=self.vc_sim.num_nodes)
        self.state = np.concatenate([V0, dV0])

    def get_param(self):
        """Return a copy of the current control parameters."""
        return self.param.copy()
    
    def d(self, state, t, param):
        """
        Return the derivatives of the second-order system [dV, ddV].
        
        Args:
            state: Current state [V, dV]
            t: Current time
            param: Control parameters (generator reactive power)
            
        Returns:
            Derivatives [dV, ddV]
        """
        self.vc_sim.set_control(param)
        return self.vc_sim.second_order_rhs(t, state)

    def step(self):
        """
        Advance the system state by one time step using RK45 integration.
        
        Returns:
            Updated system state
        """
        sol = solve_ivp(lambda t, y: self.d(y, t, self.param),
                        [0, self.dt],
                        self.state,
                        t_eval=[self.dt],
                        method='RK45')
        self.state = sol.y[:, -1]
        return self.state

    def get_state(self):
        """
        Get the system state structured as node voltages and derivatives.
        
        Returns:
            State array of shape [num_nodes, 2], with V and dV for each node
        """
        V = self.state[:self.vc_sim.num_nodes]
        dV = self.state[self.vc_sim.num_nodes:]
        return np.stack([V, dV], axis=1)

    def set_state(self, state):
        """Set the system state to the specified value."""
        self.state = state

    def get_action(self):
        """
        Get the current control action for all nodes.
        
        Returns:
            Control array with generator reactive power values
        """
        control = np.zeros(self.vc_sim.num_nodes)
        control[self.vc_sim.gen_mask] = self.param
        return control

    def render(self, step_id=0, save_path=None):
        """
        Visualize the current voltage state of the power grid.
        
        Args:
            step_id: Step identifier for saving frames
            save_path: Directory to save visualization frames
        """
        V = self.state[:self.vc_sim.num_nodes]
        
        # Pass voltage values to the visualization function
        self.vc_sim.visualize_network(V)
        
        if save_path:
            os.makedirs(save_path, exist_ok=True)
            img_path = os.path.join(save_path, f"step_{step_id:03d}.png")
            plt.savefig(img_path)
            print(f"Saved frame to {img_path}")
        else:
            plt.show()

        plt.close()

    def clean(self):
        """Clean up resources."""
        plt.close('all')


class RopeEngine(Engine):

    def __init__(self, dt, state_dim, action_dim, param_dim,
                 num_mass_range=[4, 8], k_range=[500., 1500.], gravity_range=[-2., -8.],
                 position_range=[-0.6, 0.6], bihop=True):

        # state_dim = 4
        # action_dim = 1
        # param_dim = 5
        # param [n_ball, init_x, k, damping, gravity]

        self.radius = 0.06
        self.mass = 1.

        self.num_mass_range = num_mass_range
        self.k_range = k_range
        self.gravity_range = gravity_range
        self.position_range = position_range

        self.bihop = bihop

        super(RopeEngine, self).__init__(dt, state_dim, action_dim, param_dim)

    def init(self, param=None):
        if param is None:
            self.n_ball, self.init_x, self.k, self.damping, self.gravity = [None] * 5
        else:
            self.n_ball, self.init_x, self.k, self.damping, self.gravity = param
            self.n_ball = int(self.n_ball)

        num_mass_range = self.num_mass_range
        position_range = self.position_range
        if self.n_ball is None:
            self.n_ball = rand_int(num_mass_range[0], num_mass_range[1])
        if self.init_x is None:
            self.init_x = np.random.rand() * (position_range[1] - position_range[0]) + position_range[0]
        if self.k is None:
            self.k = rand_float(self.k_range[0], self.k_range[1])
        if self.damping is None:
            self.damping = self.k / 20.
        if self.gravity is None:
            self.gravity = rand_float(self.gravity_range[0], self.gravity_range[1])
        self.param = np.array([self.n_ball, self.init_x, self.k, self.damping, self.gravity])

        # print('Env Rope param: n_ball=%d, init_x=%.4f, k=%.4f, damping=%.4f, gravity=%.4f' % (
        #     self.n_ball, self.init_x, self.k, self.damping, self.gravity))

        self.space = pymunk.Space()
        self.space.gravity = (0., self.gravity)

        self.height = 1.0
        self.rest_len = 0.3

        self.add_masses()
        self.add_rels()

        self.state_prv = None

    @property
    def num_obj(self):
        return self.n_ball

    def add_masses(self):
        inertia = pymunk.moment_for_circle(self.mass, 0, self.radius, (0, 0))
        x = self.init_x
        y = self.height
        self.balls = []

        for i in range(self.n_ball):
            body = pymunk.Body(self.mass, inertia)
            body.position = Vec2d(x, y)
            shape = pymunk.Circle(body, self.radius, (0, 0))

            if i == 0:
                # fix the first mass to a specific height
                move_joint = pymunk.GrooveJoint(self.space.static_body, body, (-2, y), (2, y), (0, 0))
                self.space.add(body, shape, move_joint)
            else:
                self.space.add(body, shape)

            self.balls.append(body)
            y -= self.rest_len

    def add_rels(self):
        give = 1. + 0.075
        # add springs over adjacent balls
        for i in range(self.n_ball - 1):
            c = pymunk.DampedSpring(
                self.balls[i], self.balls[i + 1], (0, 0), (0, 0),
                rest_length=self.rest_len * give, stiffness=self.k, damping=self.damping)
            self.space.add(c)

        # add bihop springs
        if self.bihop:
            for i in range(self.n_ball - 2):
                c = pymunk.DampedSpring(
                    self.balls[i], self.balls[i + 2], (0, 0), (0, 0),
                    rest_length=self.rest_len * give * 2, stiffness=self.k * 0.5, damping=self.damping)
                self.space.add(c)

    def add_impulse(self):
        impulse = (self.action[0], 0)
        self.balls[0].apply_impulse_at_local_point(impulse=impulse, point=(0, 0))

    def get_param(self):
        return self.n_ball, self.init_x, self.k, self.damping, self.gravity

    def get_state(self):
        state = np.zeros((self.n_ball, 4))
        for i in range(self.n_ball):
            ball = self.balls[i]
            state[i] = np.array([ball.position[0], ball.position[1], ball.velocity[0], ball.velocity[1]])

        vel_dim = self.state_dim // 2
        if self.state_prv is None:
            state[:, vel_dim:] = 0
        else:
            state[:, vel_dim:] = (state[:, :vel_dim] - self.state_prv[:, :vel_dim]) / self.dt

        return state

    def step(self):
        self.add_impulse()
        self.state_prv = self.get_state()
        self.space.step(self.dt)

    def render(self, states, actions=None, param=None, video=True, image=False, path=None,
               act_scale=None, draw_edge=True, lim=(-2.5, 2.5, -2.5, 2.5), states_gt=None,
               count_down=False, gt_border=False):
        if video:
            video_path = path + '.avi'
            fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
            print('Save video as %s' % video_path)
            out = cv2.VideoWriter(video_path, fourcc, 25, (640, 480))

        if image:
            image_path = path + '_img'
            print('Save images to %s' % image_path)
            os.system('mkdir -p %s' % image_path)

        c = ['royalblue', 'tomato', 'limegreen', 'orange', 'violet', 'chocolate', 'lightsteelblue']

        time_step = states.shape[0]
        n_ball = states.shape[1]

        if actions is not None and actions.ndim == 3:
            '''get the first ball'''
            actions = actions[:, 0, :]

        for i in range(time_step):
            fig, ax = plt.subplots(1)
            plt.xlim(lim[0], lim[1])
            plt.ylim(lim[2], lim[3])
            plt.axis('off')

            if draw_edge:
                cnt = 0
                for x in range(n_ball - 1):
                    plt.plot([states[i, x, 0], states[i, x + 1, 0]],
                             [states[i, x, 1], states[i, x + 1, 1]],
                             '-', color=c[1], lw=2, alpha=0.5)

            circles = []
            circles_color = []
            for j in range(n_ball):
                circle = Circle((states[i, j, 0], states[i, j, 1]), radius=self.radius * 5 / 4)
                circles.append(circle)
                circles_color.append(c[0])

            pc = PatchCollection(circles, facecolor=circles_color, linewidth=0, alpha=1.)
            ax.add_collection(pc)

            if states_gt is not None:
                circles = []
                circles_color = []
                for j in range(n_ball):
                    circle = Circle((states_gt[i, j, 0], states_gt[i, j, 1]), radius=self.radius * 5 / 4)
                    circles.append(circle)
                    circles_color.append('orangered')
                pc = PatchCollection(circles, facecolor=circles_color, linewidth=0, alpha=1.)
                ax.add_collection(pc)

            if actions is not None:
                F = actions[i, 0] / 4
                normF = norm(F)
                if normF < 1e-10:
                    pass
                else:
                    ax.arrow(states[i, 0, 0] + F / normF * 0.1, states[i, 0, 1],
                             F, 0., fc='Orange', ec='Orange', width=0.04, head_width=0.2, head_length=0.2)

            ax.set_aspect('equal')

            font = {'family': 'serif',
                    'color': 'darkred',
                    'weight': 'normal',
                    'size': 16}
            if count_down:
                plt.text(-2.5, 1.5, 'CountDown: %d' % (time_step - i - 1), fontdict=font)

            plt.tight_layout()

            if video:
                fig.canvas.draw()
                frame = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
                frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                out.write(frame)
                if i == time_step - 1:
                    for _ in range(5):
                        out.write(frame)

            if image:
                plt.savefig(os.path.join(image_path, 'fig_%s.png' % i), bbox_inches='tight')

            plt.close()

        if video:
            out.release()


# ===================================================================
'''
For Soft and Swim
'''


def get_init_p_fish_8():
    init_p = np.zeros((8, 3))
    init_p[0, :] = np.array([0, 0, 2])
    init_p[1, :] = np.array([0, 1, 0])
    init_p[2, :] = np.array([0, 2, 2])
    init_p[3, :] = np.array([0, 3, 0])
    init_p[4, :] = np.array([1, 0, 2])
    init_p[5, :] = np.array([1, 1, 0])
    init_p[6, :] = np.array([1, 2, 2])
    init_p[7, :] = np.array([1, 3, 0])
    return init_p


def sample_init_p_flight(n_box, shape_type=None, aug=False, train=False,
                         min_offset=False, max_offset=False):
    assert 5 <= n_box < 10
    c_box_dict = {
        5: [[1, 3, 1], [2, 1, 2]],
        6: [[3, 3], [2, 2, 2]],
        7: [[2, 3, 2], [1, 2, 1, 2, 1], [2, 1, 1, 1, 2]],
        8: [[2, 2, 2, 2], [1, 2, 2, 2, 1], [2, 1, 2, 1, 2], [3, 2, 3]],
        9: [[2, 2, 1, 2, 2], [1, 2, 3, 2, 1], [2, 1, 3, 1, 2], [3, 3, 3]],
    }

    if shape_type is None:
        shape_type = rand_int(0, len(c_box_dict[n_box]))
    else:
        shape_type = shape_type % len(c_box_dict[n_box])

    c_box = c_box_dict[n_box][shape_type]

    init_p = np.zeros((n_box, 3))
    y_offset = np.zeros(len(c_box))

    for i in range(1, (len(c_box) + 1) // 2):
        left = c_box[i - 1]
        right = c_box[i]
        y_offset[i] = rand_int(1 - right, left)
        if min_offset: y_offset[i] = 1 - right
        if max_offset: y_offset[i] = left
        y_offset[len(c_box) - i] = - y_offset[i]
        assert len(c_box) - i > i

    y = np.zeros(len(c_box))
    for i in range(1, len(c_box)):
        y[i] = y[i - 1] + y_offset[i]
    y -= y.min()

    # print('y_offset', y_offset, 'y', y)

    while True:
        idx = 0
        for i, c in enumerate(c_box):
            for j in range(c):
                # if not train:
                if False:
                    material = 2 if j < c - 1 or c == 1 else 0
                else:
                    r = np.random.rand()
                    if c == 1:
                        r_actuated, r_soft, r_rigid = 0.25, 0.25, 0.5
                    elif j == 0:
                        r_actuated, r_soft, r_rigid = 0.0, 0.5, 0.5
                    elif j == c - 1:
                        r_actuated, r_soft, r_rigid = 0.75, 0.25, 0.0
                    else:
                        r_actuated, r_soft, r_rigid = 0.4, 0.2, 0.4
                    if r < r_actuated:
                        material = 0
                    elif r < r_actuated + r_soft:
                        material = 1
                    else:
                        material = 2
                init_p[idx, :] = np.array([i, y[i] + j, material])
                idx += 1

        if (init_p[:, 2] == 0).sum() >= 2:
            break

    # print('init_p', init_p)

    if aug:
        if np.random.rand() > 0.5:
            '''flip y'''
            init_p[:, 1] = -init_p[:, 1]
        if np.random.rand() > 0.5:
            '''flip x'''
            init_p[:, 0] = -init_p[:, 0]
        if np.random.rand() > 0.5:
            '''swap x and y'''
            x, y = init_p[:, 0], init_p[:, 1]
            init_p[:, 0], init_p[:, 1] = y.copy(), x.copy()

    # print('init_p', init_p)

    return init_p


def sample_init_p_regular(n_box, shape_type=None, aug=False):
    print('sample_init_p')
    init_p = np.zeros((n_box, 3))

    if shape_type is None: shape_type = rand_int(0, 4)
    print('shape_type', shape_type)

    if shape_type == 0:  # 0 or u shape
        init_p[0, :] = np.array([0, 0, 2])
        init_p[1, :] = np.array([-1, 0, 2])
        init_p[2, :] = np.array([1, 0, 2])
        idx = 3
        y = 0
        x = [-1, 0, 1]
        res = n_box - 3
        while res > 0:
            y += 1
            if res == 3:
                i_list = [0, 1, 2]
            else:
                i_list = [0, 2]
            material = [0, 1][int(np.random.rand() < 0.5 and res > 3)]
            for i in i_list:
                init_p[idx, :] = np.array([x[i], y, material])
                idx += 1
                res -= 1

    elif shape_type == 1:  # 1 shape
        init_p[0, :] = np.array([0, 0, 2])
        for i in range(1, n_box):
            material = [0, 1][int(np.random.rand() < 0.5 and i < n_box - 1)]
            init_p[i, :] = np.array([0, i, material])

    elif shape_type == 2:  # I shape
        if n_box < 7:
            init_p[0, :] = np.array([0, 0, 2])
            for i in range(1, n_box - 3):
                material = [0, 1][int(np.random.rand() < 0.5 and i < n_box - 1)]
                init_p[i, :] = np.array([0, i, material])
            init_p[n_box - 1, :] = np.array([-1, n_box - 3, 0])
            init_p[n_box - 2, :] = np.array([0, n_box - 3, 0])
            init_p[n_box - 3, :] = np.array([1, n_box - 3, 0])
        else:
            init_p[0, :] = np.array([-1, 0, 2])
            init_p[1, :] = np.array([0, 0, 2])
            init_p[2, :] = np.array([1, 0, 2])
            for i in range(3, n_box - 3):
                material = [0, 1][int(np.random.rand() < 0.5 and i < n_box - 1)]
                init_p[i, :] = np.array([0, i - 2, material])
            init_p[n_box - 1, :] = np.array([-1, n_box - 5, 0])
            init_p[n_box - 2, :] = np.array([0, n_box - 5, 0])
            init_p[n_box - 3, :] = np.array([1, n_box - 5, 0])

    elif shape_type == 3:  # T shape
        if n_box < 6:
            init_p[0, :] = np.array([-1, 0, 2])
            init_p[1, :] = np.array([0, 0, 2])
            init_p[2, :] = np.array([1, 0, 2])
            for i in range(3, n_box):
                material = [0, 1][int(np.random.rand() < 0.5 and i < n_box - 1)]
                init_p[i, :] = np.array([0, i - 2, material])
        else:
            init_p[0, :] = np.array([-2, 0, 2])
            init_p[1, :] = np.array([-1, 0, 2])
            init_p[2, :] = np.array([0, 0, 2])
            init_p[3, :] = np.array([1, 0, 2])
            init_p[4, :] = np.array([2, 0, 2])
            for i in range(5, n_box):
                material = [0, 1][int(np.random.rand() < 0.5 and i < n_box - 1)]
                init_p[i, :] = np.array([0, i - 4, material])

    elif shape_type == 4:  # stronger T
        assert n_box == 10
        init_p[0, :] = np.array([0, -4, 0])
        init_p[1, :] = np.array([1, -4, 1])
        init_p[2, :] = np.array([0, -3, 0])
        init_p[3, :] = np.array([1, -3, 0])
        init_p[4, :] = np.array([0, -2, 1])
        init_p[5, :] = np.array([1, -2, 0])
        init_p[6, :] = np.array([-1, -1, 2])
        init_p[7, :] = np.array([0, -1, 2])
        init_p[8, :] = np.array([1, -1, 2])
        init_p[9, :] = np.array([2, -1, 2])

    if aug:
        if np.random.rand() > 0.5:
            '''flip y'''
            init_p[:, 1] = -init_p[:, 1]
        if np.random.rand() > 0.5:
            '''swap x and y'''
            x, y = init_p[:, 0], init_p[:, 1]
            init_p[:, 0], init_p[:, 1] = y.copy(), x.copy()

    return init_p


class SoftEngine(Engine):

    def __init__(self, dt, state_dim, action_dim, param_dim,
                 num_box_range=[5, 10], k_range=[600, 1000.]):

        # state_dim = 4
        # action_dim = 1
        # param_dim = 4 - [n_box, k, damping, init_p]
        # init_p: n_box * 3 - [x, y, type]
        # type: 0 - soft & actuated, 1 - soft, 2 - rigid

        self.side_length = 1.
        self.num_box_range = num_box_range
        self.k_range = k_range
        self.radius = 0.01
        self.mass = 1.

        super(SoftEngine, self).__init__(dt, state_dim, action_dim, param_dim)

    @property
    def num_obj(self):
        return self.n_box

    def inside_lim(self, x, y, lim):
        if x >= lim[0] and x < lim[1] and y >= lim[0] and y < lim[1]:
            return True
        return False

    def sample_init_p(self):
        n_box = self.n_box
        r_actuated = 0.5
        r_soft = 0.25
        r_rigid = 0.25
        lim = -4, 4
        mask = np.zeros((lim[1] - lim[0], lim[1] - lim[0]))

        init_p = np.zeros((n_box, 3))
        buf = []

        # add a fixed box
        x, y = 0, -4
        init_p[0] = np.array([x, y, 3])
        buf.append([x - 1, y])
        buf.append([x, y + 1])
        buf.append([x + 1, y])
        mask[x, y] = mask[x - 1, y] = mask[x, y + 1] = mask[x + 1, y] = 1

        for i in range(1, n_box):
            roll_type = np.random.rand()
            if roll_type < r_actuated:
                init_p[i, 2] = 0
            elif roll_type < r_actuated + r_soft:
                init_p[i, 2] = 1
            else:
                init_p[i, 2] = 2

            if len(buf) > 0:
                idx = rand_int(0, len(buf))
                x = buf[idx][0]
                y = buf[idx][1]
                del buf[idx]
            else:
                x = rand_int(lim[0], lim[1])
                y = rand_int(lim[0], lim[1])

            init_p[i, 0], init_p[i, 1] = x, y

            mask[x, y] = 1
            if self.inside_lim(x + 1, y, lim) and mask[x + 1, y] == 0:
                buf.append([x + 1, y]);
                mask[x + 1, y] = 1
            if self.inside_lim(x - 1, y, lim) and mask[x - 1, y] == 0:
                buf.append([x - 1, y]);
                mask[x - 1, y] = 1
            if self.inside_lim(x, y + 1, lim) and mask[x, y + 1] == 0:
                buf.append([x, y + 1]);
                mask[x, y + 1] = 1
            if self.inside_lim(x, y - 1, lim) and mask[x, y - 1] == 0:
                buf.append([x, y - 1]);
                mask[x, y - 1] = 1

        while (init_p[:, 2] == 0).sum() < 2:
            ''' less than 2 actuated'''
            ''' re-generate box type'''
            for i in range(1, n_box):
                roll_type = np.random.rand()
                if roll_type < r_actuated:
                    init_p[i, 2] = 0
                elif roll_type < r_actuated + r_soft:
                    init_p[i, 2] = 1
                else:
                    init_p[i, 2] = 2

        return init_p

    def init(self, param=None):
        if param is None:
            self.n_box, self.k, self.damping, self.init_p = [None] * 4
        else:
            self.n_box, self.k, self.damping, self.init_p = param
            self.n_box = int(self.n_box)

        if self.n_box is None:
            self.n_box = rand_int(self.num_box_range[0], self.num_box_range[1])
        if self.k is None:
            self.k = rand_float(self.k_range[0], self.k_range[1])
        if self.damping is None:
            self.damping = self.k / 20.
        if self.init_p is None:
            self.init_p = self.sample_init_p()
            # self.init_p = sample_init_p_regular(self.n_box, shape_type=4)

        # print('Env Soft param: n_box=%d, k=%.4f, damping=%.4f' % (self.n_box, self.k, self.damping))

        self.space = pymunk.Space()
        self.space.gravity = (0., 0.)

        self.add_masses()
        self.add_rels()

        self.state_prv = None

    def add_masses(self):
        inertia = pymunk.moment_for_circle(self.mass, 0, self.radius, (0, 0))
        self.balls = []

        for i in range(self.n_box):
            x, y, t = self.init_p[i]
            l = self.side_length / 2.

            for j in range(4):
                body = pymunk.Body(self.mass, inertia)

                if j == 0:
                    body.position = Vec2d(x - l, y - l)
                elif j == 1:
                    body.position = Vec2d(x - l, y + l)
                elif j == 2:
                    body.position = Vec2d(x + l, y - l)
                else:
                    body.position = Vec2d(x + l, y + l)

                # shape = pymunk.Circle(body, self.radius, (0, 0))
                # self.space.add(body, shape)
                self.space.add(body)
                self.balls.append(body)

    def add_rels(self):
        ball = self.balls[0]
        c = pymunk.PinJoint(self.space.static_body, ball, (ball.position[0], ball.position[1]), (0, 0))
        self.space.add(c)
        ball = self.balls[2]
        c = pymunk.PinJoint(self.space.static_body, ball, (ball.position[0], ball.position[1]), (0, 0))
        self.space.add(c)
        c = pymunk.DampedSpring(
            self.balls[0], self.balls[1], (0, 0), (0, 0),
            rest_length=self.side_length, stiffness=self.k, damping=self.damping)
        self.space.add(c)
        c = pymunk.DampedSpring(
            self.balls[1], self.balls[3], (0, 0), (0, 0),
            rest_length=self.side_length, stiffness=self.k, damping=self.damping)
        self.space.add(c)
        c = pymunk.DampedSpring(
            self.balls[2], self.balls[3], (0, 0), (0, 0),
            rest_length=self.side_length, stiffness=self.k, damping=self.damping)
        self.space.add(c)
        c = pymunk.DampedSpring(
            self.balls[1], self.balls[2], (0, 0), (0, 0),
            rest_length=self.side_length * np.sqrt(2), stiffness=self.k, damping=self.damping)
        self.space.add(c)
        c = pymunk.DampedSpring(
            self.balls[0], self.balls[3], (0, 0), (0, 0),
            rest_length=self.side_length * np.sqrt(2), stiffness=self.k, damping=self.damping)
        self.space.add(c)

        for i in range(1, self.n_box):
            if self.init_p[i, 2] <= 1:
                # if the box is soft
                # side
                c = pymunk.DampedSpring(
                    self.balls[i * 4], self.balls[i * 4 + 1], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4], self.balls[i * 4 + 2], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4 + 3], self.balls[i * 4 + 1], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4 + 3], self.balls[i * 4 + 2], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                # cross
                c = pymunk.DampedSpring(
                    self.balls[i * 4], self.balls[i * 4 + 3], (0, 0), (0, 0),
                    rest_length=self.side_length * np.sqrt(2), stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4 + 1], self.balls[i * 4 + 2], (0, 0), (0, 0),
                    rest_length=self.side_length * np.sqrt(2), stiffness=self.k, damping=self.damping)
                self.space.add(c)
            else:
                # if the box is rigid
                # side
                c = pymunk.PinJoint(self.balls[i * 4], self.balls[i * 4 + 1], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4], self.balls[i * 4 + 2], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4 + 3], self.balls[i * 4 + 1], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4 + 3], self.balls[i * 4 + 2], (0, 0), (0, 0))
                self.space.add(c)
                # cross
                c = pymunk.PinJoint(self.balls[i * 4], self.balls[i * 4 + 3], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4 + 1], self.balls[i * 4 + 2], (0, 0), (0, 0))
                self.space.add(c)

        # add PinJoint to adjacent boxes
        for i in range(self.n_box):
            for j in range(i):
                for ii in range(4):
                    for jj in range(4):
                        x, y = i * 4 + ii, j * 4 + jj
                        if calc_dis(self.balls[x].position, self.balls[y].position) < 1e-4:
                            c = pymunk.PinJoint(self.balls[x], self.balls[y], (0, 0), (0, 0))
                            self.space.add(c)

    def add_force(self):
        for i in range(self.n_box):
            if self.init_p[i, 2] == 0:
                # if the current box has actuator
                for j in range(4):
                    x, y = i * 4 + j, i * 4 + (3 - j)
                    direct = np.array([
                        self.balls[y].position[0] - self.balls[x].position[0],
                        self.balls[y].position[1] - self.balls[x].position[1]])
                    direct /= norm(direct)
                    force = direct * self.action[i]
                    self.balls[x].apply_force_at_local_point(
                        force=(force[0], force[1]), point=(0, 0))

    def get_param(self):
        return self.n_box, self.k, self.damping, self.init_p

    def get_state(self):
        state = np.zeros((self.n_box, 16))
        for i in range(self.n_box):
            for j in range(4):
                ball = self.balls[i * 4 + j]
                state[i, j * 2: (j + 1) * 2] = \
                    np.array([ball.position[0], ball.position[1]])
                state[i, 8 + j * 2: 8 + (j + 1) * 2] = \
                    np.array([ball.velocity[0], ball.velocity[1]])

        state_acc = state.copy()
        count = np.zeros((self.n_box, 1, 8))

        for i in range(self.n_box):
            for j in range(self.n_box):
                if i == j:
                    count[i, :, :] += 1
                    continue

                delta = self.init_p[i, :2] - self.init_p[j, :2]

                assert (np.abs(delta) > 0).any()

                if (np.abs(delta) > 1).any():
                    # no contact
                    continue

                if np.sum(np.abs(delta)) == 1:
                    # contact at a side
                    if delta[0] == 1:
                        x0, y0, x1, y1 = 1, 3, 0, 2
                    elif delta[0] == -1:
                        x0, y0, x1, y1 = 3, 1, 2, 0
                    elif delta[1] == 1:
                        x0, y0, x1, y1 = 0, 1, 2, 3
                    elif delta[1] == -1:
                        x0, y0, x1, y1 = 1, 0, 3, 2

                    x0 *= 2
                    y0 *= 2
                    x1 *= 2
                    y1 *= 2
                    count[i, :, x0:x0 + 2] += 1
                    count[i, :, x1:x1 + 2] += 1
                    state_acc[i, x0:x0 + 2] += state[j, y0:y0 + 2]
                    state_acc[i, x0 + 8:x0 + 10] += state[j, y0 + 8:y0 + 10]
                    state_acc[i, x1:x1 + 2] += state[j, y1:y1 + 2]
                    state_acc[i, x1 + 8:x1 + 10] += state[j, y1 + 8:y1 + 10]

                elif np.sum(np.abs(delta)) == 2:
                    # contact at a corner
                    if delta[0] == 1 and delta[1] == 1:
                        x, y = 0, 3
                    elif delta[0] == 1 and delta[1] == -1:
                        x, y = 1, 2
                    elif delta[0] == -1 and delta[1] == 1:
                        x, y = 2, 1
                    elif delta[0] == -1 and delta[1] == -1:
                        x, y = 3, 0

                    x *= 2
                    y *= 2
                    count[i, :, x:x + 2] += 1
                    state_acc[i, x:x + 2] += state[j, y:y + 2]
                    state_acc[i, x + 8:x + 10] += state[j, y + 8:y + 10]

        state_acc = state_acc.reshape(self.n_box, 2, 8) / count
        state_acc = state_acc.reshape(self.n_box, 16)

        vel_dim = self.state_dim // 2
        if self.state_prv is None:
            state_acc[:, vel_dim:] = 0
        else:
            state_acc[:, vel_dim:] = (state_acc[:, :vel_dim] - self.state_prv[:, :vel_dim]) / self.dt

        return state_acc

    def step(self):
        self.add_force()
        self.state_prv = self.get_state()
        self.space.step(self.dt)

    def render(self, states, actions=None, param=None, act_scale=10.,
               video=True, image=False, path=None, lim=(-5., 5., -6., 4.),
               states_gt=None, count_down=False, gt_border=False):

        if video:
            video_path = path + '.avi'
            fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
            print('Save video as %s' % video_path)
            out = cv2.VideoWriter(video_path, fourcc, 25, (640, 480))

        if image:
            image_path = path + '_img'
            print('Save images to %s' % image_path)
            os.system('mkdir -p %s' % image_path)

        c = ['royalblue', 'tomato', 'limegreen', 'orange', 'violet', 'chocolate', 'lightsteelblue']

        time_step = states.shape[0]
        n_ball = states.shape[1] * 4
        states = states[:, :, :8].reshape((time_step, n_ball, 2))

        if states_gt is not None:
            states_gt = states_gt[:, :, :8].reshape((time_step, n_ball, 2))

        init_p = param[3]

        for i in range(time_step):
            fig, ax = plt.subplots(1)
            plt.xlim(lim[0], lim[1])
            plt.ylim(lim[2], lim[3])
            plt.axis('off')

            polys = []
            polys_color = []

            circles = []
            circles_color = []

            for j in [0, 2]:
                circle = Circle((states[i, j, 0], states[i, j, 1]), radius=0.1)
                circles.append(circle)
                circles_color.append('orangered')

            for j in range(self.n_box):
                poly = Polygon(np.array([
                    states[i, j * 4, :2], states[i, j * 4 + 1, :2],
                    states[i, j * 4 + 3, :2], states[i, j * 4 + 2, :2]]))
                polys.append(poly)

                if init_p[j, 2] == 0:
                    if actions is not None:
                        act = actions[i, j]
                    else:
                        act = 0.
                    r = (act + act_scale) / (act_scale * 2)
                    if np.abs(r - 0.5) < 1e-4:
                        c = 'cornflowerblue'
                    else:
                        c = to_rgba('tomato')[:3] * r + to_rgba('limegreen')[:3] * (1. - r)
                        c = np.clip(c, 0., 1.)
                    polys_color.append(c)

                elif init_p[j, 2] == 1:
                    polys_color.append('lightsteelblue')
                elif init_p[j, 2] == 2:
                    polys_color.append('dimgray')
                elif init_p[j, 2] == 3:
                    polys_color.append('lightsteelblue')
                else:
                    raise AssertionError("Unknown box type %f" % init_p[j, 2])

            if states_gt is not None:
                polys_gt = []
                for j in range(self.n_box):
                    poly = Polygon(np.array([
                        states_gt[i, j * 4, :2], states_gt[i, j * 4 + 1, :2],
                        states_gt[i, j * 4 + 3, :2], states_gt[i, j * 4 + 2, :2]]))
                    polys_gt.append(poly)

                if gt_border:
                    pc_polys_gt = PatchCollection(
                        polys_gt, facecolor=(0., 0., 0., 0.), edgecolor='orangered', lw=1.)
                else:
                    pc_polys_gt = PatchCollection(
                        polys_gt, facecolor=polys_color, linewidth=0, alpha=0.5)

                circles_gt = []
                for j in [0, 2]:
                    circle = Circle((states[i, j, 0], states[i, j, 1]), radius=0.1)
                    circles_gt.append(circle)

                pc_circles_gt = PatchCollection(circles_gt, facecolor=circles_color, linewidth=0, alpha=0.5)

            pc_polys = PatchCollection(polys, facecolor=polys_color, linewidth=0, alpha=1.)
            pc_circles = PatchCollection(circles, facecolor=circles_color, linewidth=0, alpha=1.)

            ax.add_collection(pc_polys)
            ax.add_collection(pc_circles)

            if states_gt is not None:
                ax.add_collection(pc_polys_gt)
                ax.add_collection(pc_circles_gt)

            ax.set_aspect('equal')

            font = {'family': 'serif',
                    'color': 'darkred',
                    'weight': 'normal',
                    'size': 16}
            if count_down:
                plt.text(-5, 3, 'CountDown: %d' % (time_step - i - 1), fontdict=font)

            plt.tight_layout()

            if video:
                fig.canvas.draw()
                frame = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
                frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                out.write(frame)
                if i == time_step - 1:
                    for _ in range(10):
                        out.write(frame)

            if image:
                plt.savefig(os.path.join(image_path, 'fig_%s.png' % i), bbox_inches='tight')

            plt.close()

        if video:
            out.release()


class SwimEngine(Engine):

    def __init__(self, dt, state_dim, action_dim, param_dim,
                 num_box_range=[5, 10], k_range=[600, 800.]):

        # state_dim = 4
        # action_dim = 1
        # param_dim = 4 - [n_box, k, damping, init_p]
        # init_p: n_box * 3 - [x, y, type]
        # type: 0 - soft & actuated, 1 - soft, 2 - rigid

        self.side_length = 1.
        self.num_box_range = num_box_range
        self.k_range = k_range
        self.radius = 0.01
        self.mass = 1.

        super(SwimEngine, self).__init__(dt, state_dim, action_dim, param_dim)

    @property
    def num_obj(self):
        return self.n_box

    def inside_lim(self, x, y, lim):
        if x >= lim[0] and x < lim[1] and y >= lim[0] and y < lim[1]:
            return True
        return False

    def sample_init_p(self):
        n_box = self.n_box
        r_actuated = 0.5
        r_soft = 0.25
        r_rigid = 0.25
        lim = -4, 4
        mask = np.zeros((lim[1] - lim[0], lim[1] - lim[0]))

        init_p = np.zeros((n_box, 3))
        buf = []

        for i in range(n_box):
            roll_type = np.random.rand()
            if roll_type < r_actuated:
                init_p[i, 2] = 0
            elif roll_type < r_actuated + r_soft:
                init_p[i, 2] = 1
            else:
                init_p[i, 2] = 2

            if len(buf) > 0:
                idx = rand_int(0, len(buf))
                x = buf[idx][0]
                y = buf[idx][1]
                del buf[idx]
            else:
                x = rand_int(lim[0] // 2, lim[1] // 2)
                y = rand_int(lim[0] // 2, lim[1] // 2)

            init_p[i, 0], init_p[i, 1] = x, y

            mask[x, y] = 1
            if self.inside_lim(x + 1, y, lim) and mask[x + 1, y] == 0:
                buf.append([x + 1, y]);
                mask[x + 1, y] = 1
            if self.inside_lim(x - 1, y, lim) and mask[x - 1, y] == 0:
                buf.append([x - 1, y]);
                mask[x - 1, y] = 1
            if self.inside_lim(x, y + 1, lim) and mask[x, y + 1] == 0:
                buf.append([x, y + 1]);
                mask[x, y + 1] = 1
            if self.inside_lim(x, y - 1, lim) and mask[x, y - 1] == 0:
                buf.append([x, y - 1]);
                mask[x, y - 1] = 1

        while (init_p[:, 2] == 0).sum() < 2:
            ''' less than 2 actuated'''
            ''' re-generate box type'''
            for i in range(n_box):
                roll_type = np.random.rand()
                if roll_type < r_actuated:
                    init_p[i, 2] = 0
                elif roll_type < r_actuated + r_soft:
                    init_p[i, 2] = 1
                else:
                    init_p[i, 2] = 2

        return init_p

    def calc_outside(self):
        # recorde whether a specific edge is in the outside
        self.outside = np.ones((self.n_box, 4))
        for i in range(self.n_box):
            for j in range(self.n_box):
                if i == j:
                    continue

                delta = self.init_p[i, :2] - self.init_p[j, :2]

                assert (np.abs(delta) > 0).any()

                if (np.abs(delta) > 1).any():
                    # no contact
                    continue

                if np.sum(np.abs(delta)) == 1:
                    # contact at a side
                    if delta[0] == 1:
                        self.outside[i, 0] = 0
                    elif delta[0] == -1:
                        self.outside[i, 2] = 0
                    elif delta[1] == 1:
                        self.outside[i, 3] = 0
                    elif delta[1] == -1:
                        self.outside[i, 1] = 0

    def init(self, param=None):
        if param is None:
            self.n_box, self.k, self.damping, self.init_p = [None] * 4
        else:
            self.n_box, self.k, self.damping, self.init_p = param
            self.n_box = int(self.n_box)

        if self.n_box is None:
            self.n_box = rand_int(self.num_box_range[0], self.num_box_range[1])
        if self.k is None:
            self.k = rand_float(self.k_range[0], self.k_range[1])
        if self.damping is None:
            self.damping = self.k / 20.
        if self.init_p is None:
            self.init_p = self.sample_init_p()

        # print('Env Swim param: n_box=%d, k=%.4f, damping=%.4f' % (self.n_box, self.k, self.damping))

        self.space = pymunk.Space()
        self.space.gravity = (0., 0.)

        self.add_masses()
        self.add_rels()
        self.calc_outside()

        self.state_prv = None

        # print(self.init_p)
        # print(self.outside)

    def add_masses(self):
        inertia = pymunk.moment_for_circle(self.mass, 0, self.radius, (0, 0))
        self.balls = []

        for i in range(self.n_box):
            x, y, t = self.init_p[i]
            l = self.side_length / 2.

            for j in range(4):
                body = pymunk.Body(self.mass, inertia)

                if j == 0:
                    body.position = Vec2d(x - l, y - l)
                elif j == 1:
                    body.position = Vec2d(x - l, y + l)
                elif j == 2:
                    body.position = Vec2d(x + l, y - l)
                else:
                    body.position = Vec2d(x + l, y + l)

                # shape = pymunk.Circle(body, self.radius, (0, 0))
                # self.space.add(body, shape)
                self.space.add(body)
                self.balls.append(body)

    def add_rels(self):
        for i in range(self.n_box):
            if self.init_p[i, 2] <= 1:
                # if the box is soft
                # side
                c = pymunk.DampedSpring(
                    self.balls[i * 4], self.balls[i * 4 + 1], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4], self.balls[i * 4 + 2], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4 + 3], self.balls[i * 4 + 1], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4 + 3], self.balls[i * 4 + 2], (0, 0), (0, 0),
                    rest_length=self.side_length, stiffness=self.k, damping=self.damping)
                self.space.add(c)
                # cross
                c = pymunk.DampedSpring(
                    self.balls[i * 4], self.balls[i * 4 + 3], (0, 0), (0, 0),
                    rest_length=self.side_length * np.sqrt(2), stiffness=self.k, damping=self.damping)
                self.space.add(c)
                c = pymunk.DampedSpring(
                    self.balls[i * 4 + 1], self.balls[i * 4 + 2], (0, 0), (0, 0),
                    rest_length=self.side_length * np.sqrt(2), stiffness=self.k, damping=self.damping)
                self.space.add(c)
            else:
                # if the box is rigid
                # side
                c = pymunk.PinJoint(self.balls[i * 4], self.balls[i * 4 + 1], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4], self.balls[i * 4 + 2], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4 + 3], self.balls[i * 4 + 1], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4 + 3], self.balls[i * 4 + 2], (0, 0), (0, 0))
                self.space.add(c)
                # cross
                c = pymunk.PinJoint(self.balls[i * 4], self.balls[i * 4 + 3], (0, 0), (0, 0))
                self.space.add(c)
                c = pymunk.PinJoint(self.balls[i * 4 + 1], self.balls[i * 4 + 2], (0, 0), (0, 0))
                self.space.add(c)

        # add PinJoint to adjacent boxes
        for i in range(self.n_box):
            for j in range(i):
                for ii in range(4):
                    for jj in range(4):
                        x, y = i * 4 + ii, j * 4 + jj
                        if calc_dis(self.balls[x].position, self.balls[y].position) < 1e-4:
                            c = pymunk.PinJoint(self.balls[x], self.balls[y], (0, 0), (0, 0))
                            self.space.add(c)

    def add_force(self):
        for i in range(self.n_box):
            if self.init_p[i, 2] == 0:
                # if the current box has actuator
                for j in range(4):
                    x, y = i * 4 + j, i * 4 + (3 - j)
                    direct = np.array([
                        self.balls[y].position[0] - self.balls[x].position[0],
                        self.balls[y].position[1] - self.balls[x].position[1]])
                    direct /= norm(direct)
                    force = direct * self.action[i]
                    self.balls[x].apply_force_at_local_point(
                        force=(force[0], force[1]), point=(0, 0))

        for i in range(self.n_box):
            s = np.zeros((4, 4))
            for j in range(4):
                idx = i * 4 + j
                s[j, 0] = self.balls[idx].position[0]
                s[j, 1] = self.balls[idx].position[1]
                s[j, 2] = self.balls[idx].velocity[0]
                s[j, 3] = self.balls[idx].velocity[1]

            for j in range(4):
                if j == 0:
                    a, b = 0, 1
                elif j == 1:
                    a, b = 1, 3
                elif j == 2:
                    a, b = 3, 2
                else:
                    a, b = 2, 0

                if self.outside[i, j] == 1 and self.init_p[i, 2] == 0 and self.action[i] < 0:
                    direct = s[b, :2] - s[a, :2]
                    dist = norm(direct)
                    direct /= dist
                    direct = np.array([-direct[1], direct[0]])

                    v_scale = np.dot(s[a, 2:], direct)
                    if v_scale > 0.:
                        f = - v_scale ** 2 * direct * dist * 50.
                        self.balls[i * 4 + a].apply_force_at_local_point(
                            force=(f[0], f[1]), point=(0, 0))

                    v_scale = np.dot(s[b, 2:], direct)
                    if v_scale > 0.:
                        f = - v_scale ** 2 * direct * dist * 50.
                        self.balls[i * 4 + b].apply_force_at_local_point(
                            force=(f[0], f[1]), point=(0, 0))

    def get_param(self):
        return self.n_box, self.k, self.damping, self.init_p

    def get_state(self):
        state = np.zeros((self.n_box, 16))
        for i in range(self.n_box):
            for j in range(4):
                ball = self.balls[i * 4 + j]
                state[i, j * 2: (j + 1) * 2] = \
                    np.array([ball.position[0], ball.position[1]])
                state[i, 8 + j * 2: 8 + (j + 1) * 2] = \
                    np.array([ball.velocity[0], ball.velocity[1]])

        state_acc = state.copy()
        count = np.zeros((self.n_box, 1, 8))

        for i in range(self.n_box):
            for j in range(self.n_box):
                if i == j:
                    count[i, :, :] += 1
                    continue

                delta = self.init_p[i, :2] - self.init_p[j, :2]

                assert (np.abs(delta) > 0).any()

                if (np.abs(delta) > 1).any():
                    # no contact
                    continue

                if np.sum(np.abs(delta)) == 1:
                    # contact at a side
                    if delta[0] == 1:
                        x0, y0, x1, y1 = 1, 3, 0, 2
                    elif delta[0] == -1:
                        x0, y0, x1, y1 = 3, 1, 2, 0
                    elif delta[1] == 1:
                        x0, y0, x1, y1 = 0, 1, 2, 3
                    elif delta[1] == -1:
                        x0, y0, x1, y1 = 1, 0, 3, 2

                    x0 *= 2
                    y0 *= 2
                    x1 *= 2
                    y1 *= 2
                    count[i, :, x0:x0 + 2] += 1
                    count[i, :, x1:x1 + 2] += 1
                    state_acc[i, x0:x0 + 2] += state[j, y0:y0 + 2]
                    state_acc[i, x0 + 8:x0 + 10] += state[j, y0 + 8:y0 + 10]
                    state_acc[i, x1:x1 + 2] += state[j, y1:y1 + 2]
                    state_acc[i, x1 + 8:x1 + 10] += state[j, y1 + 8:y1 + 10]

                elif np.sum(np.abs(delta)) == 2:
                    # contact at a corner
                    if delta[0] == 1 and delta[1] == 1:
                        x, y = 0, 3
                    elif delta[0] == 1 and delta[1] == -1:
                        x, y = 1, 2
                    elif delta[0] == -1 and delta[1] == 1:
                        x, y = 2, 1
                    elif delta[0] == -1 and delta[1] == -1:
                        x, y = 3, 0

                    x *= 2
                    y *= 2
                    count[i, :, x:x + 2] += 1
                    state_acc[i, x:x + 2] += state[j, y:y + 2]
                    state_acc[i, x + 8:x + 10] += state[j, y + 8:y + 10]

        state_acc = state_acc.reshape(self.n_box, 2, 8) / count
        state_acc = state_acc.reshape(self.n_box, 16)

        vel_dim = self.state_dim // 2
        if self.state_prv is None:
            state_acc[:, vel_dim:] = 0
        else:
            state_acc[:, vel_dim:] = (state_acc[:, :vel_dim] - self.state_prv[:, :vel_dim]) / self.dt

        return state_acc

    def step(self):
        self.add_force()
        self.state_prv = self.get_state()
        self.space.step(self.dt)

    def render(self, states, actions=None, param=None, act_scale=10.,
               video=True, image=False, path=None, lim=(-6., 6., -7., 5.),
               states_gt=None, count_down=False, gt_border=False):

        if video:
            video_path = path + '.avi'
            fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
            print('Save video as %s' % video_path)
            out = cv2.VideoWriter(video_path, fourcc, 25, (640, 480))

        if image:
            image_path = path + '_img'
            print('Save images to %s' % image_path)
            os.system('mkdir -p %s' % image_path)

        c = ['royalblue', 'tomato', 'limegreen', 'orange', 'violet', 'chocolate', 'lightsteelblue']

        time_step = states.shape[0]
        n_ball = states.shape[1] * 4
        states = states[:, :, :8].reshape((time_step, n_ball, 2))

        if states_gt is not None:
            states_gt = states_gt[:, :, :8].reshape((time_step, n_ball, 2))

        init_p = param[3]

        for i in range(time_step):
            fig, ax = plt.subplots(1)
            plt.xlim(lim[0], lim[1])
            plt.ylim(lim[2], lim[3])
            plt.axis('off')

            polys = []
            polys_color = []

            for j in range(self.n_box):
                poly = Polygon(np.array([
                    states[i, j * 4, :2], states[i, j * 4 + 1, :2],
                    states[i, j * 4 + 3, :2], states[i, j * 4 + 2, :2]]))
                polys.append(poly)

                if init_p[j, 2] == 0:
                    if actions is not None:
                        act = actions[i, j]
                    else:
                        act = 0.
                    r = (act + act_scale) / (act_scale * 2)
                    if np.abs(r - 0.5) < 1e-4:
                        c = 'cornflowerblue'
                    else:
                        c = to_rgba('tomato')[:3] * r + to_rgba('limegreen')[:3] * (1. - r)
                        c = np.clip(c, 0., 1.)
                    polys_color.append(c)

                elif init_p[j, 2] == 1:
                    polys_color.append('lightsteelblue')
                elif init_p[j, 2] == 2:
                    polys_color.append('dimgray')
                elif init_p[j, 2] == 3:
                    polys_color.append('lightsteelblue')
                else:
                    raise AssertionError("Unknown box type %f" % init_p[j, 2])

            if states_gt is not None:
                polys_gt = []
                for j in range(self.n_box):
                    poly = Polygon(np.array([
                        states_gt[i, j * 4, :2], states_gt[i, j * 4 + 1, :2],
                        states_gt[i, j * 4 + 3, :2], states_gt[i, j * 4 + 2, :2]]))
                    polys_gt.append(poly)

                if gt_border:
                    pc_polys_gt = PatchCollection(
                        polys_gt, facecolor=(0., 0., 0., 0.), edgecolor='orangered', lw=1.)
                else:
                    pc_polys_gt = PatchCollection(
                        polys_gt, facecolor=polys_color, linewidth=0, alpha=0.5)

            pc_polys = PatchCollection(polys, facecolor=polys_color, linewidth=0, alpha=1.)

            ax.add_collection(pc_polys)

            if states_gt is not None:
                ax.add_collection(pc_polys_gt)

            ax.set_aspect('equal')

            font = {'family': 'serif',
                    'color': 'darkred',
                    'weight': 'normal',
                    'size': 16}
            if count_down:
                plt.text(-7, 4, 'CountDown: %d' % (time_step - i - 1), fontdict=font)

            plt.tight_layout()

            if video:
                fig.canvas.draw()
                frame = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
                frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                out.write(frame)
                if i == time_step - 1:
                    for _ in range(10):
                        out.write(frame)

            if image:
                plt.savefig(os.path.join(image_path, 'fig_%s.png' % i), bbox_inches='tight')

            plt.close()

        if video:
            out.release()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--env', default='')
    args = parser.parse_args()

    os.system('mkdir -p test')

    if args.env == 'Rope':
        dt = 1. / 50.
        state_dim = 4
        action_dim = 1
        param_dim = 5  # n_ball, init_x, k, damping, gravity

        act_scale = 2.
        ret_scale = 1.

        engine = RopeEngine(dt, state_dim, action_dim, param_dim)

        time_step = 300
        states = np.zeros((time_step, engine.n_ball, engine.state_dim))
        actions = np.zeros((time_step, engine.action_dim))

        for i in range(time_step):
            states[i] = engine.get_state()
            act = (np.random.rand() * 2. - 1.) * act_scale - states[i, 0, 0] * ret_scale
            engine.set_action(np.array([act]))
            engine.step()
            actions[i] = engine.get_action()

        engine.render(states, None, engine.get_param(), video=True, image=True, path='test/Rope')

    elif args.env == 'Soft':
        dt = 1. / 50.
        state_dim = 16
        action_dim = 1
        param_dim = 4  # n_box, k, damping, init_p

        act_scale = 800.
        act_delta = 200.

        engine = SoftEngine(dt, state_dim, action_dim, param_dim)
        engine.init()

        time_step = 100
        states = np.zeros((time_step, engine.n_box, state_dim))
        actions = np.zeros((time_step, engine.n_box, action_dim))

        for i in range(time_step):
            states[i] = engine.get_state()
            box_type = engine.init_p[:, 2]
            for j in range(engine.n_box):
                if box_type[j] == 0:
                    # if this is a actuated box
                    if i == 0:
                        actions[i, j] = rand_float(-act_delta, act_delta)
                    else:
                        actions[i, j] = actions[i - 1, j] + rand_float(-act_delta, act_delta)
                        actions[i, j] = np.clip(actions[i, j], -act_scale, act_scale)
                elif box_type[j] >= 1:
                    # if this is a soft box without actuation OR a rigid box
                    actions[i, j] = 0

            engine.set_action(actions[i])
            engine.step()
            assert np.array_equal(actions[i], engine.get_action())

        engine.render(states, None, engine.get_param(), act_scale=act_scale, video=True, image=True, path='test/Soft',
                      count_down=False)

    elif args.env == 'Swim':
        dt = 1. / 50.
        state_dim = 16
        action_dim = 1
        param_dim = 4  # n_box, k, damping, init_p

        act_scale = 600.
        act_delta = 300.

        engine = SwimEngine(dt, state_dim, action_dim, param_dim)

        tag = ['rand', 'forward', 'rotate'][0]
        for epoch in range(5):
            for num in [8]:
                init_p = sample_init_p_flight(num, epoch, True, train=False)
                engine.init(param=[num, None, None, init_p])

                '''
                init_p = get_init_p_fish_8()
                engine.init(param=[8, None, None, init_p])
                '''

                time_step = 100
                states = np.zeros((time_step, engine.n_box, state_dim))
                actions = np.zeros((time_step, engine.n_box, action_dim))
                actions_param = np.zeros((engine.n_box, 3))

                sin_motion = np.random.rand() < 0.5

                for i in range(time_step):
                    states[i] = engine.get_state()
                    box_type = engine.init_p[:, 2]
                    for j in range(engine.n_box):
                        if box_type[j] == 0:
                            # if this is a actuated box
                            if i == 0:
                                actions_param[j] = np.array(
                                    [rand_float(0., 1.), rand_float(0.5, 4.), rand_float(0, np.pi * 2)])

                            if actions_param[j, 0] < 0.5 and sin_motion == 0:
                                if i == 0:
                                    actions[i, j] = rand_float(-act_delta, act_delta)
                                else:
                                    lo = max(actions[i - 1, j] - act_delta, -act_scale)
                                    hi = min(actions[i - 1, j] + act_delta, act_scale)
                                    actions[i, j] = rand_float(lo, hi)
                                    actions[i, j] = np.clip(actions[i, j], -act_scale, act_scale)
                            else:
                                actions[i, j] = np.sin(i / actions_param[j, 1] + actions_param[j, 2]) * \
                                                rand_float(act_scale / 2., act_scale)

                            if tag == 'rotate':
                                if j < engine.n_box // 2:
                                    if actions[i, j] < 0: actions[i, j] = 0
                                else:
                                    if actions[i, j] > 0: actions[i, j] = 0

                        elif box_type[j] >= 1:
                            # if this is a soft box without actuation OR a rigid box
                            actions[i, j] = 0

                    engine.set_action(actions[i])
                    engine.step()
                    assert np.array_equal(actions[i], engine.get_action())

                os.system('mkdir -p test/swim_{}_train'.format(tag))
                engine.render(
                    states, None, engine.get_param(), act_scale=act_scale, video=True, image=True,
                    path='test/swim_{}_train/Swim_{}_{}'.format(tag, num, epoch), count_down=False)
