from typing import Dict, Optional, Tuple

from PIL import Image
import torch
from torch.distributions.categorical import Categorical

from envs.env import DoneTracker
from models.diffuser import WorldModel
from utils import set_seed


class WorldModelEnv:
    def __init__(self, wm: WorldModel, data_loader, horizon: int) -> None:
        self.wm = wm
        self.horizon = horizon
        self.batch_size = data_loader.batch_sampler.batch_size
    
        self.initial_condition_generator = self.make_initial_condition_generator(data_loader)

        self.done_tracker = DoneTracker(self.batch_size, self.device)
        
        self.obs, self.act = None, None
        self.ep_len = None
    
    @property
    def config(self):
        return self.wm.config

    @property
    def all_done(self):
        return self.done_tracker.all_done
    
    @property
    def device(self): 
        return self.wm.device

    @property
    def is_alive(self):
        return self.done_tracker.is_alive

    @property
    def num_actions(self):
        return self.config.num_actions

    def seed(self, x):
        set_seed(x)

    @torch.no_grad()
    def reset(self) -> torch.FloatTensor:
        self.done_tracker.reset()
        obs, act = next(self.initial_condition_generator)
        return self.reset_from(obs, act)

    @torch.no_grad()
    def reset_from(self, obs: torch.FloatTensor, act: torch.LongTensor):
        assert obs.size(1) == act.size(1) == self.config.num_steps_conditioning
        self.ep_len = 0
        self.hx_cx = None
        self.obs = obs
        self.act = act
        return obs[:, -1], None
    
    @torch.no_grad()
    def step(self, act: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:

        self.act[:, -1] = act

        next_obs = self._predict_next_obs(self.obs, self.act)
        rew, end = self._predict_rew_end(self.obs, self.act)

        self.obs = self.obs.roll(-1, dims=1)
        self.act = self.act.roll(-1, dims=1)
        self.obs[:, -1] = next_obs

        # time limit
        self.ep_len += 1
        truncated = torch.ones_like(end) if self.ep_len >= self.horizon else torch.zeros_like(end)

        self.done_tracker.update(end, truncated)

        info = {}

        return next_obs, rew, end, truncated, info

    @torch.no_grad()
    def _predict_next_obs(self, obs: torch.FloatTensor, act: torch.LongTensor, sigma: float = 5, generator: Optional[torch.Generator] = None):
        b, t, c, h, w = obs.shape
        obs = obs.reshape(b, t * c, h, w)
        c = self.config 
        noisy_next_obs = torch.randn(obs.size(0), c.image_channels, c.image_size, c.image_size, generator=generator, device=obs.device)
        sigma = torch.tensor([sigma], device=obs.device)
        next_obs = self.wm.denoiser.denoise(noisy_next_obs, sigma, obs, act)
        return next_obs.clamp(-1, 1)

    @torch.no_grad()
    def _predict_rew_end(self, obs: torch.FloatTensor, act: torch.LongTensor):
        logits_rew, logits_end, self.hx_cx = self.wm.rew_end_model(obs[:, -1:], act[:, -1:], self.hx_cx)
        rew = Categorical(logits=logits_rew).sample().squeeze(1).sub(1).float() # reward clipped to {-1, 0, 1}
        end = Categorical(logits=logits_end).sample().squeeze(1)
        return rew, end

    def make_initial_condition_generator(self, data_loader):
        for batch in data_loader:
            yield batch.obs.to(self.device), batch.act.to(self.device)

    @torch.no_grad()
    def render(self):
        return Image.fromarray(self.obs[:, -1].squeeze(0).permute(1, 2, 0).add(1).div(2).mul(255).byte().cpu().numpy()) # 1 c h w -> h w c
