from typing import Union, List, Optional

import torch
from torch import nn
from jaxtyping import Float
from einops import rearrange, repeat
from tqdm.auto import tqdm
from contextlib import nullcontext, contextmanager

from ..rf import LatentRF2D
from ..unclip import UnclipLatentRF2d
from ..modular.layers.rope import make_axial_pos_2d, make_axial_pos_3d
from ..modular.layers import LayerNorm
import numpy as np

from diffusion.utils import transform, sample_mask
import einops
from PIL import Image
import random
import gc


# This only works with zero stages < 3
@contextmanager
@torch.no_grad()
def use_distributed_ema(module: nn.Module, rank: int, ema_module: nn.Module | None):
    orig_weights = {k: v.detach().clone() for k, v in module.state_dict().items()}
    if rank == 0:
        assert not ema_module is None
        ema_weights = ema_module.state_dict()
        torch.distributed.scatter_object_list([None], [ema_weights] * torch.distributed.get_world_size(), src=0)
    else:
        l = [None]
        torch.distributed.scatter_object_list(l, None, src=0)  # if rank != 0, it receives weights from rank 0
        ema_weights = l[0]

    for k, p in module.named_parameters():
        p.data.copy_(ema_weights[k].data.to(p.data))
    try:
        yield module
    finally:
        for k, p in module.named_parameters():
            p.data.copy_(orig_weights[k].data)


@contextmanager
def freeze(module):
    """Temporarily turn off grads for `module`."""
    try:
        orig = [p.requires_grad for p in module.parameters()]
        for p in module.parameters():
            p.requires_grad_(False)
        yield
    finally:
        for p, flag in zip(module.parameters(), orig):
            p.requires_grad_(flag)


class MotionImageReconstructor(UnclipLatentRF2d):
    def __init__(
        self,
        d_motion: int,
        x_cond_dropout: bool = True,
        **kwargs,
    ):
        self.no_start_frame = kwargs.pop("no_start_frame", False)
        super().__init__(**kwargs)
        self.d_motion = d_motion
        self.x_cond_dropout = x_cond_dropout
        self.motion_proj = nn.Linear(d_motion, self.d_t)
        self.motion_norm = LayerNorm(self.d_t)

        if self.checkpoint is not None:
            torch.nn.init.eye_(self.unet.mid_merge.proj.weight)

    def map_state_dict_keys(self, state_dict):
        # Start with base class mapping
        state_dict = super().map_state_dict_keys(state_dict)
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("unet.mid_split.norm."):
                key = key.replace("unet.mid_split.norm.", "unet.mid_split.post_proj.norm.")
            elif key.startswith("unet.mid_split.proj."):
                key = key.replace("unet.mid_split.proj.", "unet.mid_split.post_proj.proj.")
            elif key == "unet.mid_merge.proj.weight":
                key = "unet.mid_merge.pre_proj.proj.weight"
            new_state_dict[key] = value
        return new_state_dict

    def get_conditioning(
        self, t: Float[torch.Tensor, "b"], c_img: Float[torch.Tensor, "b c h w"], c_motion: Float[torch.Tensor, "b d"]
    ) -> dict[str, torch.Tensor]:
        if self.time_cond_type == "sigma":
            c_noise = torch.log(t) / 4
        elif self.time_cond_type == "rf_t":
            c_noise = t
        else:
            raise NotImplementedError(f'Unknown time conditioning type "{self.time_cond_type}".')

        time_emb = self.time_in_proj(self.time_emb(c_noise[..., None]))

        motion_emb = torch.zeros_like(time_emb)
        keep_idx = torch.rand(time_emb.shape[0]) >= self.c_dropout
        if keep_idx.any():
            motion_emb[keep_idx] = self.motion_norm(self.motion_proj(c_motion[keep_idx]))

        image_emb = rearrange(self.ae.encode(c_img), "b c ... -> b ... c")
        if self.no_start_frame:
            image_emb = torch.zeros_like(image_emb)  # don't detach
        else:
            if self.x_cond_dropout:
                drop_idx = torch.rand(image_emb.shape[0]) < self.c_dropout
                image_emb[drop_idx] = torch.zeros_like(image_emb[drop_idx])
                motion_emb.detach()  # don't backprop through motion embeddings

        cond_time = self.mapping(time_emb + motion_emb)

        return {
            "cond_norm": cond_time,
            "x_cond": image_emb,
        }

    def get_unconditional_conditioning(
        self,
        t: Float[torch.Tensor, "b"],
        x: Float[torch.Tensor, "b c h w"],
        c_motion: Float[torch.Tensor, "b d"] | None = None,
    ) -> dict[str, torch.Tensor]:
        if self.time_cond_type == "sigma":
            c_noise = torch.log(t) / 4
        elif self.time_cond_type == "rf_t":
            c_noise = t
        else:
            raise NotImplementedError(f'Unknown time conditioning type "{self.time_cond_type}".')

        time_emb = self.time_in_proj(self.time_emb(c_noise[..., None]))

        if c_motion is not None:
            motion_emb = self.motion_norm(self.motion_proj(c_motion))
        else:
            motion_emb = torch.zeros_like(time_emb)

        cond_time = self.mapping(time_emb + motion_emb)

        return {
            "cond_norm": cond_time,
            "x_cond": rearrange(torch.zeros_like(x), "b c ... -> b ... c"),
        }

    def forward(
        self,
        x: Float[torch.Tensor, "b c h w"],
        c_img: Float[torch.Tensor, "b c h w"],
        c_motion: Float[torch.Tensor, "b d"],
    ) -> Float[torch.Tensor, "b"]:
        return LatentRF2D.forward(self, x=x, c_img=c_img, c_motion=c_motion)


