"""
Implementation of Transformer, parameterized as Gaussian and GMM.

Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/transformer_for_diffusion.py

"""

import logging
import torch
import torch.nn as nn
from model.diffusion.modules import SinusoidalPosEmb

logger = logging.getLogger(__name__)


class Gaussian_Transformer(nn.Module):
    def __init__(
        self,
        transition_dim,
        horizon_steps,
        cond_dim,
        transformer_embed_dim=256,
        transformer_num_heads=8,
        transformer_num_layers=6,
        transformer_activation="gelu",
        p_drop_emb=0.0,
        p_drop_attn=0.0,
        fixed_std=None,
        learn_fixed_std=False,
        std_min=0.01,
        std_max=1,
    ):

        super().__init__()
        self.transition_dim = transition_dim
        self.horizon_steps = horizon_steps
        output_dim = transition_dim

        if fixed_std is None:  # learn the logvar
            output_dim *= 2  # mean and logvar
            logger.info("Using learned std")
        elif learn_fixed_std:  # learn logvar
            self.logvar = torch.nn.Parameter(
                torch.log(torch.tensor([fixed_std**2 for _ in range(transition_dim)])),
                requires_grad=True,
            )
            logger.info(f"Using fixed std {fixed_std} with learning")
        else:
            logger.info(f"Using fixed std {fixed_std} without learning")
        self.logvar_min = torch.nn.Parameter(
            torch.log(torch.tensor(std_min**2)), requires_grad=False
        )
        self.logvar_max = torch.nn.Parameter(
            torch.log(torch.tensor(std_max**2)), requires_grad=False
        )
        self.learn_fixed_std = learn_fixed_std
        self.fixed_std = fixed_std

        self.transformer = Transformer(
            output_dim,
            horizon_steps,
            cond_dim,
            T_cond=1,  # right now we assume only one step of observation everywhere
            n_layer=transformer_num_layers,
            n_head=transformer_num_heads,
            n_emb=transformer_embed_dim,
            p_drop_emb=p_drop_emb,
            p_drop_attn=p_drop_attn,
            activation=transformer_activation,
        )

    def forward(self, cond):
        B = len(cond["state"])
        device = cond["state"].device

        # flatten history
        state = cond["state"].view(B, -1)

        # input to transformer
        state = state.unsqueeze(1)  # (B,1,cond_dim)
        out, _ = self.transformer(state)  # (B,horizon,output_dim)

        # use the first half of the output as mean
        out_mean = torch.tanh(out[:, :, : self.transition_dim])
        out_mean = out_mean.view(B, self.horizon_steps * self.transition_dim)

        if self.learn_fixed_std:
            out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
            out_scale = torch.exp(0.5 * out_logvar)
            out_scale = out_scale.view(1, self.transition_dim)
            out_scale = out_scale.repeat(B, self.horizon_steps)
        elif self.fixed_std is not None:
            out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
        else:
            out_logvar = out[:, :, self.transition_dim :]
            out_logvar = out_logvar.reshape(B, self.horizon_steps * self.transition_dim)
            out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
            out_scale = torch.exp(0.5 * out_logvar)
        return out_mean, out_scale


