import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

from ..modules.actor import AttentionActor


class PolicyModel(nn.Module):
    def __init__(self, policy_config, img_size, patch_size):
        super().__init__()
        self.actor = AttentionActor(policy_config)
        self.to_patch_embed = nn.Sequential(
            nn.Conv2d(3, policy_config.n_embd, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        
        self.img_size = img_size
        self.patch_size = patch_size
    
    def forward(self, batch):
        x, _, _ = batch
        B, T, _, _, _ = x.size()
        x = self.to_patch_embed(x.reshape(B*T, *x.shape[-3:])).contiguous()
        
        pred = self.actor(x, None) # (B, T, A)

        return pred, _
    
    def criterion(self, batch, output):
        x, a, r = batch
        pred, _ = output

        loss = self.actor.loss(pred, torch.flatten(a, 0, 1))

        return loss, {
            'loss': loss.item(),
            'post_rec_loss': loss.item(),
            'prior_rec_loss': loss.item(),
        }
