
import torch
import torch.nn as nn

from rlpyt.models.conv2d import Conv2dModel
from rlpyt.models.mlp import MlpModel
from rlpyt.utils.tensor import infer_leading_dims, restore_leading_dims


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)


class CurlSacEncoderModel(nn.Module):

    def __init__(
            self,
            latent_size,
            observation_shape=None,
            conv_model=None,
            channels=None,
            kernel_sizes=None,
            strides=None,
            paddings=None,
            hidden_sizes=None,
            layer_norm=True,
            tanh_output=False,
            ):
        super().__init__()
        # Separate conv and head for easier gradient control.
        assert observation_shape is not None or conv_model is not None
        if conv_model is None:
            conv_model = EncoderConvModel(
                observation_shape=observation_shape,
                channels=channels or [32, 32, 32, 32],
                kernel_sizes=kernel_sizes or [3, 3, 3, 3],
                strides=strides or [2, 1, 1, 1],
                paddings=paddings,
            )
        self.conv = conv_model
        self.head = EncoderHeadModel(
            input_size=self.conv.output_size,
            latent_size=latent_size,
            layer_norm=layer_norm,
            tanh_output=tanh_output,
        )
        self.apply(weight_init)

    def forward(self, observation, prev_action=None, prev_reward=None):
        """Can have input shape [c,h,w], [B,c,h,w], or [T,B,c,h,w]."""
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        conv_out = self.conv(observation.view(T * B, *img_shape), infer_dims=False)
        out = self.head(conv_out, infer_dims=False)
        out = restore_leading_dims(out, lead_dim, T, B)
        return out

    @property
    def output_size(self):
        return self.head.output_size

class CurlSacCriticModel(nn.Module):

    def __init__(
            self,
            input_size,
            action_size,
            hidden_sizes,
            ):
        super().__init__()
        self.mlp1 = MlpModel(
            input_size=input_size + action_size,
            hidden_sizes=hidden_sizes,
            output_size=1,
        )
        self.mlp2 = MlpModel(
            input_size=input_size + action_size,
            hidden_sizes=hidden_sizes,
            output_size=1,
        )
        self.apply(weight_init)

    def forward(self, latent, prev_action, prev_reward, action,
            detach_conv=False):
        lead_dim, T, B, _ = infer_leading_dims(latent, 1)  # latent is vector

        q_input = torch.cat([
            latent.view(T * B, -1),
            action.view(T * B, -1),
            ], dim=1)
        q1 = self.mlp1(q_input).squeeze(-1)
        q2 = self.mlp2(q_input).squeeze(-1)
        q1, q2 = restore_leading_dims((q1, q2), lead_dim, T, B)
        return q1, q2


class CurlSacActorModel(nn.Module):

    def __init__(
            self,
            input_size,
            action_size,
            hidden_sizes,
            min_log_std=-10.,
            max_log_std=2.,
            ):
        super().__init__()
        self.mlp = MlpModel(
            input_size=input_size,
            hidden_sizes=hidden_sizes,
            output_size=action_size * 2,
        )
        self.apply(weight_init)
        self.min_log_std = min_log_std
        self.max_log_std = max_log_std

    def forward(self, latent, prev_action, prev_reward):
        lead_dim, T, B, _ = infer_leading_dims(latent, 1)  # latent is vector

        out = self.mlp(latent.view(T * B, -1))
        mu, log_std = out.chunk(chunks=2, dim=-1)
        # Squash log_std into range.
        log_std = torch.tanh(log_std)
        log_std = self.min_log_std + 0.5 * (
            self.max_log_std - self.min_log_std) * (1 + log_std)
        mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B)
        return mu, log_std


class CurlSacTransformModel(nn.Module):

    def __init__(self, latent_size):
        super().__init__()
        self.W = nn.Parameter(torch.rand(latent_size, latent_size))

    def forward(self, anchor, positive):
        """Params anchor and positive are in the latent space."""
        assert anchor.shape == positive.shape
        lead_dim, T, B, _ = infer_leading_dims(anchor, 1)
        Wz = torch.matmul(self.W, positive.view(T * B, -1).T)
        logits = torch.matmul(anchor.view(T * B, -1), Wz)
        logits = logits - torch.max(logits, dim=1)[0]  # normalize
        logits = restore_leading_dims(logits, lead_dim, T, B)
        return logits


