from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .encoder import Encoder
from data import Batch
from .unet import UNet


def init_unet(model):
    for o in model.down_blocks:
        for p in o.resnets: p.conv2.weight.data.zero_()
        if o.downsamplers is not None:
            for p in o.downsamplers: torch.nn.init.orthogonal_(p.conv.weight)
    for o in model.up_blocks:
        for p in o.resnets: p.conv2.weight.data.zero_()
    model.conv_out.weight.data.zero_()


def add_dims(input, n): 
    return input.reshape(input.shape + (1,) * (n - input.ndim))


def sample_sigma(*size, device, loc=-0.4, scale=1.2, sigma_min=0.002, sigma_max=20):
    return (torch.randn(size, device=device) * scale + loc).exp().clip(sigma_min, sigma_max)


@dataclass
class WorldModelConfig:
    image_size: int
    image_channels: int
    num_actions: int
    num_steps_conditioning: int
    sigma_data: float
    sigma_offset_noise: float


class InnerModel(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.unet = UNet(
            in_channels=config.image_channels + config.num_steps_conditioning * config.image_channels,
            out_channels=config.image_channels,
            block_out_channels=(64, 64, 64, 64),
            norm_num_groups=8,
            add_attention=True, # still attn in middle block
            down_block_types=("DownBlock2D",) * 4,
            up_block_types=("UpBlock2D",) * 4,
            resnet_time_scale_shift="ada_group",
        )
        init_unet(self.unet)
        assert self.unet.temb_dim % config.num_steps_conditioning == 0
        self.act_emb = nn.Sequential(
            nn.Embedding(config.num_actions, self.unet.temb_dim // config.num_steps_conditioning),
            nn.Flatten(),  # b t e -> b (t e)
        )
        
    def forward(self, x, t, act):
        return self.unet(x, t, self.act_emb(act))


class Denoiser(nn.Module):
    def __init__(self, config: WorldModelConfig) -> None:
        super().__init__()
        self.config = config
        self.sigma_data = config.sigma_data
        self.sigma_offset_noise = config.sigma_offset_noise
        self.inner_model = InnerModel(config)

    def __repr__(self) -> str:
        return 'denoiser'

    def forward(self, noisy_next_obs, sigma, obs, act):
        c_in, _, _, c_noise = self._compute_conditioners(sigma, extra_dims=noisy_next_obs.ndim)
        model_input = torch.cat((obs / self.sigma_data, noisy_next_obs * c_in), dim=1)
        model_output = self.inner_model(model_input, c_noise, act)
        return model_output

    def compute_loss(self, batch: Batch):
        n = self.config.num_steps_conditioning
        assert batch.obs.size(1) == n + 1
        obs, next_obs, act, mask = batch.obs[:, :n], batch.obs[:, n], batch.act[:, :n], batch.mask_padding[:, n]
        
        b, t, c, h, w = obs.shape
        obs = obs.reshape(b, t * c, h, w)

        sigma = sample_sigma(next_obs.size(0), device=next_obs.device)
        _, c_out, c_skip, _ = self._compute_conditioners(sigma, extra_dims=next_obs.ndim)
        
        offset_noise = self.sigma_offset_noise * torch.randn(next_obs.size(0), next_obs.size(1), 1, 1, device=next_obs.device)
        noisy_next_obs = next_obs + offset_noise + torch.randn_like(next_obs) * add_dims(sigma, next_obs.ndim)
        
        model_output = self(noisy_next_obs, sigma, obs, act)

        target = (next_obs - c_skip * noisy_next_obs) / c_out
        loss = F.mse_loss(model_output[mask], target[mask])
        return loss, {'loss_denoising': loss.detach()}
        
    def denoise(self, noisy_next_obs, sigma, obs, act) -> torch.FloatTensor:
        _, c_out, c_skip, _ = self._compute_conditioners(sigma, extra_dims=noisy_next_obs.ndim)
        model_output = self(noisy_next_obs, sigma, obs, act)
        return model_output * c_out + noisy_next_obs * c_skip
    
    def _compute_conditioners(self, sigma: torch.Tensor, extra_dims=0):
        sigma = (sigma ** 2 + self.sigma_offset_noise ** 2).sqrt()
        c_in = 1 / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * c_skip.sqrt()
        c_noise = sigma.log() / 4
        return *(add_dims(c, extra_dims) for c in (c_in, c_out, c_skip)), c_noise


class RewardEndModel(nn.Module):
    def __init__(self, config: WorldModelConfig):
        super().__init__()
        self.encoder = Encoder(
            sample_size=config.image_size,
            in_channels=config.image_channels,
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D"),
            block_out_channels=(16, 16, 16, 16, 16),
            norm_num_groups=8,
            class_embed_type='identity'
        )
        self.act_emb = nn.Embedding(config.num_actions, self.encoder.time_embed_dim)
        self.lstm = nn.LSTM(256, 256, batch_first=True)
        self.head = nn.Sequential(nn.Linear(256, 256), nn.SiLU(), nn.Linear(256, 3 + 2, bias=False))
        
        for name, p in self.named_parameters():
            if "lstm" in name:
                if "weight_ih" in name:
                    nn.init.xavier_uniform_(p.data)
                elif "weight_hh" in name:
                    nn.init.orthogonal_(p.data)
                elif "bias_ih" in name:
                    p.data.fill_(0)
                    # Set forget-gate bias to 1
                    n = p.size(0)
                    p.data[(n // 4) : (n // 2)].fill_(1)
                elif "bias_hh" in name:
                    p.data.fill_(0)

    def __repr__(self) -> str:
        return 'rew_end_model'
    
    def forward(self, obs: torch.FloatTensor, act: torch.LongTensor, hx_cx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        b, t, c, h, w = obs.shape
        obs, act = obs.reshape(b * t, c, h, w), act.reshape(b * t)
        x = self.encoder(obs, self.act_emb(act))
        x = x.reshape(b, t, -1) # (b t) e h w -> b t (e h w)
        x, hx_cx = self.lstm(x, hx_cx)
        logits = self.head(x)
        return logits[:, :, :-2], logits[:, :, -2:], hx_cx
    
    def compute_loss(self, batch: Batch) -> torch.Tensor:
        mask = batch.mask_padding
        logits_rew, logits_end, _ = self(batch.obs, batch.act)
        loss_rew = F.cross_entropy(logits_rew[mask], target=batch.rew[mask].sign().long().add(1)) # reward clipped to {-1, 0, 1}
        loss_end = F.cross_entropy(logits_end[mask], target=batch.end[mask])
        return loss_rew + loss_end, {'loss_rew': loss_rew.detach(), 'loss_end': loss_end.detach()}
    

class WorldModel(nn.Module):
    def __init__(self, config: WorldModelConfig) -> None:
        super().__init__()
        self.config = config
        self.denoiser = Denoiser(config)
        self.rew_end_model = RewardEndModel(config)

    def __repr__(self) -> str:
        return 'world_model'

    @property
    def device(self):
        return self.denoiser.inner_model.unet.conv_in.weight.device