class GMM_Transformer(nn.Module):
    def __init__(
        self,
        transition_dim,
        horizon_steps,
        cond_dim,
        num_modes=5,
        transformer_embed_dim=256,
        transformer_num_heads=8,
        transformer_num_layers=6,
        transformer_activation="gelu",
        p_drop_emb=0,
        p_drop_attn=0,
        fixed_std=None,
        learn_fixed_std=False,
        std_min=0.01,
        std_max=1,
    ):

        super().__init__()
        self.num_modes = num_modes
        self.transition_dim = transition_dim
        self.horizon_steps = horizon_steps
        output_dim = transition_dim * num_modes
        # + num_modes  # mean and modes

        if fixed_std is None:
            output_dim += num_modes * transition_dim  # logvar for each mode
            logger.info("Using learned std")
        elif (
            learn_fixed_std
        ):  # initialize to fixed_std, separate for each action and mode, but same along horizon
            self.logvar = torch.nn.Parameter(
                torch.log(
                    torch.tensor(
                        [fixed_std**2 for _ in range(num_modes * transition_dim)]
                    )
                ),
                requires_grad=True,
            )
            logger.info(f"Using fixed std {fixed_std} with learning")
        else:
            logger.info(f"Using fixed std {fixed_std} without learning")
        self.logvar_min = torch.nn.Parameter(
            torch.log(torch.tensor(std_min**2)), requires_grad=False
        )
        self.logvar_max = torch.nn.Parameter(
            torch.log(torch.tensor(std_max**2)), requires_grad=False
        )
        self.fixed_std = fixed_std
        self.learn_fixed_std = learn_fixed_std

        self.transformer = Transformer(
            output_dim,
            horizon_steps,
            cond_dim,
            T_cond=1,  # right now we assume only one step of observation everywhere
            n_layer=transformer_num_layers,
            n_head=transformer_num_heads,
            n_emb=transformer_embed_dim,
            p_drop_emb=p_drop_emb,
            p_drop_attn=p_drop_attn,
            activation=transformer_activation,
        )
        self.modes_head = nn.Linear(horizon_steps * transformer_embed_dim, num_modes)

    def forward(self, cond):
        B = len(cond["state"])
        device = cond["state"].device

        # flatten history
        state = cond["state"].view(B, -1)

        # input to transformer
        state = state.unsqueeze(1)  # (B,1,cond_dim)
        out, out_prehead = self.transformer(
            state
        )  # (B,horizon,output_dim), (B,horizon,emb_dim)

        # use the first half of the output as mean
        out_mean = torch.tanh(out[:, :, : self.num_modes * self.transition_dim])
        out_mean = out_mean.reshape(
            B, self.horizon_steps, self.num_modes, self.transition_dim
        )
        out_mean = out_mean.permute(0, 2, 1, 3)  # flip horizons and modes
        out_mean = out_mean.reshape(
            B, self.num_modes, self.horizon_steps * self.transition_dim
        )

        if self.learn_fixed_std:
            out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
            out_scale = torch.exp(0.5 * out_logvar)
            out_scale = out_scale.view(1, self.num_modes, self.transition_dim)
            out_scale = out_scale.repeat(B, 1, self.horizon_steps)
        elif self.fixed_std is not None:
            out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std
        else:
            out_logvar = out[
                :, :, self.num_modes * self.transition_dim : -self.num_modes
            ]
            out_logvar = out_logvar.reshape(
                B, self.horizon_steps, self.num_modes, self.transition_dim
            )
            out_logvar = out_logvar.permute(0, 2, 1, 3)  # flip horizons and modes
            out_logvar = out_logvar.reshape(
                B, self.num_modes, self.horizon_steps * self.transition_dim
            )
            out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
            out_scale = torch.exp(0.5 * out_logvar)

        # use last horizon step as the mode weights - as it depends on the entire context
        # out_weights = out[:, -1, -self.num_modes :]  # (B,num_modes)
        out_weights = self.modes_head(out_prehead.view(B, -1))
        return out_mean, out_scale, out_weights