class CurlSacSeqTransformModel(nn.Module):

    def __init__(self, latent_size, seq_input_size, hidden_sizes, anchor_mlp=True):
        super().__init__()
        self.W = nn.Parameter(torch.rand(latent_size, latent_size))
        if anchor_mlp:
            self.mlp = MlpModel(
                input_size=latent_size + seq_input_size,
                hidden_sizes=hidden_sizes,
                output_size=latent_size,  # at least this linear layer
            )
        else:
            self.mlp = None

    def forward(self, anchor, positive, actions=None, rewards=None):
        """Params anchor and positive are in the latent space."""
        assert anchor.shape == positive.shape
        lead_dim, T, B, _ = infer_leading_dims(anchor, 1)
        Wz = torch.matmul(self.W, positive.view(T * B, -1).T)

        anchor = anchor.view(T * B, -1)
        assert (actions is None) == (rewards is None)
        if self.mlp is not None:
            if actions is not None:
                # actions and rewards from the cpc_delta_T timesteps
                seq_T, seq_B = rewards.shape
                assert seq_B == B
                # [T,B,..] --> [B,T,...]
                actions = actions.transpose(1, 0)
                rewards = rewards.transpose(1, 0)
                # can't use actions.view() because of transpose:
                mlp_input = torch.cat(
                    [anchor, actions.reshape(seq_B, -1), rewards],
                    dim=1)
            else:
                mlp_input = anchor
            anchor = anchor + self.mlp(mlp_input)  # skipper

        logits = torch.matmul(anchor, Wz)
        logits = logits - torch.max(logits, dim=1)[0]  # normalize
        logits = restore_leading_dims(logits, lead_dim, T, B)
        return logits


class EncoderConvModel(nn.Module):

    def __init__(self, observation_shape, **kwargs):
        super().__init__()
        c, h, w = observation_shape
        self.conv2d = Conv2dModel(in_channels=3, **kwargs)
        self._output_size = self.conv2d.conv_out_size(h, w)
        self._output_size = self._output_size * 4

    def forward(self, img, infer_dims=True):
        """When calling as standalone, safest to use infer_dims=True."""
        img = img.type(torch.float)  # assume dtype is uint8
        img = img.mul_(1. / 255)
        if infer_dims:
            lead_dim, T, B, img_shape = infer_leading_dims(img, 3)
            img = img.view(T * B, *img_shape)
            img = img.view(T*B*3, 3, 108, 108)
        #out = self.conv2d(img)
        img = img.view(-1, 3, 108, 108)
        conv_out = self.conv2d(img)
        conv_out = conv_out.view(conv_out.shape[0]//3, 3, conv_out.shape[1], conv_out.shape[2], conv_out.shape[3])
        conv_current = conv_out[:, 1:, :, :, :]
        conv_prev = conv_current - conv_out[:, :2, :, :, :].detach()
        conv_out = torch.cat([conv_current, conv_prev], axis=1)
        out = conv_out.view(conv_out.size(0), conv_out.size(1)*conv_out.size(2), conv_out.size(3), conv_out.size(4))
        if infer_dims:
            out = restore_leading_dims(out, lead_dim, T, B)
        return out

    @property
    def output_size(self):
        return self._output_size




class EncoderHeadModel(nn.Module):

    def __init__(self, input_size, latent_size, layer_norm=True,
            tanh_output=False):
        super().__init__()
        self.linear = nn.Linear(input_size, latent_size)
        self.layer_norm = nn.LayerNorm(latent_size) if layer_norm else None
        self._tanh_output = tanh_output
        self._output_size = latent_size

    def forward(self, conv_out, infer_dims=True):
        """When calling as standalone, safest to use infer_dims=True."""
        if infer_dims:
            lead_dim, T, B, conv_shape = infer_leading_dims(conv_out, 3)
            conv_out = conv_out.view(T * B, -1)
        else:
            conv_out = conv_out.flatten(start_dim=1)  # Assume was [B,c,h,w]
        out = self.linear(conv_out)
        if self.layer_norm is not None:
            out = self.layer_norm(out)
        if self._tanh_output:
            out = torch.tanh(out)
        if infer_dims:
            out = restore_leading_dims(out, lead_dim, T, B)
        return out

    @property
    def output_size(self):
        return self._output_size


class CurlSacModel(nn.Module):
    """To keep the standard agent.model interface for shared params, etc."""

    def __init__(self, encoder, pi_mlp):
        super().__init__()
        self.encoder = encoder
        self.pi_mlp = pi_mlp

    def forward(self, observation, prev_action, prev_reward):
        """Just to keep the standard obs, prev_action, prev_rew interface."""
        latent = self.encoder(observation)
        out = self.pi_mlp(latent, prev_action, prev_reward)
        return out


class IdentityEncoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = IdentitySubEncoder()  # (Exposed to optimized algo/agent.)
        self.head = IdentitySubEncoder()

    def forward(self, observation, prev_action=None, prev_reward=None):
        return observation


class IdentitySubEncoder(nn.Module):

    def forward(self, input, **kwargs):
        return input
