import os
from typing import List, Dict, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from .head_act import activate_head
from .utils import create_uv_grid, position_grid_to_embed
from .dpt_head import DPTHead
from iggt.heads.block import MemEffAttention, MemEffCrossAttention
from .window_sa import SwinSA, SwinCA

class PartHead(DPTHead):
    def __init__(
        self,
        dim_in: int,
        patch_size: int = 14,
        output_dim: int = 4,
        activation: str = "relu",
        features: int = 256,
        out_channels: List[int] = [256, 256, 256, 256],
        intermediate_layer_idx: List[int] = [4, 11, 17, 23],
        pos_embed: bool = True,
        feature_only: bool = False,
        down_ratio: int = 1,
        for_tracker: bool = False,
    ) -> None:

        super(PartHead, self).__init__(dim_in=dim_in)
        self.for_tracker = for_tracker
        self.patch_size = patch_size
        self.activation = activation
        self.pos_embed = pos_embed
        self.feature_only = feature_only
        self.down_ratio = down_ratio
        self.intermediate_layer_idx = intermediate_layer_idx

        head_features_1 = features
        head_features_2 = 32

        self.scratch = _make_scratch(
            out_channels,
            features,
            expand=False,
        )

        # Attach additional modules to scratch.
        self.scratch.stem_transpose = None

        # self.scratch.refinenet1 = _make_fusion_block(features, has_residual=False)
        self.scratch.refinenet1 = _make_fusion_block(features)
        self.scratch.refinenet2 = _make_fusion_block(features)
        self.scratch.refinenet3 = _make_fusion_block(features)
        self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)


        self.scratch.output_conv1 = nn.Conv2d(
            head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
        )
        conv2_in_channels = head_features_1 // 2

        self.scratch.output_conv2 = nn.Sequential(
            nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
        )
        self.cross_attention_1 = MemEffCrossAttention(
            dim=head_features_1,
            num_heads=8,
            qkv_bias=True,
            attn_drop=0.0,
            proj_drop=0.0,
            qk_norm=False,
        )

        self.cross_attention_2 = MemEffCrossAttention(
            dim=head_features_1,
            num_heads=8,
            qkv_bias=True,
            attn_drop=0.0,
            proj_drop=0.0,
            qk_norm=False,
        )

        self.window_self_atten = SwinSA(
            img_size=512,    # 注意这里img_size只影响内部窗口分割，实际输入由forward决定
            out_chans=conv2_in_channels,
            embed_dim=conv2_in_channels,
            num_heads=4,
            window_size=8, # 或者其它合适数值
        )

        self.window_cross_attention = SwinCA(
            img_size=128,
            out_chans=head_features_1,
            embed_dim=head_features_1,
            num_heads=4,
            window_size=8, # 或者其它合适数值
        )

    def forward(
        self,
        aggregated_tokens_list: List[torch.Tensor],
        images: torch.Tensor,
        patch_start_idx: int,
        frames_chunk_size: int = 12,
        point_feature: List[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass through the DPT head, supports processing by chunking frames.
        Args:
            aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
            images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
            patch_start_idx (int): Starting index for patch tokens in the token sequence.
                Used to separate patch tokens from other tokens (e.g., camera or register tokens).
            frames_chunk_size (int, optional): Number of frames to process in each chunk.
                If None or larger than S, all frames are processed at once. Default: 8.

        Returns:
            Tensor or Tuple[Tensor, Tensor]:
                - If feature_only=True: Feature maps with shape [B, S, C, H, W]
                - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
        """
        B, S, _, H, W = images.shape

        # If frames_chunk_size is not specified or greater than S, process all frames at once
        if frames_chunk_size is None or frames_chunk_size >= S:
            return self._forward_impl(aggregated_tokens_list, images, patch_start_idx, point_feat=point_feature)

        # Otherwise, process frames in chunks to manage memory usage
        assert frames_chunk_size > 0

        # Process frames in batches
        all_preds = []
        all_conf = []

        for frames_start_idx in range(0, S, frames_chunk_size):
            frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
            chunk_preds, chunk_conf = self._forward_impl(
                aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx,
                point_feat=point_feature
            )
            all_preds.append(chunk_preds)
            all_conf.append(chunk_conf)
        return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)

    def scratch_forward(self, features: List[torch.Tensor], point_feat: List[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass through the fusion blocks.

        Args:
            features (List[Tensor]): List of feature maps from different layers.

        Returns:
            Tensor: Fused feature map.
        """
        layer_1, layer_2, layer_3, layer_4 = features

        layer_1_rn = self.scratch.layer1_rn(layer_1)
        layer_2_rn = self.scratch.layer2_rn(layer_2)
        layer_3_rn = self.scratch.layer3_rn(layer_3)
        layer_4_rn = self.scratch.layer4_rn(layer_4)

        out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
        del layer_4_rn, layer_4

        if point_feat is not None:
            out4 = out.flatten(2).permute(0, 2, 1)  # (B, N, C)
            point_feat_4 = point_feat[2].flatten(2).permute(0, 2, 1)  # (B, N, C)
            out4 = self.cross_attention_2(out4, point_feat_4, point_feat_4)
            out4 = out4.permute(0, 2, 1)  # (B, C, N)
            out4 = out4.view_as(out)  # (B, C, H, W)

        out = self.scratch.refinenet3(out4, layer_3_rn, size=layer_2_rn.shape[2:])
        del layer_3_rn, layer_3

        if point_feat is not None:
            out3 = out.flatten(2).permute(0, 2, 1)  # (B, N, C)
            point_feat_3 = point_feat[1].flatten(2).permute(0, 2, 1)  # (B, N, C)
            out3 = self.cross_attention_1(out3, point_feat_3, point_feat_3)
            out3 = out3.permute(0, 2, 1)  # (B, C, N)
            out3 = out3.view_as(out)  # (B, C, H, W)

        out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
        del layer_2_rn, layer_2

        if point_feat is not None:
            # out2 = out.flatten(2).permute(0, 2, 1)  # (B, N, C)
            # point_feat_2 = point_feat[0].flatten(2).permute(0, 2, 1)  # (B, N, C)
            # out2 = self.cross_attention_2(out2, point_feat_2, point_feat_2)
            # out2 = out2.permute(0, 2, 1)  # (B, C, N)
            # out2 = out2.view_as(out)  # (B, C, H, W)
            out2 = out.permute(0, 2, 3, 1)  
            point_feat_2 = point_feat[0].permute(0, 2, 3, 1)
            out2 = self.window_cross_attention(out2, point_feat_2, point_feat_2)
            out2 = out2.permute(0, 3, 1, 2)

        out = self.scratch.refinenet1(out2, layer_1_rn)
        del layer_1_rn, layer_1
        del point_feat, out2, out3, out4

        out = self.scratch.output_conv1(out)
        return out
    
    def _forward_impl(
            self,
            input: List[torch.Tensor],
            images: torch.Tensor,
            patch_start_idx: int,
            point_feat: List[torch.Tensor] = None,
        ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
            B, S, _, H, W = images.shape
            patch_h, patch_w = H // self.patch_size, W // self.patch_size
            out = input

            # Fuse features from multiple layers.
            out = self.scratch_forward(out, point_feat)

            # === 关键：用MiniHAT处理特征 ===
            # out: [B, C, H, W] -> [B, H, W, C]
            out = out.permute(0, 2, 3, 1).contiguous()
            out = self.window_self_atten(out)
            # out: [B, H, W, C] -> [B, C, H, W]
            out = out.permute(0, 3, 1, 2).contiguous()
            
            # Interpolate fused output to match target image resolution.
            out = custom_interpolate(
                out,
                (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
                mode="bilinear",
                align_corners=True,
            )

            # B_, C_, H_, W_ = out.shape
            # out_flat = out.permute(0, 2, 3, 1).reshape(B_, -1, C_)
            # out_flat = self.mlp_head(out_flat)
            # out = out_flat.view(B_, H_, W_, -1).permute(0, 3, 1, 2).contiguous()

            out = self.scratch.output_conv2(out)

            preds = out.view(B, S, *out.shape[1:])
            return preds

################################################################################
# Modules
################################################################################


def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
    return FeatureFusionBlock(
        features,
        nn.ReLU(inplace=True),
        deconv=False,
        bn=False,
        expand=False,
        align_corners=True,
        size=size,
        has_residual=has_residual,
        groups=groups,
    )


def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
    scratch = nn.Module()
    out_shape1 = out_shape
    out_shape2 = out_shape
    out_shape3 = out_shape
    if len(in_shape) >= 4:
        out_shape4 = out_shape

    if expand:
        out_shape1 = out_shape
        out_shape2 = out_shape * 2
        out_shape3 = out_shape * 4
        if len(in_shape) >= 4:
            out_shape4 = out_shape * 8

    scratch.layer1_rn = nn.Conv2d(
        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
    )
    scratch.layer2_rn = nn.Conv2d(
        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
    )
    scratch.layer3_rn = nn.Conv2d(
        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
    )
    if len(in_shape) >= 4:
        scratch.layer4_rn = nn.Conv2d(
            in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
        )
    return scratch


class ResidualConvUnit(nn.Module):
    """Residual convolution module."""

    def __init__(self, features, activation, bn, groups=1):
        """Init.

        Args:
            features (int): number of features
        """
        super().__init__()

        self.bn = bn
        self.groups = groups
        self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
        self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)

        self.norm1 = None
        self.norm2 = None

        self.activation = activation
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        """Forward pass.

        Args:
            x (tensor): input

        Returns:
            tensor: output
        """

        out = self.activation(x)
        out = self.conv1(out)
        if self.norm1 is not None:
            out = self.norm1(out)

        out = self.activation(out)
        out = self.conv2(out)
        if self.norm2 is not None:
            out = self.norm2(out)

        return self.skip_add.add(out, x)


class FeatureFusionBlock(nn.Module):
    """Feature fusion block."""

    def __init__(
        self,
        features,
        activation,
        deconv=False,
        bn=False,
        expand=False,
        align_corners=True,
        size=None,
        has_residual=True,
        groups=1,
    ):
        """Init.

        Args:
            features (int): number of features
        """
        super(FeatureFusionBlock, self).__init__()

        self.deconv = deconv
        self.align_corners = align_corners
        self.groups = groups
        self.expand = expand
        out_features = features
        if self.expand == True:
            out_features = features // 2

        self.out_conv = nn.Conv2d(
            features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
        )

        if has_residual:
            self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)

        self.has_residual = has_residual
        self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)

        self.skip_add = nn.quantized.FloatFunctional()
        self.size = size

    def forward(self, *xs, size=None):
        """Forward pass.

        Returns:
            tensor: output
        """
        output = xs[0]

        if self.has_residual:
            res = self.resConfUnit1(xs[1])
            output = self.skip_add.add(output, res)

        output = self.resConfUnit2(output)

        if (size is None) and (self.size is None):
            modifier = {"scale_factor": 2}
        elif size is None:
            modifier = {"size": self.size}
        else:
            modifier = {"size": size}

        output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
        output = self.out_conv(output)

        return output


def custom_interpolate(
    x: torch.Tensor,
    size: Tuple[int, int] = None,
    scale_factor: float = None,
    mode: str = "bilinear",
    align_corners: bool = True,
) -> torch.Tensor:
    """
    Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
    """
    if size is None:
        size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))

    INT_MAX = 1610612736

    input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]

    if input_elements > INT_MAX:
        chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
        interpolated_chunks = [
            nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
        ]
        x = torch.cat(interpolated_chunks, dim=0)
        return x.contiguous()
    else:
        return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)

