from typing import Union, List, Optional

import torch
from torch import nn
from jaxtyping import Float
import einops
import numpy as np
from tqdm.auto import tqdm

from .rf import LatentRF2D
from .modular.layers import LayerNorm


class UnclipLatentRF2d(LatentRF2D):
    def __init__(
        self,
        img_embedder: nn.Module,
        d_img: int,
        d_t: int,
        c_dropout: float,
        checkpoint: str = None,
        remove_time_cond: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.d_img = d_img
        self.img_embedder = img_embedder
        self.c_dropout = c_dropout
        self.d_t = d_t
        self.img_proj = nn.Linear(d_img, d_t, bias=False)
        self.norm = LayerNorm(d_t)
        self.checkpoint = checkpoint

        if self.checkpoint is not None:
            state_dict = torch.load(self.checkpoint, map_location="cpu")
            state_dict = self.map_state_dict_keys(state_dict)
            if remove_time_cond:
                state_dict = self.remove_time_cond_keys(state_dict)
            self.load_state_dict(state_dict, strict=False)  # True
            print(f"Loaded UnCLIP checkpoint from {self.checkpoint}")

    def map_state_dict_keys(self, state_dict):
        # Default mapping used only in UnclipLatentRF2d
        return {
            (
                k.replace("img_embedder.model.", "img_embedder.model._orig_mod.")
                if k.startswith("img_embedder.model.")
                else k
            ): v
            for k, v in state_dict.items()
        }

    def remove_time_cond_keys(self, state_dict):
        keys = [
            "unet.mid_level.0.self_attn.norm.linear.weight",
            "unet.mid_level.0.ff.norm.linear.weight",
            "unet.mid_level.1.self_attn.norm.linear.weight",
            "unet.mid_level.1.ff.norm.linear.weight",
            "unet.mid_level.2.self_attn.norm.linear.weight",
            "unet.mid_level.2.ff.norm.linear.weight",
            "unet.mid_level.3.self_attn.norm.linear.weight",
            "unet.mid_level.3.ff.norm.linear.weight",
            "unet.mid_level.4.self_attn.norm.linear.weight",
            "unet.mid_level.4.ff.norm.linear.weight",
            "unet.mid_level.5.self_attn.norm.linear.weight",
            "unet.mid_level.5.ff.norm.linear.weight",
            "unet.mid_level.6.self_attn.norm.linear.weight",
            "unet.mid_level.6.ff.norm.linear.weight",
            "unet.mid_level.7.self_attn.norm.linear.weight",
            "unet.mid_level.7.ff.norm.linear.weight",
            "unet.mid_level.8.self_attn.norm.linear.weight",
            "unet.mid_level.8.ff.norm.linear.weight",
            "unet.mid_level.9.self_attn.norm.linear.weight",
            "unet.mid_level.9.ff.norm.linear.weight",
            "unet.mid_level.10.self_attn.norm.linear.weight",
            "unet.mid_level.10.ff.norm.linear.weight",
            "unet.mid_level.11.self_attn.norm.linear.weight",
            "unet.mid_level.11.ff.norm.linear.weight",
            "unet.mid_level.12.self_attn.norm.linear.weight",
            "unet.mid_level.12.ff.norm.linear.weight",
            "unet.mid_level.13.self_attn.norm.linear.weight",
            "unet.mid_level.13.ff.norm.linear.weight",
            "unet.mid_level.14.self_attn.norm.linear.weight",
            "unet.mid_level.14.ff.norm.linear.weight",
            "unet.mid_level.15.self_attn.norm.linear.weight",
            "unet.mid_level.15.ff.norm.linear.weight",
            "unet.mid_level.16.self_attn.norm.linear.weight",
            "unet.mid_level.16.ff.norm.linear.weight",
            "unet.mid_level.17.self_attn.norm.linear.weight",
            "unet.mid_level.17.ff.norm.linear.weight",
            "unet.mid_level.18.self_attn.norm.linear.weight",
            "unet.mid_level.18.ff.norm.linear.weight",
            "unet.mid_level.19.self_attn.norm.linear.weight",
            "unet.mid_level.19.ff.norm.linear.weight",
            "unet.mid_level.20.self_attn.norm.linear.weight",
            "unet.mid_level.20.ff.norm.linear.weight",
            "unet.mid_level.21.self_attn.norm.linear.weight",
            "unet.mid_level.21.ff.norm.linear.weight",
            "unet.mid_level.22.self_attn.norm.linear.weight",
            "unet.mid_level.22.ff.norm.linear.weight",
            "unet.mid_level.23.self_attn.norm.linear.weight",
            "unet.mid_level.23.ff.norm.linear.weight",
            "unet.mid_level.24.self_attn.norm.linear.weight",
            "unet.mid_level.24.ff.norm.linear.weight",
            "unet.mid_level.25.self_attn.norm.linear.weight",
            "unet.mid_level.25.ff.norm.linear.weight",
            "unet.mid_level.26.self_attn.norm.linear.weight",
            "unet.mid_level.26.ff.norm.linear.weight",
            "unet.mid_level.27.self_attn.norm.linear.weight",
            "unet.mid_level.27.ff.norm.linear.weight",
            "time_emb.weight",
            "time_in_proj.weight",
            "mapping.in_norm.scale",
            "mapping.blocks.0.norm.scale",
            "mapping.blocks.0.up_proj.weight",
            "mapping.blocks.0.down_proj.weight",
            "mapping.blocks.1.norm.scale",
            "mapping.blocks.1.up_proj.weight",
            "mapping.blocks.1.down_proj.weight",
            "mapping.out_norm.scale",
            "img_proj.weight",
            "norm.weight",
            "norm.bias",
        ]
        for k in keys:
            if k in state_dict:
                del state_dict[k]
        return state_dict

    def get_conditioning(self, t: Float[torch.Tensor, "b"], c_img: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
        if self.time_cond_type == "sigma":
            c_noise = torch.log(t) / 4
        elif self.time_cond_type == "rf_t":
            c_noise = t
        else:
            raise NotImplementedError(f'Unknown time conditioning type "{self.time_cond_type}".')

        time_emb = self.time_in_proj(self.time_emb(c_noise[..., None]))

        img_emb = torch.zeros_like(time_emb)
        keep_idx = torch.rand(time_emb.shape[0]) >= self.c_dropout
        keep_any = keep_idx.any()
        if keep_any:
            c_img_emb = c_img[keep_idx]
            c_img_emb = self.img_embedder(c_img_emb)
            c_img_emb = self.img_proj(c_img_emb)
            c_img_emb = self.norm(c_img_emb)
            img_emb[keep_idx] = c_img_emb

        cond_time = self.mapping(time_emb + img_emb)

        return {"cond_norm": cond_time}

    def forward(self, x: Float[torch.Tensor, "b ..."], **data_kwargs) -> Float[torch.Tensor, "b"]:
        return super().forward(x=x, c_img=x, **data_kwargs)

    @torch.no_grad()
    def validate(
        self,
        dataloader_val: "torch.utils.data.DataLoader",
        global_rank: int,
        global_samples: int,
        max_steps: Optional[int],
        device,
        dtype,
        wandb,
        monitor: "deepspeed.monitor.monitor.MonitorMaster",
        ema_model: nn.Module | None,
    ) -> None:
        c_dropout = self.c_dropout
        self.c_dropout = 0.0
        if self.val_shape is None:
            if global_rank == 0:
                print("Skipping validation for LatentRF model as val_shape is not provided.")
            return
        if global_rank == 0:
            for i, val_batch in enumerate(
                tqdm(
                    dataloader_val,
                    desc=f"Validating",
                    disable=(global_rank != 0),
                    total=max_steps,
                )
            ):
                sample_tensor = torch.randn((1, *self.val_shape), dtype=dtype, generator=torch.manual_seed(i)).to(
                    device
                )
                c_img = val_batch["x"][:1].to(device=device, dtype=dtype)
                res = einops.rearrange(
                    self.sample(z=sample_tensor, c_img=c_img)[0],
                    "c h w -> h w c",
                )
                c_img = self.ae.decode(self.ae.encode(c_img))
                img = einops.rearrange(c_img[0], "c h w -> h w c")
                res = torch.cat([img, res], dim=0)
                wandb.log(
                    {
                        f"Val/Vis/sample_{i}": wandb.Image(
                            ((res.clip(-1, 1) / 2 + 0.5) * 255).round().float().cpu().numpy().astype(np.uint8)
                        )
                    },
                    step=global_samples,
                )

                if max_steps is not None and i + 1 >= max_steps:
                    break

        self.c_dropout = c_dropout
