import numpy as np
from torch import nn
import torch
from einops.layers.torch import Rearrange
from positional_encodings.torch_encodings import PositionalEncoding1D, Summer

class ActionVAE(nn.Module):
    def __init__(self,
                 action_dim=7,
                 encoder_dim=256,
                 decoder_dim=256,

                 skill_block_size=32,
                 downsample_factor=2,

                 attn_pdrop=0.1,
                 use_causal_encoder=True,
                 use_causal_decoder=True,

                 encoder_heads=4,
                 encoder_layers=2,
                 decoder_heads=4,
                 decoder_layers=4,

                 before_latent_dim=8,
                 latent_dim=16,
                 kl_weight=1e-5,
                 ):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.skill_block_size = skill_block_size

        self.use_causal_encoder = use_causal_encoder
        self.use_causal_decoder = use_causal_decoder

        self.kl_weight = kl_weight
        self.loss_fn = torch.nn.L1Loss()

        assert int(np.log2(downsample_factor)) == np.log2(downsample_factor), 'downsample_factor must be a power of 2'
        strides = [2] * int(np.log2(downsample_factor)) + [1]
        kernel_sizes = [5] + [3] * int(np.log2(downsample_factor))
        if len(strides) == 1:
            kernel_sizes = [3, 2]
            strides = [1, 1]

        self.action_proj = nn.Linear(action_dim, encoder_dim)
        self.action_head = nn.Linear(decoder_dim, action_dim)
        self.conv_block = ResidualTemporalBlock(
            encoder_dim, encoder_dim, kernel_size=kernel_sizes,
            stride=strides, causal=use_causal_encoder)

        encoder_layer = nn.TransformerEncoderLayer(d_model=encoder_dim,
                                                   nhead=encoder_heads,
                                                   dim_feedforward=4 * encoder_dim,
                                                   dropout=attn_pdrop,
                                                   activation='gelu',
                                                   batch_first=True,
                                                   norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer,
                                             num_layers=encoder_layers,
                                             enable_nested_tensor=False)
        decoder_layer = nn.TransformerDecoderLayer(d_model=decoder_dim,
                                                   nhead=decoder_heads,
                                                   dim_feedforward=4 * decoder_dim,
                                                   dropout=attn_pdrop,
                                                   activation='gelu',
                                                   batch_first=True,
                                                   norm_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=decoder_layers)
        self.add_positional_emb = Summer(PositionalEncoding1D(encoder_dim))
        self.fixed_positional_emb = PositionalEncoding1D(decoder_dim)


        self.latent_dim = latent_dim
        self.quant_proj = nn.Linear(encoder_dim, self.latent_dim * 2)
        self.post_quant_proj = nn.Linear(self.latent_dim, decoder_dim)

    def encode(self, act):
        x = self.action_proj(act)
        x = self.conv_block(x)
        B, H, D = x.shape

        x = self.add_positional_emb(x)

        if self.use_causal_encoder:
            mask = nn.Transformer.generate_square_subsequent_mask(x.size(1), device=x.device)
            x = self.encoder(x, mask=mask, is_causal=True)
        else:
            x = self.encoder(x)

        x = x[:, -H:]

        return x

    def decode(self, codes):
        x = self.fixed_positional_emb(
            torch.zeros((codes.shape[0], self.skill_block_size, self.decoder_dim), dtype=codes.dtype,
                        device=codes.device))

        if self.use_causal_decoder:
            mask = nn.Transformer.generate_square_subsequent_mask(x.size(1), device=x.device)
            x = self.decoder(x, codes, tgt_mask=mask, tgt_is_causal=True)
        else:
            x = self.decoder(x, codes)
        x = self.action_head(x)
        return x

    def get_sample(self,act):
        h = self.encode(act)
        moments = self.quant_proj(h)
        posterior = DiagonalGaussianDistribution(moments)
        z = posterior.sample()

        return z
    def get_action(self,z):
        z = self.post_quant_proj(z)
        dec = self.decode(z)
        return dec

    def get_action_latent(self, act):
        h = self.encode(act)

        return h

    def forward(self, act):

        h = self.encode(act)

        moments = self.quant_proj(h)
        posterior = DiagonalGaussianDistribution(moments)
        z = posterior.sample()

        z = self.post_quant_proj(z)
        dec = self.decode(z)

        l1_loss = self.loss_fn(dec, act)

        kl_loss = posterior.kl().mean()
        loss = l1_loss + kl_loss * self.kl_weight
        info = {
            'recon_loss': l1_loss.item(),
            'kl_loss': kl_loss.item(),
            'total_loss': loss.item(),
        }
        return loss, info

    @property
    def device(self):
        return next(self.parameters()).device

