import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union, List, Dict, Any

from stream3r.layers import PatchEmbed
from stream3r.layers.block import Block
from stream3r.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from stream3r.layers.vision_transformer import (
    vit_small,
    vit_base,
    vit_large,
    vit_giant2,
)

logger = logging.getLogger(__name__)

_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]


class Aggregator(nn.Module):
    def __init__(
        self,
        img_size=518,
        patch_size=14,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4.0,
        num_register_tokens=4,
        block_fn=Block,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
        patch_embed="dinov2_vitl14_reg",
        aa_order=["spatial", "temporal"],
        aa_block_size=1,
        qk_norm=True,
        rope_freq=100,
        init_values=0.01,
    ):
        super().__init__()

        self.__build_patch_embed__(
            patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim
        )

        self.rope = (
            RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
        )
        self.position_getter = PositionGetter() if self.rope is not None else None

        self.frame_blocks = nn.ModuleList(
            [
                block_fn(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    proj_bias=proj_bias,
                    ffn_bias=ffn_bias,
                    init_values=init_values,
                    qk_norm=qk_norm,
                    rope=self.rope,
                )
                for _ in range(depth)
            ]
        )

        self.global_blocks = nn.ModuleList(
            [
                block_fn(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    proj_bias=proj_bias,
                    ffn_bias=ffn_bias,
                    init_values=init_values,
                    qk_norm=qk_norm,
                    rope=self.rope,
                )
                for _ in range(depth)
            ]
        )

        self.depth = depth
        self.aa_order = aa_order
        self.patch_size = patch_size
        self.aa_block_size = aa_block_size

        if self.depth % self.aa_block_size != 0:
            raise ValueError(
                f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})"
            )

        self.aa_block_num = self.depth // self.aa_block_size

        self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
        self.register_token = nn.Parameter(
            torch.randn(1, 2, num_register_tokens, embed_dim)
        )

        self.patch_start_idx = 1 + num_register_tokens

        nn.init.normal_(self.camera_token, std=1e-6)
        nn.init.normal_(self.register_token, std=1e-6)

        for name, value in (
            ("_resnet_mean", _RESNET_MEAN),
            ("_resnet_std", _RESNET_STD),
        ):
            self.register_buffer(
                name, torch.FloatTensor(value).reshape(1, 1, 3, 1, 1), persistent=False,
            )

    def __build_patch_embed__(
        self,
        patch_embed,
        img_size,
        patch_size,
        num_register_tokens,
        interpolate_antialias=True,
        interpolate_offset=0.0,
        block_chunks=0,
        init_values=1.0,
        embed_dim=1024,
    ):

        if "conv" in patch_embed:
            self.patch_embed = PatchEmbed(
                img_size=img_size,
                patch_size=patch_size,
                in_chans=3,
                embed_dim=embed_dim,
            )
        else:
            vit_models = {
                "dinov2_vitl14_reg": vit_large,
                "dinov2_vitb14_reg": vit_base,
                "dinov2_vits14_reg": vit_small,
                "dinov2_vitg2_reg": vit_giant2,
            }

            self.patch_embed = vit_models[patch_embed](
                img_size=img_size,
                patch_size=patch_size,
                num_register_tokens=num_register_tokens,
                interpolate_antialias=interpolate_antialias,
                interpolate_offset=interpolate_offset,
                block_chunks=block_chunks,
                init_values=init_values,
            )

            if hasattr(self.patch_embed, "mask_token"):
                self.patch_embed.mask_token.requires_grad_(False)

    def forward(
        self,
        images: torch.Tensor,
        past_key_values=None,
        use_cache=False,
        past_frame_idx=0,
    ) -> Tuple[List[torch.Tensor], int]:

        B, S, C_in, H, W = images.shape

        if use_cache and past_key_values[0] is not None:
            _, _, S_true, _, _ = past_key_values[0][0].shape
            S_true += 1
        else:
            S_true = S

        if use_cache and S > 1:
            print(f"Use KV cache expects S=1, got S={S}")

        if C_in != 3:
            raise ValueError(f"Expected 3 input channels, got {C_in}")

        images = (images - self._resnet_mean.to(images.device)) / self._resnet_std.to(
            images.device
        )

        images = images.reshape(B * S, C_in, H, W)
        patch_tokens = self.patch_embed(images)

        if isinstance(patch_tokens, dict):
            patch_tokens = patch_tokens["x_norm_patchtokens"]

        _, P, C = patch_tokens.shape

        if use_cache:
            camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
            camera_token = camera_token_full[-1:, :, :]

            register_token_full = slice_expand_and_flatten(
                self.register_token, B, S_true
            )
            register_token = register_token_full[-1:, :, :]
        else:
            camera_token = slice_expand_and_flatten(self.camera_token, B, S)
            register_token = slice_expand_and_flatten(self.register_token, B, S)

        tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)

        pos = None
        if self.rope is not None:
            pos = self.position_getter(
                B * S, H // self.patch_size, W // self.patch_size, device=images.device
            )

        if self.patch_start_idx > 0:

            pos = pos + 1
            pos_special = (
                torch.zeros(B * S, self.patch_start_idx, 2)
                .to(images.device)
                .to(pos.dtype)
            )
            pos = torch.cat([pos_special, pos], dim=1)

        _, P, C = tokens.shape

        spatial_idx = 0
        temporal_idx = 0
        output_list = []

        for _ in range(self.aa_block_num):
            for attn_type in self.aa_order:
                if attn_type == "spatial":
                    (
                        tokens,
                        spatial_idx,
                        spatial_intermediates,
                    ) = self._process_spatial_attention(
                        tokens, B, S, P, C, spatial_idx, pos=pos
                    )
                elif attn_type == "temporal":
                    if use_cache:
                        if past_key_values[temporal_idx] is not None:
                            k, v = past_key_values[temporal_idx]
                        (
                            tokens,
                            temporal_idx,
                            temporal_intermediates,
                            new_kv,
                        ) = self._process_temporal_attention(
                            tokens,
                            B,
                            S,
                            P,
                            C,
                            temporal_idx,
                            pos=pos,
                            past_key_values_block=past_key_values[temporal_idx]
                            if past_key_values[temporal_idx] is not None
                            else None,
                            use_cache=True,
                            past_frame_idx=past_frame_idx,
                        )
                        past_key_values[temporal_idx - 1] = new_kv
                    else:
                        (
                            tokens,
                            temporal_idx,
                            temporal_intermediates,
                        ) = self._process_temporal_attention(
                            tokens, B, S, P, C, temporal_idx, pos=pos
                        )
                else:
                    raise ValueError(f"Unknown attention type: {attn_type}")
            for i in range(len(spatial_intermediates)):

                concat_inter = torch.cat(
                    [spatial_intermediates[i], temporal_intermediates[i]], dim=-1
                )
                output_list.append(concat_inter)

        del concat_inter
        del spatial_intermediates
        del temporal_intermediates
        if use_cache:
            return output_list, self.patch_start_idx, past_key_values
        return output_list, self.patch_start_idx

    def _process_spatial_attention(self, tokens, B, S, P, C, spatial_idx, pos=None):

        if tokens.shape != (B * S, P, C):
            tokens = tokens.reshape(B, S, P, C).reshape(B * S, P, C)

        if pos is not None and pos.shape != (B * S, P, 2):
            pos = pos.reshape(B, S, P, 2).reshape(B * S, P, 2)

        intermediates = []

        for _ in range(self.aa_block_size):
            tokens = self.frame_blocks[spatial_idx](tokens, pos=pos)
            spatial_idx += 1
            intermediates.append(tokens.reshape(B, S, P, C))

        return tokens, spatial_idx, intermediates

    def _process_temporal_attention(
        self,
        tokens,
        B,
        S,
        P,
        C,
        temporal_idx,
        pos=None,
        past_key_values_block=None,
        use_cache=False,
        past_frame_idx=0,
    ) -> Union[
        Tuple[torch.Tensor, int, List[torch.Tensor]],
        Tuple[torch.Tensor, int, List[torch.Tensor], List],
    ]:

        if tokens.shape != (B, S * P, C):
            tokens = tokens.reshape(B, S, P, C).reshape(B, S * P, C)

        if pos is not None and pos.shape != (B, S * P, 2):
            pos = pos.reshape(B, S, P, 2).reshape(B, S * P, 2)

        intermediates = []

        for _ in range(self.aa_block_size):
            if not use_cache:
                L = S * P
                frame_ids = torch.arange(L, device=tokens.device) // P
                future_frame = frame_ids.unsqueeze(1) < frame_ids.unsqueeze(0)
                attn_mask = (
                    future_frame.to(tokens.dtype) * torch.finfo(tokens.dtype).min
                )
            else:
                attn_mask = None

            if use_cache:
                tokens, block_kv = self.global_blocks[temporal_idx](
                    tokens,
                    pos=pos,
                    attn_mask=attn_mask,
                    past_key_values=past_key_values_block,
                    use_cache=True,
                )
            else:
                tokens = self.global_blocks[temporal_idx](
                    tokens, pos=pos, attn_mask=attn_mask
                )
            temporal_idx += 1
            intermediates.append(tokens.reshape(B, S, P, C))

        if use_cache:
            return tokens, temporal_idx, intermediates, block_kv
        return tokens, temporal_idx, intermediates


def slice_expand_and_flatten(token_tensor, B, S):

    query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])

    others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])

    combined = torch.cat([query, others], dim=1)

    combined = combined.reshape(B * S, *combined.shape[2:])
    return combined
