import torch
from torch import nn
import numpy as np
import copy

from .encoder import Encoder
from .policy import GaussianPolicy
from .mlp import MLP 


class LangPiNetwork(torch.nn.Module):
    def __init__(
        self,
        input_n_channel,
        mlp_dim,
        action_dim,
        action_mag,
        activation_type,  # for MLP; ReLU default for conv
        img_sz,
        kernel_sz,
        stride,
        padding,
        n_channel,
        append_dim=0,
        use_sm=True,
        use_ln=True,
        use_bn=False,
        use_residual=False,
        dual_conv=False,
        use_film=False,
        lang_dim=768,
        lang_mlp_dim=128,
        device='cpu',
        verbose=True,
    ):
        super().__init__()
        self.device = device
        self.use_film = use_film
        self.lang_dim = lang_dim
        if use_film:
            append_dim -= lang_dim
        self.img_sz = img_sz
        if np.isscalar(img_sz):
            self.img_sz = [img_sz, img_sz]

        # Conv layers shared with critic
        self.encoder = Encoder(input_n_channel=input_n_channel,
                               img_sz=img_sz,
                               kernel_sz=kernel_sz,
                               stride=stride,
                               padding=padding,
                               n_channel=n_channel,
                               use_sm=use_sm,
                               use_spec=False,
                               use_bn=use_bn,
                               use_residual=use_residual,
                               dual_conv=dual_conv,
                               use_film=use_film,
                               device=device,
                               verbose=False)
        if use_sm:
            dim_conv_out = n_channel[-1] * 2  # assume spatial softmax
        else:
            dim_conv_out = self.encoder.get_output_dim()

        # Add recurrent if specified
        mlp_dim = [dim_conv_out + lang_mlp_dim] + mlp_dim + [action_dim]

        # Linear layers
        self.mlp = GaussianPolicy(mlp_dim, action_mag, activation_type, use_ln,
                                  device, verbose)
        self.mlp_lang = nn.Sequential(
                            nn.Linear(lang_dim, lang_mlp_dim), 
                            nn.ReLU(),
                            # nn.Linear(lang_mlp_dim, lang_mlp_dim), 
                            # nn.ReLU(),
                            ).to(device) 

    def forward(
            self,
            image,  # NCHW or LNCHW
            append=None,  # LN x append_dim #! assume embedding
            latent=None,
            detach_encoder=False,
            detach_lang=True,    #! assume critic updates
    ):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims. init_rnn_state is always
        L=1.
        """
        # Convert to torch
        np_input = False
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
            np_input = True
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1

        # CNN
        conv_out = self.encoder.forward(image, detach=detach_encoder)

        # Append, latent
        if append is not None:
            append = self.mlp_lang(append)
            if detach_lang:
                append = append.detach()
            conv_out = torch.cat((conv_out, append), dim=-1)

        # MLP
        output = self.mlp(conv_out)

        # Restore dimension
        for _ in range(num_extra_dim):
            output = output.squeeze(0)

        # Convert back to np
        if np_input:
            output = output.detach().cpu().numpy()
        return output


    def sample(self,
               image,
               append=None,
               latent=None,
               detach_encoder=False,
               detach_lang=True,  #! assume critic updates
               get_prob=False):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims. init_rnn_state is always
        L=1. Not converting back to numpy.
        """

        # Convert to torch
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1

        # CNN
        conv_out = self.encoder.forward(image, detach=detach_encoder)

        # Append, latent
        if append is not None:
            append = self.mlp_lang(append)
            if detach_lang:
                append = append.detach()
            conv_out = torch.cat((conv_out, append), dim=-1)

        # MLP
        action, log_prob = self.mlp.sample(conv_out, get_prob)

        # Restore dimension
        for _ in range(num_extra_dim):
            action = action.squeeze(0)
            log_prob = log_prob.squeeze(0)
        return (action, log_prob)


class LangTwinnedQNetwork(torch.nn.Module):
    def __init__(
        self,
        input_n_channel,
        mlp_dim,
        action_dim,
        activation_type,  # for MLP; ReLU default for conv
        img_sz,
        kernel_sz,
        stride,
        padding,
        n_channel,
        append_dim=0,
        use_sm=True,
        use_ln=True,
        use_bn=False,
        use_residual=False,
        dual_conv=False,
        use_film=False,
        lang_dim=768,
        lang_mlp_dim=128,
        device='cpu',
        verbose=True,
    ):

        super().__init__()
        self.device = device
        self.img_sz = img_sz
        self.use_film = use_film
        self.lang_dim = lang_dim
        if use_film:
            append_dim -= self.lang_dim
        if np.isscalar(img_sz):
            self.img_sz = [img_sz, img_sz]

        # Conv layers shared with critic
        self.encoder = Encoder(input_n_channel=input_n_channel,
                               img_sz=img_sz,
                               kernel_sz=kernel_sz,
                               stride=stride,
                               padding=padding,
                               n_channel=n_channel,
                               use_sm=use_sm,
                               use_spec=False,
                               use_bn=use_bn,
                               use_residual=use_residual,
                               dual_conv=dual_conv,
                               use_film=use_film,
                               device=device,
                               verbose=False)
        if use_sm:
            dim_conv_out = n_channel[-1] * 2  # assume spatial softmax
        else:
            dim_conv_out = self.encoder.get_output_dim()
        mlp_dim = [dim_conv_out + lang_mlp_dim + action_dim
                       ] + mlp_dim + [1]

        self.Q1 = MLP(mlp_dim,
                      activation_type,
                      out_activation_type='Identity',
                      use_ln=use_ln,
                      verbose=False).to(device)
        self.Q2 = copy.deepcopy(self.Q1)
        if verbose:
            print("The MLP for critic has the architecture as below:")
            print(self.Q1.moduleList)

        self.mlp_lang = nn.Sequential(
                            nn.Linear(lang_dim, lang_mlp_dim), 
                            nn.ReLU(),
                            # nn.Linear(lang_mlp_dim, lang_mlp_dim), 
                            # nn.ReLU(),
                            ).to(device) 


    def forward(self,
                image,
                actions,
                latent=None,
                append=None,
                detach_encoder=False,
                detach_lang=False,
                ):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims. init_rnn_state is always
        L=1.
        """

        # Convert to torch
        np_input = False
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
            np_input = True
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            actions = actions.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1

        # Forward thru conv
        conv_out = self.encoder.forward(image, detach=detach_encoder)

        # Append, latent
        if append is not None:
            append = self.mlp_lang(append)
            if detach_lang:
                append = append.detach()
            conv_out = torch.cat((conv_out, append), dim=-1)

        # Append action to mlp
        conv_out = torch.cat((conv_out, actions), dim=-1)

        # MLP
        q1 = self.Q1(conv_out)
        q2 = self.Q2(conv_out)

        # Restore dimension
        for _ in range(num_extra_dim):
            q1 = q1.squeeze(0)
            q2 = q2.squeeze(0)

        # Convert back to np
        if np_input:
            q1 = q1.detach().cpu().numpy()
            q2 = q2.detach().cpu().numpy()
        return q1, q2