class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=2)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
        return x

    def kl(self, other=None):
        dims = np.arange(1, self.mean.dim()).tolist()
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.sum(torch.pow(self.mean, 2)
                                       + self.var - 1.0 - self.logvar,
                                       dim=dims)
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=dims)

    def nll(self, sample):
        dims = np.arange(1, sample.dim()).tolist()
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, stride, no_pad=False):
        super(CausalConv1d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        if no_pad:
            self.padding = 0
        else:
            self.padding = dilation*(kernel_size-1)
        self.stride = stride
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, stride=stride)

    def forward(self, x):
        x = self.conv(x)
        last_n = (2*self.padding-self.kernel_size)//self.stride + 1
        if last_n> 0:
            return x[:, :, :-last_n]
        else:
            return x


class Conv1dBlock(nn.Module):

    def __init__(self, inp_channels, out_channels, kernel_size, stride, n_groups=4, causal=True, no_pad=False):
        super().__init__()
        if causal:
            conv = CausalConv1d(inp_channels, out_channels, kernel_size, dilation=1, stride=stride, no_pad=no_pad)
        else:
            conv = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size//2, stride=stride)

        self.block = nn.Sequential(
            conv,
            Rearrange('batch channels horizon -> batch channels 1 horizon'),
            nn.GroupNorm(n_groups, out_channels),
            Rearrange('batch channels 1 horizon -> batch channels horizon'),
            nn.Mish(),
        )
    def forward(self, x):
        return self.block(x)


# TODO: delete deconv modules for final release version
class CausalDeConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, stride):
        super(CausalDeConv1d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x = self.conv(x)
        last_n = self.kernel_size-self.stride
        if last_n> 0:
            return x[:, :, :-last_n]
        else:
            return x

class DeConv1dBlock(nn.Module):

    def __init__(self, inp_channels, out_channels, kernel_size, stride, n_groups=8, causal=True):
        super().__init__()
        if causal:
            conv = CausalDeConv1d(inp_channels, out_channels, kernel_size, dilation=1, stride=stride)
        else:
            conv = nn.ConvTranspose1d(inp_channels, out_channels, kernel_size, padding=kernel_size//2, stride=stride, output_padding=stride-1)

        self.block = nn.Sequential(
            conv,
            Rearrange('batch channels horizon -> batch channels 1 horizon'),
            nn.GroupNorm(n_groups, out_channels),
            Rearrange('batch channels 1 horizon -> batch channels horizon'),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ResidualTemporalBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, kernel_size=[5,3], stride=[2,2], n_groups=8, causal=True, residual=False, pooling_layers=[]):
        super().__init__()
        self.pooling_layers = pooling_layers
        self.blocks = nn.ModuleList()
        for i in range(len(kernel_size)):
            block = Conv1dBlock(
                inp_channels if i == 0 else out_channels,
                out_channels,
                kernel_size[i],
                stride[i],
                n_groups=n_groups,
                causal=causal
            )
            self.blocks.append(block)
        if residual:
            if out_channels == inp_channels and stride[0] == 1:
                self.residual_conv = nn.Identity()
            else:
                self.residual_conv = nn.Conv1d(inp_channels, out_channels, kernel_size=1, stride=sum(stride))
        if pooling_layers:
            self.pooling = nn.AvgPool1d(kernel_size=2, stride=2)

    def forward(self, input_dict):
        x = input_dict
        x = torch.transpose(x, 1, 2) # [B, T, D] -> [B, D, T]
        out = x
        layer_num = 0
        for block in self.blocks:
            out = block(out)
            if hasattr(self, 'pooling'):
                if layer_num in self.pooling_layers:
                    out = self.pooling(out)
            layer_num += 1
        if hasattr(self, 'residual_conv'):
            out = out + self.residual_conv(x)
        return torch.transpose(out, 1, 2)



