import torch
import torch.nn as nn
import torch.nn.functional as F

from nerv.training import BaseModel
from nerv.models import deconv_out_shape, conv_norm_act, deconv_norm_act

# from transformers import AutoImageProcessor, EfficientNetModel
from transformers import AutoImageProcessor

from ...base_slots.models.utils import assert_shape, SoftPositionEmbed, torch_cat

from slotformer.base_slots.models import StoSAVi

from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce

def pack_one(x, pattern):
    return pack([x], pattern)

def unpack_one(x, ps, pattern):
    return unpack(x, ps, pattern)[0]

def get_sin_pos_enc(seq_len, d_model):
    """Sinusoid absolute positional encoding."""
    inv_freq = 1. / (10000**(torch.arange(0.0, d_model, 2.0) / d_model))
    pos_seq = torch.arange(seq_len - 1, -1, -1).type_as(inv_freq)
    sinusoid_inp = torch.outer(pos_seq, inv_freq)
    pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
    return pos_emb.unsqueeze(0)  # [1, L, C]


def build_pos_enc(pos_enc, input_len, d_model):
    """Positional Encoding of shape [1, L, D]."""
    if not pos_enc:
        return None
    # ViT, BEiT etc. all use zero-init learnable pos enc
    if pos_enc == 'learnable':
        pos_embedding = nn.Parameter(torch.zeros(1, input_len, d_model))
    # in SlotFormer, we find out that sine P.E. is already good enough
    elif 'sin' in pos_enc:  # 'sin', 'sine'
        pos_embedding = nn.Parameter(
            get_sin_pos_enc(input_len, d_model), requires_grad=False)
    else:
        raise NotImplementedError(f'unsupported pos enc {pos_enc}')
    return pos_embedding

class TokenLearner(nn.Module):
    """
    https://arxiv.org/abs/2106.11297
    using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map
    """

    def __init__(
        self,
        *,
        dim,
        ff_mult = 2,
        num_output_tokens = 8,
        num_layers = 2
    ):
        super().__init__()
        inner_dim = dim * ff_mult * num_output_tokens

        self.num_output_tokens = num_output_tokens
        self.net = nn.Sequential(
            nn.Conv2d(dim * num_output_tokens, inner_dim, 1, groups = num_output_tokens),
            nn.GELU(),
            nn.Conv2d(inner_dim, num_output_tokens, 1, groups = num_output_tokens),
        )

    def forward(self, x):
        x, ps = pack_one(x, '* c h w')
        x = repeat(x, 'b c h w -> b (g c) h w', g = self.num_output_tokens)
        attn = self.net(x)

        attn = rearrange(attn, 'b g h w -> b 1 g h w')
        x = rearrange(x, 'b (g c) h w -> b c g h w', g = self.num_output_tokens)

        x = reduce(x * attn, 'b c g h w -> b c g', 'mean')
        x = unpack_one(x, ps, '* c n')
        return x


