import torch
from torch import nn
import numpy as np
import copy

from .encoder import Encoder
from .policy import GaussianPolicy
from .mlp import MLP 


class PiNetwork(torch.nn.Module):
    def __init__(
        self,
        input_n_channel,
        mlp_dim,
        action_dim,
        action_mag,
        activation_type,  # for MLP; ReLU default for conv
        img_sz,
        kernel_sz,
        stride,
        padding,
        n_channel,
        latent_dim=0,
        append_dim=0,
        rec_type=None,
        rec_hidden_size=0,
        rec_num_layers=1,
        rec_bidirectional=False,
        rec_dropout=0,
        use_sm=True,
        use_ln=True,
        use_bn=False,
        use_residual=False,
        dual_conv=False,
        use_film=False,
        lang_dim=256,
        device='cpu',
        verbose=True,
    ):
        super().__init__()
        self.device = device
        self.rec_hidden_size = rec_hidden_size
        self.use_film = use_film
        self.lang_dim = lang_dim
        if use_film:
            append_dim -= lang_dim
        
        if rec_bidirectional:
            self.rec_hidden_batch_dim = 2 * rec_num_layers
        else:
            self.rec_hidden_batch_dim = rec_num_layers
        self.img_sz = img_sz
        if np.isscalar(img_sz):
            self.img_sz = [img_sz, img_sz]

        # Conv layers shared with critic
        self.encoder = Encoder(input_n_channel=input_n_channel,
                               img_sz=img_sz,
                               kernel_sz=kernel_sz,
                               stride=stride,
                               padding=padding,
                               n_channel=n_channel,
                               use_sm=use_sm,
                               use_spec=False,
                               use_bn=use_bn,
                               use_residual=use_residual,
                               dual_conv=dual_conv,
                               use_film=use_film,
                               device=device,
                               verbose=False)
        if use_sm:
            dim_conv_out = n_channel[-1] * 2  # assume spatial softmax
        else:
            dim_conv_out = self.encoder.get_output_dim()

        # Add recurrent if specified
        self.rec = None
        mlp_dim = [dim_conv_out + append_dim + latent_dim
                    ] + mlp_dim + [action_dim]

        # Linear layers
        self.mlp = GaussianPolicy(mlp_dim, action_mag, activation_type, use_ln,
                                  device, verbose)

    def forward(
            self,
            image,  # NCHW or LNCHW
            append=None,  # LN x append_dim
            latent=None,  # LN x z_dim
            detach_encoder=False,
            detach_rec=False,
            detach_lang=False,
            init_rnn_state=None,  # N x hidden_dim
    ):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims. init_rnn_state is always
        L=1.
        """
        # Convert to torch
        np_input = False
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
            np_input = True
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        L = 0
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape
        else:
            L, N, C, H, W = image.shape
            image = image.view(L * N, C, H, W)
        if self.rec is not None and L == 0:
            # recurrent but input does not have L
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            L = 1
        restore_seq = L > 0

        # # Append latent to image channels
        # if latent is not None:
        #     latent = latent.unsqueeze(-1).unsqueeze(-1)  # make H, W channels
        #     if image.dim() == 4:  # no seq
        #         latent = latent.repeat(1, 1, H, W)
        #     else:  #! assume same latent for seq
        #         latent = latent.repeat(L, 1, 1, H, W)
        #         latent = latent.view(L * N, -1, H, W)
        #     image = torch.cat((image, latent), dim=-3)  # dim=C

        # Forward thru conv
        if self.use_film:   # assume lang at front of append
            conv_out = self.encoder.forward(image, detach=detach_encoder, lang=append[:, :self.lang_dim])
            append = append[:, self.lang_dim:]
        else:
            conv_out = self.encoder.forward(image, detach=detach_encoder)

        # Put dimension back
        if restore_seq:
            conv_out = conv_out.view(L, N, -1)

        # Append, recurrent, latent
        if self.rec is not None:
            conv_out = self.layernorm(conv_out)
            if append is not None:
                conv_out = torch.cat((conv_out, append), dim=-1)
            conv_out, (hn, cn) = self.rec(conv_out, init_rnn_state, detach_rec)
        elif append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # MLP
        output = self.mlp(conv_out)

        # Restore dimension
        for _ in range(num_extra_dim):
            output = output.squeeze(0)

        # Convert back to np
        if np_input:
            output = output.detach().cpu().numpy()

        if self.rec is not None:
            return output, (hn, cn)
        else:
            return output

    def sample(self,
               image,
               append=None,
               latent=None,
               detach_encoder=False,
               detach_rec=False,
               detach_lang=False,
               init_rnn_state=None,
               get_prob=False):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims. init_rnn_state is always
        L=1. Not converting back to numpy.
        """

        # Convert to torch
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        L = 0
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape
        else:
            L, N, C, H, W = image.shape
            image = image.view(L * N, C, H, W)
        if self.rec is not None and L == 0:
            # recurrent but input does not have L
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            L = 1
        restore_seq = L > 0

        # Append latent to image channels
        # if latent is not None:
        #     latent = latent.unsqueeze(-1).unsqueeze(-1)  # make H, W channels
        #     if image.dim() == 4:  # no seq
        #         latent = latent.repeat(1, 1, H, W)
        #     else:  #! assume same latent for seq
        #         latent = latent.repeat(L, 1, 1, H, W)
        #         latent = latent.view(L * N, -1, H, W)
        #     image = torch.cat((image, latent), dim=-3)  # dim=C

        # Forward thru conv
        if self.use_film:   # assume lang at front of append
            conv_out = self.encoder.forward(image, detach=detach_encoder, lang=append[:, :self.lang_dim])
            append = append[:, self.lang_dim:]
        else:
            conv_out = self.encoder.forward(image, detach=detach_encoder)

        # Put dimension back
        if restore_seq:
            conv_out = conv_out.view(L, N, -1)

        # Append, recurrent, latent
        if self.rec is not None:
            conv_out = self.layernorm(conv_out)
            if append is not None:
                conv_out = torch.cat((conv_out, append), dim=-1)
            conv_out, (hn, cn) = self.rec(conv_out, init_rnn_state, detach_rec)
        elif append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # MLP
        action, log_prob = self.mlp.sample(conv_out, get_prob)

        # Restore dimension
        for _ in range(num_extra_dim):
            action = action.squeeze(0)
            log_prob = log_prob.squeeze(0)

        if self.rec is not None:
            return (action, log_prob), (hn, cn)
        else:
            return (action, log_prob)

    def sample_init_rnn_state(self):
        if self.rec is None:
            raise NotImplementedError
        else:
            return (torch.randn(self.rec_hidden_batch_dim, 1,
                                self.rec_hidden_size).to(self.device),
                    torch.randn(self.rec_hidden_batch_dim, 1,
                                self.rec_hidden_size).to(self.device))