class Transformer(nn.Module):
    def __init__(
        self,
        output_dim,
        horizon,
        cond_dim,
        T_cond=1,
        n_layer=12,
        n_head=12,
        n_emb=768,
        p_drop_emb=0.0,
        p_drop_attn=0.0,
        causal_attn=False,
        n_cond_layers=0,
        activation="gelu",
    ):
        super().__init__()

        # encoder for observations
        self.cond_obs_emb = nn.Linear(cond_dim, n_emb)
        self.cond_pos_emb = nn.Parameter(torch.zeros(1, T_cond, n_emb))
        if n_cond_layers > 0:
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=n_emb,
                nhead=n_head,
                dim_feedforward=4 * n_emb,
                dropout=p_drop_attn,
                activation=activation,
                batch_first=True,
                norm_first=True,
            )
            self.encoder = nn.TransformerEncoder(
                encoder_layer=encoder_layer,
                num_layers=n_cond_layers,
            )
        else:
            self.encoder = nn.Sequential(
                nn.Linear(n_emb, 4 * n_emb),
                nn.Mish(),
                nn.Linear(4 * n_emb, n_emb),
            )

        # decoder
        self.pos_emb = nn.Parameter(torch.zeros(1, horizon, n_emb))
        self.drop = nn.Dropout(p_drop_emb)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=n_emb,
            nhead=n_head,
            dim_feedforward=4 * n_emb,
            dropout=p_drop_attn,
            activation=activation,
            batch_first=True,
            norm_first=True,  # important for stability
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=decoder_layer, num_layers=n_layer
        )

        # attention mask
        if causal_attn:
            # causal mask to ensure that attention is only applied to the left in the input sequence
            # torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT
            # therefore, the upper triangle should be -inf and others (including diag) should be 0.
            sz = horizon
            mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
            mask = (
                mask.float()
                .masked_fill(mask == 0, float("-inf"))
                .masked_fill(mask == 1, float(0.0))
            )
            self.register_buffer("mask", mask)

            t, s = torch.meshgrid(
                torch.arange(horizon), torch.arange(T_cond), indexing="ij"
            )
            mask = t >= (
                s - 1
            )  # add one dimension since time is the first token in cond
            mask = (
                mask.float()
                .masked_fill(mask == 0, float("-inf"))
                .masked_fill(mask == 1, float(0.0))
            )
            self.register_buffer("memory_mask", mask)
        else:
            self.mask = None
            self.memory_mask = None

        # decoder head
        self.ln_f = nn.LayerNorm(n_emb)
        self.head = nn.Linear(n_emb, output_dim)

        # constants
        self.T_cond = T_cond
        self.horizon = horizon

        # init
        self.apply(self._init_weights)

    def _init_weights(self, module):
        ignore_types = (
            nn.Dropout,
            SinusoidalPosEmb,
            nn.TransformerEncoderLayer,
            nn.TransformerDecoderLayer,
            nn.TransformerEncoder,
            nn.TransformerDecoder,
            nn.ModuleList,
            nn.Mish,
            nn.Sequential,
        )
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.MultiheadAttention):
            weight_names = [
                "in_proj_weight",
                "q_proj_weight",
                "k_proj_weight",
                "v_proj_weight",
            ]
            for name in weight_names:
                weight = getattr(module, name)
                if weight is not None:
                    torch.nn.init.normal_(weight, mean=0.0, std=0.02)
            bias_names = ["in_proj_bias", "bias_k", "bias_v"]
            for name in bias_names:
                bias = getattr(module, name)
                if bias is not None:
                    torch.nn.init.zeros_(bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
        elif isinstance(module, Transformer):
            torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
            if module.cond_obs_emb is not None:
                torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
        elif isinstance(module, ignore_types):
            # no param
            pass
        else:
            raise RuntimeError("Unaccounted module {}".format(module))

    def forward(
        self,
        cond: torch.Tensor,
        **kwargs,
    ):
        """
        cond: (B, T, cond_dim)
        output: (B, T, output_dim)
        """
        # encoder
        cond_embeddings = self.cond_obs_emb(cond)  # (B,To,n_emb)
        tc = cond_embeddings.shape[1]
        position_embeddings = self.cond_pos_emb[
            :, :tc, :
        ]  # each position maps to a (learnable) vector
        x = self.drop(cond_embeddings + position_embeddings)
        x = self.encoder(x)
        memory = x
        # (B,T_cond,n_emb)

        # decoder
        position_embeddings = self.pos_emb[
            :, : self.horizon, :
        ]  # each position maps to a (learnable) vector
        position_embeddings = position_embeddings.expand(
            cond.shape[0], self.horizon, -1
        )  # repeat for batch dimension
        x = self.drop(position_embeddings)
        # (B,T,n_emb)
        x = self.decoder(
            tgt=x,
            memory=memory,
            tgt_mask=self.mask,
            memory_mask=self.memory_mask,
        )
        # (B,T,n_emb)

        # head
        x_prehead = self.ln_f(x)
        x = self.head(x_prehead)
        # (B,T,n_out)
        return x, x_prehead


if __name__ == "__main__":
    transformer = Transformer(
        output_dim=10,
        horizon=4,
        T_cond=1,
        cond_dim=16,
        causal_attn=False,  # no need to use for delta control
        # From Cheng: I found the causal attention masking to be critical to get the transformer variant of diffusion policy to work. My suspicion is that when used without it, the model "cheats" by looking ahead into future end-effector poses, which is almost identical to the action of the current timestep.
        n_cond_layers=0,
    )
    # opt = transformer.configure_optimizers()

    cond = torch.zeros((4, 1, 16))  # B x 1 x cond_dim
    out, _ = transformer(cond)
