import torch
import torch.nn as nn
import numpy as np
from torch.distributions.normal import Normal

clip_min_stddev=0.1
clip_max_stddev=10
clip_mean=30

class RNNModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers=1, device='cpu'):
        super(RNNModel, self).__init__()

        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.device= device

        self.rnn = nn.LSTM(input_size, hidden_dim, n_layers, bias=True, batch_first=False)
        # Fully connected layer
        self.fc = nn.Linear(hidden_dim, output_size)


    def forward(self, x):
        """Predict the next embedding based a history of states and actions
        Given:
            x: shape(H, B, input_size)
        Return:
            out: shape(B, output_dim)
        """
        out, hidden = self.rnn(x)

        out = self.fc(out[-1, :, :])  # many to one
        return out


def _activation(t):
    mean, log_std = t
    low = -np.inf if clip_mean is None else -clip_mean
    high = np.inf if clip_mean is None else clip_mean
    mean = squash_to_range(mean, low, high)

    if clip_min_stddev is None:
        low = -np.inf
    else:
        low = np.log(np.exp(clip_min_stddev) - 1.0)
    if clip_max_stddev is None:
        high = np.inf
    else:
        high = np.log(np.exp(clip_max_stddev) - 1.0)
    std = squash_to_range(torch.exp(log_std), low, high)
    log_std = torch.log(std)
    return [mean, log_std]

def squash_to_range(t, low=-np.inf, high=np.inf):
  """Squashes an input to the range [low, high]."""
  # assert low < 0 < high
  if low == -np.inf:
    t_low = t
  else:
    t_low = -low * torch.tanh(t / (-low))
  if high == np.inf:
    t_high = t
  else:
    t_high = high * torch.tanh(t / high)
  return torch.where(t < 0, t_low, t_high)

def tie_weights(src, trg):
    assert type(src) == type(trg)
    trg.weight = src.weight
    trg.bias = src.bias

def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)


class MLP_encoder(nn.Module):
    def __init__(self, obs_shape, feature_num=10, num_layers=2, num_filters=32,output_logits=False):
        super().__init__()

        self.obs_shape = obs_shape

        self.feature_num = feature_num

        self.output_logits = output_logits

        self.net = nn.Sequential(
            nn.Linear(obs_shape[0], 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 2 * feature_num),
        )

        self.ln = nn.LayerNorm(2*self.feature_num)
        self.output_logits = output_logits

    def forward(self, obs, output_mean=False):

        feat = self.net(obs)
        mu, logstd = feat.chunk(2, dim=-1)
        state = self.reparameterize(mu, logstd, output_mean=output_mean)

        return [mu, logstd, state]

    def reparameterize(self, mu, logstd, output_mean=False):
        std = torch.exp(logstd)
        eps = torch.randn_like(std)

        if output_mean:
            output = mu
        else:
            output = mu + eps * std
        return output

class Transition_model(nn.Module):

    def __init__(self, input_feature_num, output_feature_num=50, num_layers=2, output_logits=False):
        super().__init__()
        self.num_layers = num_layers
        self.output_logits = output_logits

        self.net = nn.Sequential(
            nn.Linear(input_feature_num, 256), nn.ReLU(),
            nn.Linear(256,256), nn.ReLU(), nn.Linear(256, 2 * output_feature_num),
        )

        self.ln = nn.LayerNorm(2 * output_feature_num)
        #self.apply(weight_init)

    def forward(self, obs_actions, detach = False):

        output = self.net(obs_actions)

        if detach:
            output = output.detach()

        output_norm = self.ln(output)

        if self.output_logits:
            output = output_norm
        else:
            output = output

        mu, logstd = output.chunk(2, dim=-1)
        state = self.reparameterize(mu, logstd)
        dist = Normal(loc=mu, scale=torch.exp(logstd))
        return [mu, logstd, dist]

    def reparameterize(self, mu, logstd, output_mean=False):
        std = torch.exp(logstd)
        eps = torch.randn_like(std)

        if output_mean:
            output = mu
        else:
            output = mu + eps * std
        return output

class IdentityEncoder(nn.Module):
    def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args):
        super().__init__()

        assert len(obs_shape) == 1
        self.feature_dim = obs_shape[0]

    def forward(self, obs, output_mean=False):
        return None, None, obs

    def copy_conv_weights_from(self, source):
        pass

    def log(self, L, step, log_freq):
        pass

class PixelEncoder(nn.Module):
    def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args):
        pass




_AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder, 'prop': MLP_encoder}


def make_encoder(
    encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False
):
    assert encoder_type in _AVAILABLE_ENCODERS
    return _AVAILABLE_ENCODERS[encoder_type](
        obs_shape, feature_dim, num_layers, num_filters, output_logits
    )