# '''
# @Author: 
# @Email: 
# @Date: 2020-06-09 23:39:36
# LastEditTime: 2022-12-27 20:00:52
# @Description: 
# '''

import numpy as np
import torch
import time

from .optimizer import Optimizer

class RandomOptimizer_Parallel(Optimizer):
    def __init__(self, num_envs, action_dim, horizon, popsize):
        super().__init__()
        self.num_envs = num_envs
        self.horizon = horizon
        self.popsize = popsize
        self.action_dim = action_dim
        self.solution = None
        self.cost_function = None

    def setup(self, cost_function):
        self.cost_function = cost_function

    def reset(self):
        pass

    def obtain_solution_tower(self, *args, **kwargs):     
        #start = time.time()
        color_dim = 5
        shape_dim = 3
        # convert int to onehot
        color = np.random.randint(0, color_dim-1, size=(self.popsize, self.horizon))
        color = (np.arange(color_dim-1) == color[..., None]).astype(int)
        shape = np.random.randint(0, shape_dim-1, size=(self.popsize, self.horizon))
        shape = (np.arange(shape_dim-1) == shape[..., None]).astype(int)
        stop = np.random.randint(0, 2, size=(self.popsize, self.horizon, 1))

        action = np.concatenate([color, shape, stop], axis=2)

        #end_1 = time.time()

        costs = self.cost_function(action)
        solution = action[np.argmin(costs)]
        return solution

    def generate_one_action(self, low, high, size):
        shape = torch.Size(size)
        if torch.cuda.is_available():
            move = torch.cuda.LongTensor(shape)
        else:
            move = torch.LongTensor(shape)

        torch.randint(0, high, size=shape, out=move)
        move = torch.nn.functional.one_hot(move)
        return move

    def obtain_solution_unlock(self, *args, **kwargs):     
        #start = time.time()
        move_dim = 4
        key_dim = 2
        door_dim = 2
        
        #'''
        # convert int to onehot
        move = np.random.randint(0, move_dim, size=(self.num_envs, self.popsize, self.horizon))
        move = (np.arange(move_dim) == move[..., None]).astype(int)
        pick_key = np.random.randint(0, key_dim, size=(self.num_envs, self.popsize, self.horizon))
        pick_key = (np.arange(key_dim) == pick_key[..., None]).astype(int)
        open_door = np.random.randint(0, door_dim, size=(self.num_envs, self.popsize, self.horizon))
        open_door = (np.arange(door_dim) == open_door[..., None]).astype(int)
        
        
        action = np.concatenate([move, pick_key, open_door], axis=3)
        #'''

        '''
        # torch version
        move = self.generate_one_action(0, move_dim, (self.popsize, self.horizon))
        pick_key = self.generate_one_action(0, key_dim, (self.popsize, self.horizon))
        open_door = self.generate_one_action(0, door_dim, (self.popsize, self.horizon))
        action = torch.cat([move, pick_key, open_door], axis=2)
        '''

        costs = self.cost_function(action)
        # print("COSTS: ", costs)
        #solution = action[np.argmin(costs)].cpu().numpy()
        # solution = action[np.argmin(costs)]
        solution = action[np.arange(0, costs.shape[0]), np.argmin(costs, axis=1)]
        
        return solution
    
    def obtain_solution_crash(self, action_dim):    
        # we should use uniform sampling to increase diversity 
        actions = np.random.uniform(-1, 1, size=(self.num_envs, self.popsize, self.horizon, action_dim))
        
        # make steering action to be 0
        for a_i in range(action_dim):
            if (a_i+1) % 2 == 0:
                actions[:, :, :, a_i] = 0

        costs, costs_list, finish_flag_list = self.cost_function(actions)
        solution = actions[np.arange(0, costs.shape[0]), np.argmin(costs, axis=1)]
        return solution

    def obtain_solution_chemistry(self, action_dim):     
        # convert int to onehot
        action = np.random.randint(0, action_dim, size=(self.popsize, self.horizon))
        action = (np.arange(action_dim) == action[..., None]).astype(int)
        costs = self.cost_function(action)
        solution = action[np.argmin(costs)]
        return solution

    def obtain_solution_lift(self, low, high):
        """
        Generate random solutions for the LiftEnv.

        :param low: The lower bound of the action space.
        :param high: The upper bound of the action space.
        :return: The best solution based on the cost function.
        """
        # Sample random actions within the valid range
        actions = np.random.uniform(low, high, size=(self.popsize, self.horizon, self.action_dim-1))
        
        # Evaluate the cost of each action sequence
        costs = self.cost_function(actions)

        # Select the action sequence with the lowest cost
        solution = actions[np.arange(0, costs.shape[0]), np.argmin(costs, axis=1)]
        return solution
    
    def obtain_solution_lift(self, low, high, oracle_action):    
        # we should use uniform sampling to increase diversity 

        actions_xyz = np.random.uniform(low, high, size=(self.num_envs, self.popsize, self.horizon, self.action_dim-1))
        actions_gripper = np.random.randint(0, 3, size=(self.num_envs, self.popsize, self.horizon, 1)) - 1
        actions = np.concatenate([actions_xyz, actions_gripper], axis=3)
        
        costs = self.cost_function(actions)
        solution = actions[np.arange(0, costs.shape[0]), np.argmin(costs, axis=1)]
        return solution
