import math
from typing import Dict, Tuple

from einops import rearrange
from PIL import Image
import torch
from torch.distributions.categorical import Categorical

from envs.env import DoneTracker
from .gpt import GPT
from .tokenizer import Tokenizer
from utils import set_seed


class WorldModelEnv:
    def __init__(self, tok: Tokenizer, gpt: GPT, data_loader, horizon: int) -> None:
        self.tok = tok
        self.gpt = gpt
        self.horizon = horizon
        self.batch_size = data_loader.batch_sampler.batch_size
    
        self.initial_condition_generator = self.make_initial_condition_generator(data_loader)

        self.done_tracker = DoneTracker(self.batch_size, self.device)
        
        self.obs = None
        self.ep_len = None
        self.keys_values_gpt, self.num_obs_tokens = None, None

        self.obs_tokens_buffer, self.act_tokens_buffer = None, None
    
    @property
    def all_done(self):
        return self.done_tracker.all_done
    
    @property
    def device(self): 
        return self.gpt.pos_emb.weight.device

    @property
    def is_alive(self):
        return self.done_tracker.is_alive

    @property
    def num_actions(self):
        return self.gpt.act_vocab_size

    def seed(self, x):
        set_seed(x)

    @torch.no_grad()
    def reset(self) -> torch.FloatTensor:
        self.done_tracker.reset()
        obs, act = next(self.initial_condition_generator)
        return self.reset_from_seq(obs, act), {}

    @torch.no_grad()
    def reset_from_seq(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.FloatTensor:
        assert obs.size(1) == act.size(1)
        self.ep_len = 0

        obs_tokens = self.tok.encode(obs).tokens    # (B, T, C, H, W) -> (B, T, K)
        _, _, num_obs_tokens = obs_tokens.shape
        if self.num_obs_tokens is None:
            self.num_obs_tokens = num_obs_tokens

        self.obs = obs[:, -1]

        self.obs_tokens_buffer = obs_tokens         # (B, T, K)
        self.act_tokens_buffer = act[:, :-1].unsqueeze(-1)  # (B, T, 1)

        return self.obs
      
    @torch.no_grad()
    def decode_obs_tokens(self, obs_tokens):
        obs_tokens_emb = self.tok.embedding(obs_tokens)     # (B, K, E)
        z = rearrange(obs_tokens_emb, 'b (h w) e -> b e h w', h=int(math.sqrt(self.num_obs_tokens)))
        rec = self.tok.decode(z)         # (B, C, H, W)
        return rec.clamp(-1, 1)

    @torch.no_grad()
    def step(self, act: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
        h = self.gpt.config.max_blocks - 3

        self.act_tokens_buffer = torch.cat((self.act_tokens_buffer, act.reshape(-1, 1, 1)), dim=1)[:, -h:]

        tokens = rearrange(torch.cat((self.obs_tokens_buffer, self.act_tokens_buffer), dim=2), 'b t k1 -> b (t k1)')  # (B, T(K+1))

        self.keys_values_gpt = self.gpt.transformer.generate_empty_keys_values(n=act.size(0), max_tokens=self.gpt.config.max_tokens)
        
        outputs = self.gpt(tokens, past_keys_values=self.keys_values_gpt)

        rew = Categorical(logits=outputs.logits_rew).sample()[:, -1].sub(1).float() # reward clipped to {-1, 0, 1}
        end = Categorical(logits=outputs.logits_end).sample()[:, -1]

        token = Categorical(logits=outputs.logits_obs).sample()[:, -1:]
        obs_tokens = [token]

        for k in range(self.num_obs_tokens):

            outputs = self.gpt(token, past_keys_values=self.keys_values_gpt)

            if k < self.num_obs_tokens - 1:
                token = Categorical(logits=outputs.logits_obs).sample()
                obs_tokens.append(token)

        obs_tokens = torch.cat(obs_tokens, dim=1)        # (B, K)
        self.obs = self.decode_obs_tokens(obs_tokens)

        # Update buffers
        self.obs_tokens_buffer = torch.cat((self.obs_tokens_buffer, obs_tokens.unsqueeze(1)), dim=1)[:, -h:]
        
        # time limit
        self.ep_len += 1
        truncated = torch.ones_like(end) if self.ep_len >= self.horizon else torch.zeros_like(end)

        self.done_tracker.update(end, truncated)

        info = None

        return self.obs, rew, end, truncated, info
    
    def make_initial_condition_generator(self, data_loader):
        for batch in data_loader:
            yield batch.obs.to(self.device), batch.act.to(self.device)

    @torch.no_grad()
    def render(self):
        return Image.fromarray(self.obs.squeeze(0).permute(1, 2, 0).add(1).div(2).mul(255).byte().cpu().numpy()) # 1 c h w -> h w c
