import pickle

import torch
import numpy as np
import tensorflow as tf

class DOPEPolicy:
    """ DOPE policy """

    # def __init__(self, policy_file, device="cpu"):
    def __init__(self, policy_file, device="cpu"):
        self.device = torch.device(device)

        with tf.io.gfile.GFile(policy_file, 'rb') as f:
            weights = pickle.load(f)
        self.fc0_w = torch.from_numpy(weights['fc0/weight']).to(self.device)
        self.fc0_b = torch.from_numpy(weights['fc0/bias']).to(self.device)
        self.fc1_w = torch.from_numpy(weights['fc1/weight']).to(self.device)
        self.fc1_b = torch.from_numpy(weights['fc1/bias']).to(self.device)
        self.fclast_w = torch.from_numpy(weights['last_fc/weight']).to(self.device)
        self.fclast_b = torch.from_numpy(weights['last_fc/bias']).to(self.device)
        self.fclast_w_logstd = torch.from_numpy(weights['last_fc_log_std/weight']).to(self.device)
        self.fclast_b_logstd = torch.from_numpy(weights['last_fc_log_std/bias']).to(self.device)
        # relu = lambda x: torch.maximum(x, 0)
        self.nonlinearity = torch.tanh if weights['nonlinearity'] == 'tanh' else torch.relu

        identity = lambda x: x
        self.output_transformation = torch.tanh if weights[
            'output_distribution'] == 'tanh_gaussian' else identity

    def select_action(self, state, deterministic=False):
        # if torch.is_tensor(state): state = state.cpu().numpy()
        if len(state.shape) == 1:
            state = np.expand_dims(state, axis=0)
        state = torch.as_tensor(state, dtype=torch.float32).to(self.device)
        x = torch.mm(state, self.fc0_w.T) + self.fc0_b
        x = self.nonlinearity(x)
        x = torch.mm(x, self.fc1_w.T) + self.fc1_b
        x = self.nonlinearity(x)
        mean = torch.mm(x, self.fclast_w.T) + self.fclast_b
        logstd = torch.mm(x, self.fclast_w_logstd.T) + self.fclast_b_logstd
        if deterministic:
            action = self.output_transformation(mean)
        else:
            noise = torch.ones_like(logstd)
            action = self.output_transformation(mean + torch.exp(logstd) * noise)
        return action.cpu().numpy()