# '''
# @Author: 
# @Email: 
# @Date: 2020-06-09 23:39:36
# LastEditTime: 2022-12-27 20:00:52
# @Description: 
# '''

# import numpy as np
# import torch
# import time

# from .optimizer import Optimizer

class RandomOptimizer(Optimizer):
    def __init__(self, action_dim, horizon, popsize):
        super().__init__()
        self.horizon = horizon
        self.popsize = popsize
        self.action_dim = action_dim
        self.solution = None
        self.cost_function = None

    def setup(self, cost_function):
        self.cost_function = cost_function

    def reset(self):
        pass

    def obtain_solution_tower(self, *args, **kwargs):     
        #start = time.time()
        color_dim = 5
        shape_dim = 3
        # convert int to onehot
        color = np.random.randint(0, color_dim-1, size=(self.popsize, self.horizon))
        color = (np.arange(color_dim-1) == color[..., None]).astype(int)
        shape = np.random.randint(0, shape_dim-1, size=(self.popsize, self.horizon))
        shape = (np.arange(shape_dim-1) == shape[..., None]).astype(int)
        stop = np.random.randint(0, 2, size=(self.popsize, self.horizon, 1))

        action = np.concatenate([color, shape, stop], axis=2)

        #end_1 = time.time()

        costs = self.cost_function(action)
        solution = action[np.argmin(costs)]
        return solution

    def generate_one_action(self, low, high, size):
        shape = torch.Size(size)
        if torch.cuda.is_available():
            move = torch.cuda.LongTensor(shape)
        else:
            move = torch.LongTensor(shape)

        torch.randint(0, high, size=shape, out=move)
        move = torch.nn.functional.one_hot(move)
        return move

    def obtain_solution_unlock(self, *args, **kwargs):     
        #start = time.time()
        move_dim = 4
        key_dim = 2
        door_dim = 2
        
        #'''
        # convert int to onehot
        move = np.random.randint(0, move_dim, size=(self.popsize, self.horizon))
        move = (np.arange(move_dim) == move[..., None]).astype(int)
        pick_key = np.random.randint(0, key_dim, size=(self.popsize, self.horizon))
        pick_key = (np.arange(key_dim) == pick_key[..., None]).astype(int)
        open_door = np.random.randint(0, door_dim, size=(self.popsize, self.horizon))
        open_door = (np.arange(door_dim) == open_door[..., None]).astype(int)
        action = np.concatenate([move, pick_key, open_door], axis=2)
        #'''
        
        '''
        # torch version
        move = self.generate_one_action(0, move_dim, (self.popsize, self.horizon))
        pick_key = self.generate_one_action(0, key_dim, (self.popsize, self.horizon))
        open_door = self.generate_one_action(0, door_dim, (self.popsize, self.horizon))
        action = torch.cat([move, pick_key, open_door], axis=2)
        '''

        costs = self.cost_function(action)
        #solution = action[np.argmin(costs)].cpu().numpy()
        solution = action[np.argmin(costs)]
        return solution

    def obtain_solution_crash(self, action_dim):    
        # we should use uniform sampling to increase diversity 
        actions = np.random.uniform(-1, 1, size=(self.popsize, self.horizon, action_dim))

        # make steering action to be 0
        for a_i in range(action_dim):
            if (a_i+1) % 2 == 0:
                actions[:, :, a_i] = 0

        costs, costs_list, finish_flag_list = self.cost_function(actions)
        # print('costs: ', costs.shape)
        solution = actions[np.argmin(costs)]
        return solution

    def obtain_solution_chemistry(self, action_dim):     
        # convert int to onehot
        action = np.random.randint(0, action_dim, size=(self.popsize, self.horizon))
        action = (np.arange(action_dim) == action[..., None]).astype(int)
        costs = self.cost_function(action)
        solution = action[np.argmin(costs)]
        return solution

    def obtain_solution_lift(self, low, high):    
        # we should use uniform sampling to increase diversity 
        actions_xyz = np.random.uniform(low, high, size=(self.popsize, self.horizon, self.action_dim-1))
        actions_gripper = np.random.randint(0, 3, size=(self.popsize, self.horizon, 1)) - 1
        actions = np.concatenate([actions_xyz, actions_gripper], axis=2)

        costs = self.cost_function(actions)
        solution = actions[np.argmin(costs)]
        return solution