import torch
import numpy as np
from torch.distributions import Normal
import torch.nn.functional as F

from .encoder import Encoder
from .mlp import MLP


class Policy(torch.nn.Module):
    def __init__(
        self,
        input_n_channel,
        mlp_dim,
        action_dim,
        activation_type,  # for MLP; ReLU default for conv
        img_sz,
        kernel_sz,
        stride,
        padding,
        n_channel,
        action_mag=1,
        latent_dim=0,
        append_dim=0,
        use_sm=True,
        use_ln=True,
        use_bn=False,
        use_residual=False,
        dual_conv=False,
        device='cpu',
        verbose=True,
    ):
        super().__init__()
        self.device = device
        self.img_sz = img_sz
        if np.isscalar(img_sz):
            self.img_sz = [img_sz, img_sz]

        # Conv layers shared with critic
        self.encoder = Encoder(input_n_channel=input_n_channel,
                               img_sz=img_sz,
                               kernel_sz=kernel_sz,
                               stride=stride,
                               padding=padding,
                               n_channel=n_channel,
                               use_sm=use_sm,
                               use_spec=False,
                               use_bn=use_bn,
                               use_residual=use_residual,
                               dual_conv=dual_conv,
                               device=device,
                               verbose=False)
        if use_sm:
            dim_conv_out = n_channel[-1] * 2  # assume spatial softmax
        else:
            dim_conv_out = self.encoder.get_output_dim()

        # Linear layers
        mlp_dim = [dim_conv_out + append_dim + latent_dim] + \
                    mlp_dim + [action_dim]
        self.mlp = GaussianPolicy(mlp_dim, action_mag, activation_type, use_ln,
                                  device, verbose)


    def forward(
            self,
            image,  # NCHW or LNCHW
            append=None,  # LN x append_dim
            latent=None,  # LN x z_dim
    ):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims.
        """
        # Convert to torch
        np_input = False
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
            np_input = True
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape

        # Forward thru conv
        conv_out = self.encoder.forward(image, detach=False)

        # Append, latent
        if append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # MLP
        output = self.mlp(conv_out)

        # Restore dimension
        for _ in range(num_extra_dim):
            output = output.squeeze(0)

        # Convert back to np
        if np_input:
            output = output.detach().cpu().numpy()
        return output


    def sample(self,
               image,
               append=None,
               latent=None,
               detach_encoder=False,
        ):
        """
        Assume all arguments have the same number of leading dims (L and N),
        and returns the same number of leading dims. init_rnn_state is always
        L=1. Not converting back to numpy.
        """

        # Convert to torch
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape

        # Get CNN output
        conv_out = self.encoder.forward(image, detach=detach_encoder)

        # Append, latent
        if append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # MLP
        action, log_prob = self.mlp.sample(conv_out)

        # Restore dimension
        for _ in range(num_extra_dim):
            action = action.squeeze(0)
            log_prob = log_prob.squeeze(0)
        return (action, log_prob)


    def log_prob(self, image, 
                        action,
                        append=None,
                        latent=None):

        # Convert to torch
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape

        # Get CNN output
        conv_out = self.encoder.forward(image, detach=False)

        # Append, latent
        if append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # Get raw action
        with torch.no_grad():
            y = (action - self.mlp.bias) / self.mlp.scale
            eps = 1e-5
            y = torch.clamp(y, min=-1+eps, max=1-eps) #! need to clamp otherwise atanh can throw nan since values too large
            rv = torch.atanh(y)

        # MLP
        log_prob = self.mlp.get_pdf(conv_out, rv)
        return log_prob


    def density(self, image,
                    append=None,
                    latent=None):

        # Convert to torch
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape

        # Get CNN output
        conv_out = self.encoder.forward(image, detach=False)

        # Append, latent
        if append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # Get density
        return self.mlp.get_density(conv_out)

class Baseline(torch.nn.Module):
    def __init__(
        self,
        input_n_channel,
        img_sz,
        kernel_sz,
        stride,
        padding,
        n_channel,
        latent_dim=0,
        append_dim=0,
        use_sm=True,
        use_ln=True,
        use_bn=False,
        use_residual=False,
        device='cpu',
        verbose=True,
    ):

        super().__init__()
        self.device = device
        self.img_sz = img_sz
        if np.isscalar(img_sz):
            self.img_sz = [img_sz, img_sz]

        # Conv layers shared with critic
        self.encoder = Encoder(input_n_channel=input_n_channel,
                               img_sz=img_sz,
                               kernel_sz=kernel_sz,
                               stride=stride,
                               padding=padding,
                               n_channel=n_channel,
                               use_sm=use_sm,
                               use_spec=False,
                               use_bn=use_bn,
                               use_residual=use_residual,
                               device=device,
                               verbose=False)
        if use_sm:
            dim_conv_out = n_channel[-1] * 2  # assume spatial softmax
        else:
            dim_conv_out = self.encoder.get_output_dim()

        # one linear layer only, no activation
        mlp_dim = [dim_conv_out + latent_dim + append_dim, 1]
        self.V = MLP(mlp_dim,
                      'Identity',
                      out_activation_type='Identity',
                      use_ln=use_ln,
                      verbose=False).to(device)
        if verbose:
            print("The MLP for critic has the architecture as below:")
            print(self.V.moduleList)

    def features(self, image, 
                        append=None,
                        latent=None,
                        detach_encoder=True):

        # Convert to torch
        np_input = False
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).to(self.device)
            np_input = True
        if isinstance(append, np.ndarray):
            append = torch.from_numpy(append).float().to(self.device)

        # Convert [0, 255] to [0, 1]
        if image.dtype == torch.uint8:
            image = image.float() / 255.0
        else:
            raise TypeError

        # Get dimensions
        num_extra_dim = 0
        if image.dim() == 3:  # running policy deterministically at test time
            image = image.unsqueeze(0)
            actions = actions.unsqueeze(0)
            if append is not None:
                append = append.unsqueeze(0)
            num_extra_dim += 1
            N, C, H, W = image.shape
        elif image.dim() == 4:
            N, C, H, W = image.shape

        # Get CNN output
        conv_out = self.encoder.forward(image, detach=detach_encoder)
        return conv_out, np_input, num_extra_dim

    def forward(self,
                image,
                append=None,
                latent=None,
                detach_encoder=True):

        conv_out, np_input, num_extra_dim = self.features(image, detach_encoder=detach_encoder)

        # Append, latent
        if append is not None:
            conv_out = torch.cat((conv_out, append), dim=-1)
        if latent is not None:
            conv_out = torch.cat((conv_out, latent), dim=-1)

        # MLP
        v = self.V(conv_out)

        # Restore dimension
        for _ in range(num_extra_dim):
            v = v.squeeze(0)

        # Convert back to np
        if np_input:
            v = v.detach().cpu().numpy()
        return v    #Nx1


    def fit(self, image, returns):
        baseline_reg = 1e-5
        features = self.features(image)[0]
        reg = baseline_reg * torch.eye(features.size(1)).to(self.device)
        A = features.t() @ features + reg
        b = features.t() @ returns
        coeffs = torch.linalg.lstsq(b, A)[0]
        self.V.moduleList[0][0].weight.data = coeffs.data  # V only has one sequential, and the sequential only has a linear layer


class GaussianPolicy(torch.nn.Module):
    def __init__(self,
                 dimList,
                 action_mag,    # assume 1
                 activation_type='ReLU',
                 use_ln=True,
                 device='cpu',
                 verbose=True):
        super(GaussianPolicy, self).__init__()
        self.device = device
        self.mean = MLP(dimList,
                        activation_type,
                        out_activation_type='Identity',
                        use_ln=use_ln,
                        verbose=False).to(device)
        self.log_std = MLP(dimList,
                           activation_type,
                           out_activation_type='Identity',
                           use_ln=use_ln,
                           verbose=False).to(device)
        if verbose:
            print("The MLP for MEAN has the architecture as below:")
            print(self.mean.moduleList)
        self.LOG_STD_MAX = 1
        self.LOG_STD_MIN = -10
        self.eps = 1e-8

    def forward(self, state):  # mean only
        state_tensor = state.to(self.device)
        mean = self.mean(state_tensor)
        mean = torch.tanh(mean)
        return mean

    def get_density(self, state):
        state_tensor = state.to(self.device)
        mean = self.mean(state_tensor)
        log_std = self.log_std(state_tensor)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
        std = torch.exp(log_std)
        # if torch.any(torch.isnan(mean)):
        #     print(state_tensor)
        return Normal(mean, std)

    def sample(self, state, get_prob=False):
        normal_rv = self.get_density(state)
        pi_action = normal_rv.rsample()

        if get_prob:
            logp_pi = normal_rv.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)

        # Get action
        # y = torch.tanh(x)  # constrain the output to be within [-1, 1]
        # action = y * self.scale + self.bias
        # # Get the correct probability: x -> a, a = c * y + b, y = tanh x
        # # followed by: p(a) = p(x) x |det(da/dx)|^-1
        # # log p(a) = log p(x) - log |det(da/dx)|
        # # log |det(da/dx)| = sum log (d a_i / d x_i)
        # # d a_i / d x_i = c * ( 1 - y_i^2 )
        # log_prob -= torch.log(self.scale * (1 - y.pow(2)) + self.eps)
        # if log_prob.dim() > 1:
        #     log_prob = log_prob.sum(log_prob.dim() - 1, keepdim=True)
        # else:
        #     log_prob = log_prob.sum()
        # # mean = torch.tanh(mean) * self.scale + self.bias

        return pi_action, logp_pi

