import numpy as np
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from PIL import Image

from src.modules import STTransformer
from src.modules.actor import CrossAttentionBlock


class DynamicModel(nn.Module):
    def __init__(self, video_tokenizer, embed_model, decoder_config):
        super().__init__()
        self.video_tokenizer = video_tokenizer
        self.video_tokenizer.requires_grad_(False)

        self.action_dim = decoder_config.action_dim

        decoder_config.n_tokens_per_frame = video_tokenizer.decoder.config.n_tokens_per_frame
        decoder_config.block_size = video_tokenizer.decoder.config.block_size
        decoder_config.vocab_size = video_tokenizer.quantizer.n_e

        self.decoder = STTransformer(decoder_config)

        self.embed_state = nn.Linear(video_tokenizer.quantizer.e_dim, decoder_config.n_embd)
        self.embed_action = nn.Linear(decoder_config.action_dim, decoder_config.n_embd)

        if embed_model is not None:
            self.embed_z = nn.Linear(embed_model.cat_size*embed_model.class_size, decoder_config.n_embd, bias=True)
            self.embed_sa_z = CrossAttentionBlock(decoder_config)
            self.embed_model = embed_model
            self.embed_model.requires_grad_(False)
        else:
            self.embed_model = None

    def forward(self, batch):
        # ----------- Dynamic Step -----------
        imgs, actions, rewards = batch
        # x shape (B T C H W)
        B, T = imgs.shape[:2]
        with torch.no_grad():
            states, _, state_token_idxs = self.video_tokenizer.encode(imgs) # (B, T*H*W, C)
            H, W = int(np.sqrt(states.shape[1] // T)), int(np.sqrt(states.shape[1] // T))
            states = states.reshape(B, T, H, W, -1)
            state_token_idxs = state_token_idxs.reshape(B, T, H, W)

        state_embeds = self.embed_state(states[:, :-1]).reshape(B*(T-1), H*W, -1)
        action_embeds = self.embed_action(actions[:, :-1]).reshape(B*(T-1), 1, -1).repeat(1, H*W, 1)
        sa_embeds = state_embeds + action_embeds
        if self.embed_model is not None:
            with torch.no_grad():
                prior_logits = self.embed_model.prior(imgs, actions)
                z_prior_dist = self.embed_model.get_z_distribution(prior_logits)
                z_prior = self.embed_model.get_z_sample(z_prior_dist, reparameterize=False, deterministic=True)
            z_embeds = self.embed_z(z_prior[:, :-1].reshape(B*(T-1), 1, -1).repeat(1, H*W, 1))
            sa_embeds = self.embed_sa_z(sa_embeds, z_embeds).reshape(B, (T-1)*H*W, -1)
            dynamic_logits = self.decoder(sa_embeds)
            dynamic_logits = dynamic_logits.reshape(B, T-1, H, W, -1)
                
            return state_token_idxs[:, 1:], dynamic_logits, prior_logits
        
        else:
            sa_embeds = sa_embeds.reshape(B, (T-1)*H*W, -1)
            dynamic_logits = self.decoder(sa_embeds)
            dynamic_logits = dynamic_logits.reshape(B, T-1, H, W, -1)
            
            return state_token_idxs[:, 1:], dynamic_logits

    def criterion(self, batch, output):
        _, actions, rewards = batch
        
        if self.embed_model is not None:
            target_idxs, logits, prior_logits = output
        else:
            target_idxs, logits = output

        model_loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target_idxs.reshape(-1), label_smoothing=0.1)
        
        if self.embed_model is not None:
            prior_dist = prior_logits.softmax(dim=-1)
            prior_entropy = torch.mean((-prior_dist*torch.log(prior_dist)).sum(dim=-1))
            return model_loss, {
                'model_loss': model_loss.item(),
                'prior_entropy': prior_entropy.item(),
            }
        else:
            return model_loss, {
                'model_loss': model_loss.item(),
            }
    
    def step(self, img_seq, action_seq):
        # x shape (B T C H W)
        B, T = img_seq.shape[:2]
        with torch.no_grad():
            states, _, state_token_idxs = self.video_tokenizer.encode(img_seq) # (B, T*H*W, C)
            H, W = int(np.sqrt(states.shape[1] // T)), int(np.sqrt(states.shape[1] // T))
            states = states.reshape(B, T, H, W, -1)
            state_token_idxs = state_token_idxs.reshape(B, T*H*W)

        state_embeds = self.embed_state(states).reshape(B*T, H*W, -1)
        action_embeds = self.embed_action(action_seq).reshape(B*T, 1, -1).repeat(1, H*W, 1)
        sa_embeds = state_embeds + action_embeds
        
        if self.embed_model is not None:
            with torch.no_grad():
                prior_logits = self.embed_model.prior(img_seq, action_seq)
                z_prior_dist = self.embed_model.get_z_distribution(prior_logits)
                z_prior = self.embed_model.get_z_sample(z_prior_dist, reparameterize=False, deterministic=True)
            z_embeds = self.embed_z(z_prior.reshape(B*T, 1, -1).repeat(1, H*W, 1))
            sa_embeds = self.embed_sa_z(sa_embeds, z_embeds).reshape(B, T*H*W, -1)
            
        else:
            sa_embeds = sa_embeds.reshape(B, T*H*W, -1)
            
        dynamic_logits = self.decoder(sa_embeds)
        img_token_idxs = torch.argmax(dynamic_logits, dim=-1) 
        img_token_idxs = torch.cat([state_token_idxs, img_token_idxs[:, -H*W:]], dim=-1)
        img_tokens = self.video_tokenizer.quantizer.embedding(img_token_idxs)
        img_tokens = self.video_tokenizer.quantizer.norm(img_tokens)
        recon_imgs = self.video_tokenizer.decode(img_tokens)
        
        return recon_imgs[:, -1] 
            
    @torch.no_grad()
    def rollout(self, batch):
        imgs, actions, rewards = batch
        B, T = imgs.shape[:2]
        img_seq = imgs[:, :1]
        seq_len = self.video_tokenizer.seq_len - 1
        for i in range(1, T):
            if i > seq_len:
                # input_img_seq = torch.cat([img_seq[:, :1], img_seq[:, -seq_len+1:]], dim=1)
                # init_action = actions[:, :i-seq_len+1, :3].sum(dim=1, keepdims=True)
                # init_action = actions[:, :1] * 0
                # init_action = torch.cat([init_action, actions[:, i-seq_len:i-seq_len+1, -1:]], dim=-1)
                # input_action_seq = torch.cat([init_action, actions[:, i-seq_len+1:i]], dim=1)
                input_img_seq = img_seq[:, -seq_len:]
                input_action_seq = actions[:, i-seq_len:i]
            else:
                input_img_seq = img_seq
                input_action_seq = actions[:, :i]    
            next_img = self.step(input_img_seq, input_action_seq)
            img_seq = torch.cat([img_seq, next_img.unsqueeze(dim=1)], dim=1)
        
        return img_seq
