import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

from src.modules.st_transformer import STTransformer


class RewardModel(nn.Module):
    def __init__(self, reward_config, img_size, patch_size, seq_len):
        super().__init__()
        reward_config.n_tokens_per_frame = (img_size // patch_size) ** 2
        reward_config.block_size = seq_len * reward_config.n_tokens_per_frame
        reward_config.vocab_size = None
        
        self.to_patch_embed = nn.Sequential(
            nn.Conv2d(3, reward_config.n_embd, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        self.reward = STTransformer(reward_config)
        self.trunk = nn.Linear(reward_config.n_embd, 1)
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.seq_len = seq_len
        
    def predict(self, imgs):
        rewards = self.forward((imgs, None, None))[0][:, -1]
        rewards = torch.sign(rewards) * (torch.exp(torch.abs(rewards)) - 1)

        return rewards
    
    def predict_all(self, imgs):
        B, T = imgs.shape[:2]
        rewards = []
        for t in range(T):
            start_idx = max(0, t - 7)
            reward = self.forward((imgs[:, start_idx:t+1], None, None))[0][:, -1:]
            reward = torch.sign(reward) * (torch.exp(torch.abs(reward)) - 1)
            rewards.append(reward)
        rewards = torch.cat(rewards ,dim=1)
        return rewards
    
    def forward(self, batch):
        imgs, actions, rewards = batch
        B, T, _, _, _ = imgs.size()
        x = self.to_patch_embed(imgs.reshape(B*T, *imgs.shape[-3:])).contiguous()
        _, HW, C = x.size()
        reward = self.trunk(self.reward(x.reshape(B, T*HW, C)).reshape(B, T, HW, -1)[:, :, -1])

        return reward, _
    
    def criterion(self, batch, output):
        imgs, actions, rewards = batch
        p_rewards, _ = output

        loss = F.mse_loss(rewards, p_rewards)  

        return loss, {
            'reward_loss': loss.item(),
        }
