import numpy as np
import pickle

class Policy(object):
    # Basic policy attributes shared across policy types
    def __init__(self, policy_params):
        self.ob_dim = policy_params['ob_dim']
        self.ac_dim = policy_params['ac_dim']
        self.weights = np.empty(0)
        self.env_name = policy_params['env_name']
        
    def update_weights(self, coords_to_perturb, new_weights, total_coordinates):
        raise NotImplementedError

    def get_weights(self):
        return self.weights

    def act(self, ob):
        raise NotImplementedError

class LinearPolicy(Policy):
    # Linear policy implementation

    def __init__(self, policy_params):
        Policy.__init__(self, policy_params)

        # Our initial guess is random
        if policy_params['weight_type'] == "random":
            self.weights = np.random.randn(self.ac_dim, self.ob_dim)

        # Begin from all zeros
        elif policy_params['weight_type'] == "zeros":
            self.weights = np.zeros((self.ac_dim, self.ob_dim), dtype = np.float64)

        # Starting from previously trained policy
        elif policy_params['weight_type'] == "pre_trained":
            policy_file = open(f'./TrainedPolicies/{self.env_name}_finalModel.obj', 'rb')
            pre_trained_policy = pickle.load(policy_file)
            self.weights = pre_trained_policy.weights
            policy_file.close()
        else:
            raise NotImplementedError
        
        self.unseen_coordinates = np.arange(self.weights.shape[0]*self.weights.shape[1])

    def get_action(self, ob):
        return np.dot(self.weights, ob)
    
    # Perform coordinate perturbations
    def update_weights(self, coords_to_perturb, new_weights, total_coordinates):
        for coord in range(total_coordinates):    # Loop over coordinate block
            self.weights[coords_to_perturb[coord][0]][coords_to_perturb[coord][1]] += new_weights[coord]

    def refresh_unseen_coordinates(self, coords_to_perturb):
        if len(self.unseen_coordinates) < coords_to_perturb:
            self.unseen_coordinates = np.arange(self.weights.shape[0]*self.weights.shape[1])
    
    def pick_coords(self, coords_to_perturb):
        coords_to_perturb = np.random.choice(self.unseen_coordinates, size=coords_to_perturb, replace=False) 
        self.unseen_coordinates = np.setdiff1d(self.unseen_coordinates, coords_to_perturb)

        # Update coords_to_perturb to reflect how to index into linear policy
        coords_to_perturb = [[int(coord / self.weights.shape[1]), coord % self.weights.shape[1]] for coord in coords_to_perturb]
        return coords_to_perturb
    
    def update_layer(self):
        return None
    
    # Training checkpoint
    def save_weights(self):
        policy_file = open(f'./TrainedPolicies/{self.env_name}_finalModel.obj', 'wb+')
        self.policy = pickle.dump(self, policy_file)
        policy_file.close()

class LLPolicy(Policy):
    # Policy used for LunarLander experiments

    def __init__(self, policy_params):
        Policy.__init__(self, policy_params)

        # Model is a NN w/ a hidden layer w/ hl_size neurons and tanh activation
        if policy_params['weight_type'] == "random":
            self.weights = {
                'W1': np.random.randn(self.ob_dim, policy_params["hl_size"]),            
                'W2': np.random.randn(policy_params["hl_size"], self.ac_dim)  
            }               
        elif policy_params['weight_type'] == "zeros":
            self.weights = {
                'W1' : np.zeros((self.ob_dim, policy_params["hl_size"])),    
                'W2' : np.zeros((policy_params["hl_size"], self.ac_dim))
            }
        elif policy_params['weight_type'] == "pre_trained":
            policy_file = open(f'./TrainedPolicies/{self.env_name}_finalModel.obj', 'rb')
            pre_trained_policy = pickle.load(policy_file)
            self.weights = pre_trained_policy["weights"]
            policy_file.close()
        else:
            raise NotImplementedError

        self.current_layer = policy_params["start_layer"]

        # Set up dimensions to ascend over
        self.coordinate_shapes = {}
        self.unseen_coordinates = {}
        for layer in self.weights:
            self.coordinate_shapes[layer] = {}

            self.coordinate_shapes[layer]['inp_shape'] = self.weights[layer].shape
            self.coordinate_shapes[layer]['dims'] = np.meshgrid(*[range(d) for d in self.coordinate_shapes[layer]['inp_shape']])
            self.coordinate_shapes[layer]['dims'] = np.stack([np.reshape(d, [-1]) for d in self.coordinate_shapes[layer]['dims']], -1)
            self.coordinate_shapes[layer]['d'] = self.coordinate_shapes[layer]['dims'].shape[0]
            self.coordinate_shapes[layer]['d_list'] = list(range(self.coordinate_shapes[layer]['d']))   # number of random values that must be computed
            self.unseen_coordinates[layer] = self.coordinate_shapes[layer]['d_list'].copy()

    def get_action(self, ob):
        hl = np.matmul(ob, self.weights['W1'])
        hl = np.tanh(hl)
        action = np.matmul(hl, self.weights['W2'])
        return np.tanh(action)

    # Perform coordinate perturbations
    def update_weights(self, coords_to_perturb, new_weights, total_coordinates):
        for coord in range(total_coordinates):    # Loop over coordinate block
            self.weights[self.current_layer][self.coordinate_shapes[self.current_layer]['dims'][coords_to_perturb[coord], 0], self.coordinate_shapes[self.current_layer]['dims'][coords_to_perturb[coord], 1]] += new_weights[coord]

    def refresh_unseen_coordinates(self, coords_to_perturb):
        if len(self.unseen_coordinates[self.current_layer]) < coords_to_perturb:
            self.unseen_coordinates[self.current_layer] = list(self.coordinate_shapes[self.current_layer]['d_list'])

    def pick_coords(self, coords_to_perturb):
        coords_to_perturb =  np.random.choice(self.unseen_coordinates[self.current_layer], coords_to_perturb, replace=False)
        self.unseen_coordinates[self.current_layer] = np.setdiff1d(self.unseen_coordinates[self.current_layer], coords_to_perturb)
        return coords_to_perturb

    # Update layer where coordinates are getting perturbed    
    def update_layer(self):
        if self.current_layer == 'W1':
            self.current_layer = 'W2'
        else:
            self.current_layer = 'W1'

    # Training checkpoint
    def save_weights(self):
        policy_file = open(f'./TrainedPolicies/{self.env_name}_finalModel.obj', 'wb+')
        self.policy = pickle.dump(self, policy_file)
        policy_file.close()