from dataclasses import dataclass

import torch
import torch.nn.functional as F

import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.geometry.base import BaseImplicitGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.renderers.base import VolumeRenderer
from threestudio.utils.GAN.discriminator import NLayerDiscriminator, weights_init
from threestudio.utils.GAN.distribution import DiagonalGaussianDistribution
from threestudio.utils.GAN.mobilenet import MobileNetV3 as GlobalEncoder
from threestudio.utils.GAN.vae import Decoder as Generator
from threestudio.utils.GAN.vae import Encoder as LocalEncoder
from threestudio.utils.typing import *


@threestudio.register("gan-volume-renderer")
class GANVolumeRenderer(VolumeRenderer):
    @dataclass
    class Config(VolumeRenderer.Config):
        base_renderer_type: str = ""
        base_renderer: Optional[VolumeRenderer.Config] = None

    cfg: Config

    def configure(
        self,
        geometry: BaseImplicitGeometry,
        material: BaseMaterial,
        background: BaseBackground,
    ) -> None:
        self.base_renderer = threestudio.find(self.cfg.base_renderer_type)(
            self.cfg.base_renderer,
            geometry=geometry,
            material=material,
            background=background,
        )
        self.ch_mult = [1, 2, 4]
        self.generator = Generator(
            ch=64,
            out_ch=3,
            ch_mult=self.ch_mult,
            num_res_blocks=1,
            attn_resolutions=[],
            dropout=0.0,
            resamp_with_conv=True,
            in_channels=7,
            resolution=512,
            z_channels=4,
        )
        self.local_encoder = LocalEncoder(
            ch=32,
            out_ch=3,
            ch_mult=self.ch_mult,
            num_res_blocks=1,
            attn_resolutions=[],
            dropout=0.0,
            resamp_with_conv=True,
            in_channels=3,
            resolution=512,
            z_channels=4,
        )
        self.global_encoder = GlobalEncoder(n_class=64)
        self.discriminator = NLayerDiscriminator(
            input_nc=3, n_layers=3, use_actnorm=False, ndf=64
        ).apply(weights_init)

    def forward(
        self,
        rays_o: Float[Tensor, "B H W 3"],
        rays_d: Float[Tensor, "B H W 3"],
        light_positions: Float[Tensor, "B 3"],
        bg_color: Optional[Tensor] = None,
        gt_rgb: Float[Tensor, "B H W 3"] = None,
        multi_level_guidance: Bool = False,
        **kwargs
    ) -> Dict[str, Float[Tensor, "..."]]:
        B, H, W, _ = rays_o.shape
        if gt_rgb is not None and multi_level_guidance:
            generator_level = torch.randint(0, 3, (1,)).item()
            interval_x = torch.randint(0, 8, (1,)).item()
            interval_y = torch.randint(0, 8, (1,)).item()
            int_rays_o = rays_o[:, interval_y::8, interval_x::8]
            int_rays_d = rays_d[:, interval_y::8, interval_x::8]
            out = self.base_renderer(
                int_rays_o, int_rays_d, light_positions, bg_color, **kwargs
            )
            comp_int_rgb = out["comp_rgb"][..., :3]
            comp_gt_rgb = gt_rgb[:, interval_y::8, interval_x::8]
        else:
            generator_level = 0
        scale_ratio = 2 ** (len(self.ch_mult) - 1)
        rays_o = torch.nn.functional.interpolate(
            rays_o.permute(0, 3, 1, 2),
            (H // scale_ratio, W // scale_ratio),
            mode="bilinear",
        ).permute(0, 2, 3, 1)
        rays_d = torch.nn.functional.interpolate(
            rays_d.permute(0, 3, 1, 2),
            (H // scale_ratio, W // scale_ratio),
            mode="bilinear",
        ).permute(0, 2, 3, 1)

        out = self.base_renderer(rays_o, rays_d, light_positions, bg_color, **kwargs)
        comp_rgb = out["comp_rgb"][..., :3]
        latent = out["comp_rgb"][..., 3:]
        out["comp_lr_rgb"] = comp_rgb.clone()

        posterior = DiagonalGaussianDistribution(latent.permute(0, 3, 1, 2))
        if multi_level_guidance:
            z_map = posterior.sample()
        else:
            z_map = posterior.mode()
        lr_rgb = comp_rgb.permute(0, 3, 1, 2)

        if generator_level == 0:
            g_code_rgb = self.global_encoder(F.interpolate(lr_rgb, (224, 224)))
            comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
        elif generator_level == 1:
            g_code_rgb = self.global_encoder(
                F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224))
            )
            comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
        elif generator_level == 2:
            g_code_rgb = self.global_encoder(
                F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224))
            )
            l_code_rgb = self.local_encoder(gt_rgb.permute(0, 3, 1, 2))
            posterior = DiagonalGaussianDistribution(l_code_rgb)
            z_map = posterior.sample()
            comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)

        comp_rgb = F.interpolate(comp_rgb.permute(0, 3, 1, 2), (H, W), mode="bilinear")
        comp_gan_rgb = F.interpolate(comp_gan_rgb, (H, W), mode="bilinear")
        out.update(
            {
                "posterior": posterior,
                "comp_gan_rgb": comp_gan_rgb.permute(0, 2, 3, 1),
                "comp_rgb": comp_rgb.permute(0, 2, 3, 1),
                "generator_level": generator_level,
            }
        )

        if gt_rgb is not None and multi_level_guidance:
            out.update({"comp_int_rgb": comp_int_rgb, "comp_gt_rgb": comp_gt_rgb})
        return out

    def update_step(
        self, epoch: int, global_step: int, on_load_weights: bool = False
    ) -> None:
        self.base_renderer.update_step(epoch, global_step, on_load_weights)

    def train(self, mode=True):
        return self.base_renderer.train(mode)

    def eval(self):
        return self.base_renderer.eval()
