import numpy as np
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange
from tqdm import tqdm, trange
from PIL import Image

from src.modules import STTransformer, AttentionActor


class PolicyEmbedModel(nn.Module):
    def __init__(self, posterior_config, prior_config, policy_config, kl_config, img_size, patch_size, seq_len):
        super().__init__()
        posterior_config.n_tokens_per_frame = (img_size // patch_size) ** 2
        posterior_config.block_size = seq_len * posterior_config.n_tokens_per_frame
        posterior_config.vocab_size = None
        prior_config.n_tokens_per_frame = (img_size // patch_size) ** 2
        prior_config.block_size = seq_len * prior_config.n_tokens_per_frame
        prior_config.vocab_size = None

        self.posterior_model = STTransformer(posterior_config)
        self.prior_model = STTransformer(prior_config)
        self.policy_model = AttentionActor(policy_config)

        self.to_patch_post_embed = nn.Sequential(
            nn.Conv2d(3, posterior_config.n_embd, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        self.to_patch_prior_embed = nn.Sequential(
            nn.Conv2d(3, prior_config.n_embd, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        self.prior_output_proj = nn.Linear(prior_config.n_embd, posterior_config.cat_size*posterior_config.class_size, bias=prior_config.bias)
        self.posterior_output_proj = nn.Linear(posterior_config.n_embd, posterior_config.cat_size*posterior_config.class_size, bias=posterior_config.bias)
        self.a_post_embed = nn.Linear(policy_config.action_dim, posterior_config.n_embd)
        self.a_prior_embed = nn.Linear(policy_config.action_dim, prior_config.n_embd)
        self.z_embed = nn.Linear(posterior_config.cat_size*posterior_config.class_size, policy_config.n_embd, bias=True)

        self.img_size = img_size
        self.patch_size = patch_size
        self.seq_len = seq_len
        self.cat_size = posterior_config.cat_size
        self.class_size = posterior_config.class_size
        self.kl_info = kl_config
        
    def get_z_distribution(self, logits):
        return torch.distributions.Independent(
            torch.distributions.OneHotCategoricalStraightThrough(logits=logits),
            reinterpreted_batch_ndims=1
        )

    def get_z_sample(self, dist, reparameterize=False, deterministic=False):
        if deterministic:
            sample = dist.base_dist.probs
            onehot = torch.eye(sample.shape[-1]).to(sample.device)
            sample = onehot[sample.argmax(dim=-1), :]
        else:
            if reparameterize:
                sample = dist.rsample()
            else:
                sample = dist.sample()
        sample = sample.reshape(*sample.shape[:-2], -1)
        return sample
    
    def posterior(self, x, a):
        # x shape (B T C H W)
        B, T, _, _, _ = x.size()
        x = self.to_patch_post_embed(x.reshape(B*T, *x.shape[-3:])).contiguous()
        a = self.a_post_embed(a).reshape(B*T, 1, -1).repeat(1, x.shape[1], 1)
        x += a
        _, HW, _ = x.size()
        x = x.reshape(B, T*HW, -1)
        h = self.posterior_model(x)
        logits = self.posterior_output_proj(h.reshape(B, T, HW, -1)[:, -1:, -1]).reshape(B, 1, self.cat_size, self.class_size)
        return logits
        
    def prior(self, x, a):
        # x shape (B T C H W)
        B, T, _, _, _ = x.size()
        x = self.to_patch_prior_embed(x.reshape(B*T, *x.shape[-3:])).contiguous()
        a = self.a_prior_embed(a).reshape(B*T, 1, -1).repeat(1, x.shape[1], 1)
        x += a
        _, HW, _ = x.size()
        x = x.reshape(B, T*HW, -1)
        h = self.prior_model(x)
        logits = self.prior_output_proj(h.reshape(B, T, HW, -1)[:, :, -1].reshape(B*T, -1)).reshape(B, T, self.cat_size, self.class_size)
        return logits
    
    def reconstruct(self, x, z):
        B, T, _, _, _ = x.size()
        x = self.to_patch_post_embed(x.reshape(B*T, *x.shape[-3:])).contiguous()
        _, HW, _ = x.size()
        x = x.reshape(B*T, HW, -1)

        zT = z.shape[1]
        assert T % zT == 0
        z = self.z_embed(z)
        z = z.repeat(1, (T//zT)*HW, 1).reshape(B*T, HW, -1)
        x += z

        pred = self.policy_model(x, z)
        
        return pred

    def _kl_loss(self, prior_logits, posterior_logits):
        kl_lhs = torch.mean(torch.distributions.kl.kl_divergence(self.get_z_distribution(posterior_logits.detach()), self.get_z_distribution(prior_logits)))
        kl_rhs = torch.mean(torch.distributions.kl.kl_divergence(self.get_z_distribution(posterior_logits), self.get_z_distribution(prior_logits.detach())))
        if self.kl_info.use_free_nats:
            free_nats = self.kl_info.free_nats
            kl_lhs = torch.max(kl_lhs,kl_lhs.new_full(kl_lhs.size(), free_nats))
            kl_rhs = torch.max(kl_rhs,kl_rhs.new_full(kl_rhs.size(), free_nats))
        kl_loss = self.kl_info.kl_balance_scale * kl_lhs + (1 - self.kl_info.kl_balance_scale) * kl_rhs
        return kl_loss
    
    def forward(self, batch):
        x, a, _ = batch
        # x shape (B T C H W)
        posterior_logits = self.posterior(x, a)
        prior_logits = self.prior(x, a)

        z_posterior_dist = self.get_z_distribution(posterior_logits)
        z_prior_dist = self.get_z_distribution(prior_logits)
        
        z_posterior = self.get_z_sample(z_posterior_dist, reparameterize=True, deterministic=False)
        posterior_action_pred = self.reconstruct(x, z_posterior)
        with torch.no_grad():
            z_prior = self.get_z_sample(z_prior_dist, reparameterize=False, deterministic=True)
            prior_action_pred = self.reconstruct(x, z_prior)

        return posterior_logits, prior_logits, posterior_action_pred, prior_action_pred
    
    def criterion(self, batch, output):
        x, a, r = batch
        posterior_logits, prior_logits, posterior_action_pred, prior_action_pred = output

        posterior_action_loss = self.policy_model.loss(posterior_action_pred, torch.flatten(a, 0, 1))
        prior_action_loss = self.policy_model.loss(prior_action_pred, torch.flatten(a, 0, 1))
        
        kl_loss = self._kl_loss(prior_logits, posterior_logits)

        loss = posterior_action_loss + self.kl_info.kl_loss_coef * kl_loss

        prior_dist = prior_logits.softmax(dim=-1)
        posterior_dist = posterior_logits.softmax(dim=-1)
        prior_entropy = torch.mean((-prior_dist*torch.log(prior_dist)).sum(dim=-1))
        posterior_entropy = torch.mean((-posterior_dist*torch.log(posterior_dist)).sum(dim=-1))

        return loss, {
            'loss': loss.item(),
            'post_rec_loss': posterior_action_loss.item(),
            'prior_rec_loss': prior_action_loss.item(),
            'prior_entropy': prior_entropy.item(),
            'posterior_entropy': posterior_entropy.item(),
            'kl_loss': kl_loss.item(),
        }