class MotionRepresentationLearner(nn.Module):
    def __init__(
        self,
        sequence_embedder: nn.Module,
        reconstructor: nn.Module,
        d_motion_cls: int,
        num_motion_tokens_per_frame: int = 1,
        frame_embedder: nn.Module | None = None,
        lambda_dis: float = 0.0,
        reconstruction_skip: int = 1,
        params_masking: dict = {},
        params_geometric_augmentations: dict = {},
        params_photometric_augmentations: dict = {},
        size=256,
    ):
        super().__init__()
        self.frame_embedder = frame_embedder
        self.sequence_embedder = sequence_embedder
        self.motion_cls_token = nn.Parameter(torch.randn(num_motion_tokens_per_frame, d_motion_cls))
        self.num_motion_tokens_per_frame = num_motion_tokens_per_frame
        self.reconstructor = reconstructor
        self.reconstruction_skip = reconstruction_skip
        self.lambda_dis = lambda_dis

        self.params_geometric_augmentations = params_geometric_augmentations
        self.params_photometric_augmentations = params_photometric_augmentations
        self.params_masking = params_masking

        self.size = size

    def get_motion_embeddings(
        self, x: Float[torch.Tensor, "b c t h w"], training_enabled: bool = True
    ) -> Float[torch.Tensor, "b d t"]:
        kwargs = {}
        B, C, T, H, W = x.shape

        # apply augmentations to the input video
        if training_enabled:
            x = torch.stack(
                [
                    einops.rearrange(
                        transform(
                            einops.rearrange(video, "c t h w -> t c h w"),
                            self.params_photometric_augmentations,
                            self.params_geometric_augmentations,
                        ),
                        "t c h w -> c t h w",
                    )
                    for video in x
                ],
                dim=0,
            ).to(dtype=x.dtype)

        if self.frame_embedder is not None:
            pos_2d: Float[torch.Tensor, "b d t h w"] = (
                make_axial_pos_2d(H, W, device=x.device).view(1, 1, H, W, -1).expand(B, T, -1, -1, -1).movedim(-1, 1)
            )

            frame_embeddings: Float[torch.Tensor, "b d t h' w'"] = rearrange(
                self.frame_embedder(
                    rearrange(x, "b c t h w -> (b t) c h w"),
                    # pos=rearrange(pos_2d, "b d t h w -> (b t) d h w"),
                ),
                "(b t) c h w -> b t c h w",
                b=B,
            )

            # only do masking during training
            if training_enabled and self.params_masking.enabled:
                B_, T_, C_, H_, W_ = frame_embeddings.shape
                keep_indices: List[Float[torch.Tensor, "n"]] = sample_mask(
                    mask_ratio=self.params_masking.mask_ratio,
                    mask_type=self.params_masking.mask_type,
                    fs=T_,
                    hs=H_,
                    ws=W_,
                    batch_size=B_,
                )
                kwargs["keep_indices"] = keep_indices

            frame_embeddings = rearrange(frame_embeddings, "b t c h w -> b c t h w")
        else:
            # if no frame embedder is provided, use the input as the frame embeddings
            frame_embeddings = x

        h, w = frame_embeddings.shape[-2:]
        pos_3d: Float[torch.Tensor, "b d t h w"] = (
            make_axial_pos_3d(T, h, w, device=x.device).view(1, T, h, w, -1).expand(B, -1, -1, -1, -1)
        ).movedim(-1, 1)

        N = self.num_motion_tokens_per_frame
        cls_tokens: Float[torch.Tensor, "b (t n) d"] = self.motion_cls_token.view(1, N, -1).repeat(B, T, 1)
        cls_pos: Float[torch.Tensor, "b t d"] = repeat(
            rearrange(pos_3d.mean(dim=(-2, -1)), "b d t -> b t d"), "b t d -> b (t n) d", n=N
        )

        return self.sequence_embedder(
            frame_embeddings,
            pos=pos_3d,
            x_extra=cls_tokens,
            pos_extra=cls_pos,
            **kwargs,
        )

    def forward(
        self, x: Float[torch.Tensor, "b c t h w"], **data_kwargs
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        motion_embeddings = self.get_motion_embeddings(x, training_enabled=True)[..., : -self.reconstruction_skip]

        loss = self.reconstructor(
            x=rearrange(x[:, :, self.reconstruction_skip :], "b c t h w -> (b t) c h w"),  # x_t+1
            c_img=rearrange(x[:, :, : -self.reconstruction_skip], "b c t h w -> (b t) c h w"),  # x_t
            c_motion=rearrange(motion_embeddings, "b d t -> (b t) d"),  # motion embeddings of x_t
        )
        metrics = {}
        return loss, metrics

    @torch.no_grad()
    def validate(
        self,
        dataloader_val: "torch.utils.data.DataLoader",
        dataloader_train: "torch.utils.data.DataLoader",
        global_rank: int,
        global_samples: int,
        max_steps: int | None,
        device,
        dtype,
        wandb,
        monitor: "deepspeed.monitor.monitor.MonitorMaster",
        ema_model: nn.Module | None,
    ) -> None:
        # Only rank 0 builds the samples_by_class dictionary.
        if global_rank == 0:
            samples_by_class = {}
            for batch in dataloader_val:
                videos = batch["x"]
                labels = batch["cls"]
                for video, label in zip(videos, labels):
                    samples_by_class.setdefault(label, []).append(video)
                # Stop if at least 8 classes have at least 4 samples
                if len(samples_by_class) >= 8 and all(len(v) >= 4 for v in samples_by_class.values()):
                    break

        def get_random_class_pair_batches(samples_by_class, batch_size=8):
            valid_classes = [label for label, vids in samples_by_class.items() if len(vids) >= 2]
            if global_rank == 0:
                print(f"Valid classes for sampling: {valid_classes}", flush=True)
            if len(valid_classes) < batch_size:
                if global_rank == 0:
                    print(f"Not enough valid classes to sample from: {valid_classes}", flush=True)
                return None, None
            class_labels = random.sample(valid_classes, k=batch_size)
            source_videos = []
            target_videos = []
            for label in class_labels:
                vids = samples_by_class[label]
                source_videos.append(random.choice(vids).clone())
                target_videos.append(random.choice(vids).clone())
            try:
                source_videos = torch.stack(source_videos, dim=0)
                target_videos = torch.stack(target_videos, dim=0)
            except Exception as e:
                if global_rank == 0:
                    print(f"Failed to stack videos for labels {class_labels}, error: {e}", flush=True)
                return None, None
            return source_videos, target_videos

        # free GPU memory
        torch.cuda.empty_cache()

        # -------------------------------
        # Motion Transfer Validation Loop
        # -------------------------------
        if global_rank == 0:
            for i in tqdm(range(max_steps), desc="Motion Transfer Validation", disable=False):
                source_batch, target_batch = get_random_class_pair_batches(samples_by_class, batch_size=8)
                if source_batch is None:
                    print(f"Skipping step {i} due to insufficient valid classes.", flush=True)
                    continue

                source_batch = source_batch.to(device=device, dtype=dtype)
                target_batch = target_batch.to(device=device, dtype=dtype)

                source_motion = self.get_motion_embeddings(source_batch, training_enabled=False)[
                    ..., : -self.reconstruction_skip
                ]
                source_motion = rearrange(source_motion, "b d t -> b t d")

                num_motion_frames = source_motion.shape[1]
                b, c, t, h, w = target_batch.shape

                initial_frame = target_batch[:, :, :1]  # (b, c, 1, h, w)
                gen_frames = [initial_frame]

                # with (use_distributed_ema(self, rank=global_rank, ema_module=ema_model)
                #    if ema_model is not None else nullcontext()):
                for time_idx in range(1, num_motion_frames):
                    cond_frame = gen_frames[-1]
                    cond_img = cond_frame.squeeze(2)
                    latent = self.reconstructor.ae.encode(cond_img)
                    latent_shape = latent.shape[1:]
                    sample_tensor = torch.randn((b, *latent_shape), dtype=dtype, device=device)
                    motion_emb = source_motion[:, time_idx, :]
                    pred_frame = self.reconstructor.sample(
                        z=sample_tensor,
                        c_img=cond_img,
                        c_motion=motion_emb,
                    )
                    pred_frame = pred_frame.unsqueeze(2)
                    gen_frames.append(pred_frame)

                generated_video = torch.cat(gen_frames, dim=2)
                predicted_vid = einops.rearrange(generated_video, "b c t h w -> t h (b w) c")
                num_pred_frames = predicted_vid.shape[0]
                source_vid = einops.rearrange(source_batch[:, :, :num_pred_frames], "b c t h w -> t h (b w) c")
                target_vid = einops.rearrange(
                    target_batch[:, :, 0:1].expand(-1, -1, num_pred_frames, -1, -1),
                    "b c t h w -> t h (b w) c",
                )
                combined = torch.cat([source_vid, target_vid, predicted_vid], dim=1)
                combined_uint8 = ((combined.clip(-1, 1) / 2 + 0.5) * 255).round().float().to(torch.uint8)
                combined_uint8 = combined_uint8.cpu().numpy()
                combined_uint8 = np.transpose(combined_uint8, (0, 3, 1, 2))

                wandb.log(
                    {f"Val/Vis/MotionTransfer_{i}": wandb.Video(combined_uint8, fps=6, format="gif")},
                    step=global_samples,
                )

        # 3. final sync so the other ranks can continue
        torch.distributed.barrier()
        torch.cuda.empty_cache()

        # -------------------------------
        # Reconstruction Validation Loop
        # -------------------------------
        torch.distributed.barrier()  # keep everybody in sync up front
        if global_rank == 0:
            pbar = tqdm(range(max_steps), desc="Reconstruction Validation", disable=False)

            data_iter = iter(dataloader_val)
            for i in pbar:
                try:
                    val_batch = next(data_iter)
                except StopIteration:
                    break

                x = val_batch["x"].to(device=device, dtype=dtype)

                motion_embeddings = self.get_motion_embeddings(x, training_enabled=False)[
                    ..., : -self.reconstruction_skip
                ]
                motion_embeddings = rearrange(motion_embeddings, "b d t -> b t d")
                num_motion_frames = motion_embeddings.shape[1]

                initial_frame = x[:, :, 0:1]
                gen_frames = [initial_frame]

                for time_idx in range(1, num_motion_frames):
                    cond_img = gen_frames[-1].squeeze(2)
                    latent = self.reconstructor.ae.encode(cond_img)
                    latent_shape = latent.shape[1:]
                    sample_tensor = torch.randn((x.shape[0], *latent_shape), dtype=dtype, device=device)
                    motion_emb = motion_embeddings[:, time_idx, :]

                    pred_frame = self.reconstructor.sample(
                        z=sample_tensor,
                        c_img=cond_img,
                        c_motion=motion_emb,
                    ).unsqueeze(2)

                    gen_frames.append(pred_frame)

                generated_video = torch.cat(gen_frames, dim=2)
                predicted_vid = einops.rearrange(generated_video, "b c t h w -> t h (b w) c")
                num_pred_frames = predicted_vid.shape[0]
                original_vid = einops.rearrange(x[:, :, :num_pred_frames], "b c t h w -> t h (b w) c")
                combined = torch.cat([original_vid, predicted_vid], dim=1)
                combined_uint8 = ((combined.clamp_(-1, 1) / 2 + 0.5) * 255).round().to(torch.uint8).cpu().numpy()
                combined_uint8 = np.transpose(combined_uint8, (0, 3, 1, 2))

                wandb.log(
                    {f"Val/Vis/AR_Reconstruction_{i}": wandb.Video(combined_uint8, fps=6, format="gif")},
                    step=global_samples,
                )

        # 3. final sync so the other ranks can continue
        torch.distributed.barrier()

        # -------------------------------
        # Visualizing Train Frames Loop
        # -------------------------------
        for i, train_batch in enumerate(
            tqdm(dataloader_train, desc="Visualizing Train Frames", total=max_steps, disable=False)
        ):
            try:
                if global_rank == 0:
                    x_orig = train_batch["x"].to(device=device, dtype=dtype)
                    x_augmented = torch.stack(
                        [
                            einops.rearrange(
                                transform(
                                    einops.rearrange(video, "c t h w -> t c h w"),
                                    self.params_photometric_augmentations,
                                    self.params_geometric_augmentations,
                                ),
                                "t c h w -> c t h w",
                            )
                            for video in x_orig
                        ],
                        dim=0,
                    ).to(dtype=x_orig.dtype)
                    untransformed_vid = einops.rearrange(x_orig, "b c t h w -> t h (b w) c")
                    transformed_vid = einops.rearrange(x_augmented, "b c t h w -> t h (b w) c")
                    combined = torch.cat([untransformed_vid, transformed_vid], dim=1)
                    combined_uint8 = ((combined.clip(-1, 1) / 2 + 0.5) * 255).round().float().to(torch.uint8)
                    combined_uint8 = combined_uint8.cpu().numpy()
                    combined_uint8 = np.transpose(combined_uint8, (0, 3, 1, 2))

                    wandb.log(
                        {f"Val/Vis/TrainFrames_{i}": wandb.Video(combined_uint8, fps=6, format="gif")},
                        step=global_samples,
                    )
            except Exception as e:
                print(f"Exception at Visualizing Train Frames step {i}: {e}", flush=True)
            finally:
                torch.distributed.barrier()
                if max_steps is not None and i + 1 >= max_steps:
                    break
