import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import print
import numpy as np

class Scale(nn.Module):
    def __init__(self, scale=1.0):
        super().__init__()
        self.scale = scale
    def forward(self, x):
        return x * self.scale

class ClipWithGrad(nn.Module):
    def __init__(self, clip=1.0):
        super().__init__()
        self.clip = clip
    def forward(self, x):
        x1 = x.detach().clamp(-self.clip, self.clip)
        return x + (x1 - x).detach()

class MLPControl(nn.Module):
    # clip_value: despite the reward shaping may cause return to exceed +/-1 range,
    # the discount factor shrinks the range; so clip_value=1.0 is probably also safe.
    def __init__(self, nstate=120, nact=8, nhidden=[64,64],
            activation=nn.Tanh, nvf=1, pretrain=None,
            clip_action=1.5, clip_value=1.1, init_std=1.0,
            trainable_std=True):
        super().__init__()

        h = nstate
        layers = []
        for nh in nhidden:
            layers.append(nn.Linear(h, nh))
            layers.append(activation())
            h = nh
        layers.append(nn.Linear(h, nact))
        nn.init.orthogonal_(layers[-1].weight, gain=0.1)  # gain=1.0, 0.5, 0.1?
        nn.init.constant_(layers[-1].bias, 0)
        if clip_action:
            # assert clip_action == 1.0
            # layers.append(nn.Tanh())
            layers.append(ClipWithGrad(clip_action))
        self.policy = nn.Sequential(*layers)

        h = nstate
        layers = []
        for nh in nhidden:
            layers.append(nn.Linear(h, nh))
            layers.append(activation())
            h = nh
        layers.append(nn.Linear(h, 1))
        nn.init.orthogonal_(layers[-1].weight, gain=0.01)
        nn.init.constant_(layers[-1].bias, 0)
        if clip_value > 0:
            layers.append(ClipWithGrad(clip_value))
            # layers.append(nn.Tanh())
            # layers.append(Scale(clip_value))
        self.value = nn.Sequential(*layers)

        # action range is [-1,1], init mean is around 0, so we set the variance 0.5~1.0
        if trainable_std:
            self.policy.logstd = nn.Parameter(torch.full((1,nact),
                init_std if init_std is not None else 1.).log())
        else:
            self.policy.register_buffer("logstd", torch.full((1,nact),
                init_std if init_std is not None else 1.).log())

        if pretrain:
            print (f"load pretrain from {pretrain}")
            self.load_state_dict(torch.load(pretrain))
            if init_std is not None:
                init_std = np.log(init_std)
                self.policy.logstd.data += init_std - self.policy.logstd.data.mean()
                print (f'set logstd {self.policy.logstd} mean {init_std}')
            else:
                print (f'logstd {self.policy.logstd}')

    def forward(self, s, av=3):
        if av & 1:
            act_mean_std = (self.policy(s), self.policy.logstd.expand(len(s),-1))
            # print (act_mean_std)
            if av == 1: return act_mean_std
        if av & 2:
            values = self.value(s).squeeze(-1)
            if av == 2: return values
        return act_mean_std, values