class TwinnedQNetwork(torch.nn.Module):
    def __init__(
        self,
        input_n_channel,
        mlp_dim,
        action_dim,
        activation_type,  # for MLP; ReLU default for conv
        img_sz,
        kernel_sz,
        stride,
        padding,
        n_channel,
        latent_dim=0,
        append_dim=0,
        rec_type=None,
        rec_hidden_size=0,
        rec_num_layers=1,
        rec_bidirectional=False,
        rec_dropout=0,
        use_sm=True,
        use_ln=True,
        use_bn=False,
        use_residual=False,
        dual_conv=False,
        use_film=False,
        lang_dim=256,
        device='cpu',
        verbose=True,
    ):

        super().__init__()
        self.device = device
        self.img_sz = img_sz
        self.use_film = use_film
        self.lang_dim = lang_dim
        if use_film:
            append_dim -= self.lang_dim
        if np.isscalar(img_sz):
            self.img_sz = [img_sz, img_sz]

        # Conv layers shared with critic
        self.encoder = Encoder(input_n_channel=input_n_channel,
                               img_sz=img_sz,
                               kernel_sz=kernel_sz,
                               stride=stride,
                               padding=padding,
                               n_channel=n_channel,
                               use_sm=use_sm,
                               use_spec=False,
                               use_bn=use_bn,
                               use_residual=use_residual,
                               dual_conv=dual_conv,
                               use_film=use_film,
                               device=device,
                               verbose=False)
        if use_sm:
            dim_conv_out = n_channel[-1] * 2  # assume spatial softmax
        else:
            dim_conv_out = self.encoder.get_output_dim()

        # Add Recurrent if specified
        self.rec = None
        mlp_dim = [dim_conv_out + latent_dim + append_dim + action_dim
                    ] + mlp_dim + [1]

        self.Q1 = MLP(mlp_dim,
                      activation_type,
                      out_activation_type='Identity',
                      use_ln=use_ln,
                      verbose=False).to(device)
        self.Q2 = copy.deepcopy(self.Q1)
        if verbose:
            print("The MLP for critic has the architecture as below:")
            print(self.Q1.moduleList)

    def forward(self,
                image,
                actions,
                append=None,
                latent=None,
                detach_encoder=False,
                detach_rec=False,
                detach_lang=False,
                init_rnn_state=None):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims. init_rnn_state is always
        L=1.
        """

        # Convert to torch
        np_input = False
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
            np_input = True
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        L = 0
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            actions = actions.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape
        else:
            L, N, C, H, W = image.shape
            image = image.view(L * N, C, H, W)
        if self.rec is not None and L == 0:
            # recurrent but input does not have L
            if append is not None:
                append = append.unsqueeze(0)
            actions = actions.unsqueeze(0)
            num_extra_dim += 1
            L = 1
        restore_seq = L > 0

        # Append latent to image channels
        # if latent is not None:
        #     latent = latent.unsqueeze(-1).unsqueeze(-1)  # make H, W channels
        #     if image.dim() == 4:  # no seq
        #         latent = latent.repeat(1, 1, H, W)
        #     else:  #! assume same latent for seq
        #         latent = latent.repeat(L, 1, 1, H, W)
        #         latent = latent.view(L * N, -1, H, W)
        #     image = torch.cat((image, latent), dim=-3)  # dim=C

        # Forward thru conv
        if self.use_film:   # assume lang at front of append
            conv_out = self.encoder.forward(image, detach=detach_encoder, lang=append[:, :self.lang_dim])
            append = append[:, self.lang_dim:]
        else:
            conv_out = self.encoder.forward(image, detach=detach_encoder)

        # Put dimension back
        if restore_seq:
            conv_out = conv_out.view(L, N, -1)

        # Append, recurrent, latent
        if self.rec is not None:
            conv_out = self.layernorm(conv_out)
            if append is not None:
                conv_out = torch.cat((conv_out, append), dim=-1)
            conv_out, (hn, cn) = self.rec(conv_out, init_rnn_state, detach_rec)
        elif append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # Append action to mlp
        conv_out = torch.cat((conv_out, actions), dim=-1)

        # MLP
        q1 = self.Q1(conv_out)
        q2 = self.Q2(conv_out)

        # Restore dimension
        for _ in range(num_extra_dim):
            q1 = q1.squeeze(0)
            q2 = q2.squeeze(0)

        # Convert back to np
        if np_input:
            q1 = q1.detach().cpu().numpy()
            q2 = q2.detach().cpu().numpy()

        if self.rec is not None:
            return q1, q2, (hn, cn)
        else:
            return q1, q2