# input image + inst -> output action
# language conditional robotics transformer
class RoboticsTransformer(BaseModel):
    """language conditional robotics transformer for behavioral cloning
        We use language conditioned cross attention instead of FiLM conditioning on EfficientNetb3
    """

    def __init__(
            self,
            resolution=(64, 64),
            enc_dict=dict(
                enc_model='google/efficientnet-b3',
                num_token_learner=8,
                enc_out_channels=128 
            ),
            act_dec_dict=dict(
                history_len=6,
                t_pe='sin',
                d_model=128,
                num_layers=4,
                num_heads=8,
                ffn_dim=512,
                inst_size=768,
                act_size=2,
                norm_first=True,
            ),
            loss_dict=dict(
                rollout_len=6,
                use_img_recon_loss=False,
            ),
            eps=1e-6,
    ):
        super().__init__()

        self.resolution = resolution
        self.eps = eps
        self.enc_dict = enc_dict
        self.act_dec_dict = act_dec_dict
        self.loss_dict = loss_dict

        self._build_encoder()
        self._build_dt()

        self.loss_decay_factor = 1.  # temporal loss weighting

        #assert self.enc_dict['enc_out_channels'] == self.act_dec_dict['d_model']

    def _build_encoder(self):
        self.encoder_preprocessor = AutoImageProcessor.from_pretrained(self.enc_dict['enc_model'])
        # TODO: no pretrained model for RT-1 output size [B*T, 512, 9, 9] use [B*T, 1536, 10, 10]
        # eff_dim = 1536, num_token = 8
        # self.encoder = EfficientNetModel.from_pretrained(self.enc_dict['enc_model'])
        self.encoder = None

        # TODO: fetch efficientnet output size for tokenlearner dim
        self.token_learner = TokenLearner(dim=1536, ff_mult=2, num_output_tokens=self.enc_dict['num_token_learner'], num_layers=2)
        self.token_learner_projector = nn.Linear(1536, self.enc_dict['enc_out_channels'])

    def _get_encoder_out(self, img):
        img = self.encoder_preprocessor(img, return_tensors="pt")
        for k in img:
            if torch.is_tensor(img[k]):
                img[k] = img[k].cuda()
        with torch.no_grad():
            outputs = self.encoder(**img)['last_hidden_state'] # [B*T, eff_dim, h, w]

        encoder_out = self.token_learner(outputs) # [B*T, eff_dim, num_token]
        encoder_out = rearrange(encoder_out, 'bt d tk -> bt tk d') # [B*T, num_token, eff_dim] 
        encoder_out = self.token_learner_projector(encoder_out) # [B*T, num_token, d_model]
        return encoder_out

    def _build_dt(self):
        self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], self.act_dec_dict['d_model'])
        self.act_enc_t_pe = build_pos_enc(self.act_dec_dict['t_pe'], self.act_dec_dict['history_len'], self.act_dec_dict['d_model'])

        dec_layer = nn.TransformerDecoderLayer(
            d_model=self.act_dec_dict['d_model'],
            nhead=self.act_dec_dict['num_heads'],
            dim_feedforward=self.act_dec_dict['ffn_dim'],
            norm_first=self.act_dec_dict['norm_first'],
            batch_first=True,
        )

        self.action_decoder = nn.TransformerDecoder(
            decoder_layer=dec_layer, num_layers=self.act_dec_dict['num_layers'])

        self.out_proj = nn.Linear(self.act_dec_dict['d_model'], self.act_dec_dict['act_size'])
        self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, self.act_dec_dict['d_model'])))

    def forward(self, data_dict):
        """Forward pass."""
        frames = data_dict['img']  # [B, T, 3, H, W]
        inst = data_dict['instruction']  # [B, D_inst]
        B = frames.shape[0]
        T = frames.shape[1]
        
        inst = self.inst_proj(inst).unsqueeze(1) # [B, 1, D]

        frames = frames.flatten(0, 1) # [B*T, 3, H, W]
        encoder_out = self._get_encoder_out(frames) # [B*T, num_token, D]
        encoder_out = rearrange(encoder_out, '(b t) tk c -> b t tk c', b=B) # [B, T, num_token, D] 
        num_token = encoder_out.shape[2]

        # temporal_pe repeat for each slot, shouldn't be None
        # [1, T, D] --> [B, T, num_token, D]
        act_enc_pe = self.act_enc_t_pe.unsqueeze(2).repeat(B, 1, encoder_out.shape[2], 1)
        encoder_out = encoder_out + act_enc_pe

        encoder_out = encoder_out.flatten(1, 2) # [B, T*num_token, D] 
        action_token = self.action_token.unsqueeze(0).repeat(B, 1, 1)

        decoder_in = torch.cat((encoder_out, action_token), dim=1) # [B, T*num_token + 1, D] 

        # causal mask
        mask = torch.ones(T, T).triu(diagonal=1).repeat_interleave(num_token, dim=0).repeat_interleave(num_token, dim=1)
        mask = torch.cat((mask, torch.ones(mask.shape[0], 1)), dim=1)
        mask = torch.cat((mask, torch.zeros(1, mask.shape[1])), dim=0).cuda() # TODO: self.device doesn't work
        actions = self.action_decoder(decoder_in, inst, tgt_mask=mask) # [B, T*num_token + 1, D] 
        actions = self.out_proj(actions[:, -1]) # [B, D_a]
        out_dict = {
            'actions': actions,  # [B, D_a]
        }
        return out_dict

    def calc_train_loss(self, data_dict, out_dict):
        """Compute training loss."""
        loss_dict = {}
        gt_actions = data_dict['actions']  
        pred_actions = out_dict['actions']
        actions_loss = F.mse_loss(pred_actions, gt_actions, reduction='none')

        # compute per-step slot loss in eval time
        if not self.training:
            pass
            '''for step in range(min(6, gt_slots.shape[1])):
                loss_dict[f'slot_recon_loss_{step+1}'] = \
                slots_loss[:, step].mean()
            '''

        # apply temporal loss weighting as done in RPIN
        # penalize more for early steps, less for later steps
        if self.loss_decay_factor < 1.:
            w = self.loss_decay_factor**torch.arange(gt_actions.shape[1])
            w = w.type_as(actions_loss)
            # w should sum up to rollout_T
            w = w / w.sum() * gt_actions.shape[1]
            actions_loss = actions_loss * w[None, :, None, None]

        loss_dict['action_loss'] = actions_loss.mean()

        return loss_dict

    @property
    def dtype(self):
        return self.inst_proj.dtype

    @property
    def device(self):
        return self.inst_proj.device

    def train(self, mode=True):
        super().train(mode)
        # keep decoder part in eval mode
        #self.decoder.eval()
        #self.decoder_pos_embedding.eval()
        return self

