import os.path as osp
import pickle

import numpy as np

# root = ~/.../copo/copo/maour_environment
root = osp.dirname(osp.abspath(__file__))

_checkpoints_buffers = {}


def relu(x):
    return np.clip(x, 0, None)


def sac_policy(weights, obs, deterministic=False):

    obs = np.asarray(obs)
    if obs.ndim == 1:
        obs = np.expand_dims(obs, axis=0)
    assert obs.ndim == 2

    x = np.matmul(obs, weights["default_policy/sequential/action_1/kernel"]) + \
        weights["default_policy/sequential/action_1/bias"]

    x = relu(x)

    x = np.matmul(x, weights["default_policy/sequential/action_2/kernel"]) + \
        weights["default_policy/sequential/action_2/bias"]

    x = relu(x)

    x = np.matmul(x, weights["default_policy/sequential/action_out/kernel"]) + \
        weights["default_policy/sequential/action_out/bias"]

    mean, log_std = np.split(x, 2, axis=1)
    std = np.exp(log_std)
    action = np.random.normal(mean, std) if not deterministic else mean
    squashed = ((np.tanh(action) + 1.0) / 2.0) * 2 - 1
    return squashed


def read_weight(ckpt_path, remove_value_network=True):
    with open(ckpt_path, "rb") as f:
        data = f.read()
    unpickled = pickle.loads(data)
    worker = pickle.loads(unpickled.pop("worker"))
    if "_optimizer_variables" in worker["state"]["default_policy"]:
        worker["state"]["default_policy"].pop("_optimizer_variables")
    weights = worker["state"]["default_policy"]
    if remove_value_network:

        new_weights = {}
        for k, v in weights.items():
            should_use_this_item = True
            for remove_key in ["twin_q", "cost_q", "q_hidden", "q_out", "value", "alpha"]:
                if remove_key in k:
                    should_use_this_item = False
            if should_use_this_item:
                new_weights[k] = v
        return new_weights
    else:
        return weights


class PolicyFunction:
    def __init__(self, ckpt):
        global _checkpoints_buffers
        if ckpt not in _checkpoints_buffers:
            w = read_weight(ckpt)
            _checkpoints_buffers[ckpt] = w
        else:
            w = _checkpoints_buffers[ckpt]
        self.w = w
        self.ckpt = ckpt

    def policy(self, obs, deterministic=False):
        return sac_policy(self.w, obs, deterministic=deterministic)

    def __call__(self, obs, deterministic=False):
        actions = self.policy(obs, deterministic=deterministic)
        # print(actions)
        return actions

    def reset(self):
        pass
