import os
import copy
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from einops import rearrange
import torch.nn.functional as F
from conv_util import Feature_Gather
from mamba_ssm.modules.mamba_simple import Mamba
from pwclonet_model_utils import GatedMLP
from PIL import Image
import time

class deformable_mamba_layer(nn.Module):
    """Multi-head cross attention module with 4 sampling points."""

    def __init__(self, dim, feat_lvls, num_head=8, sampling_points=4):
        super(deformable_mamba_layer, self).__init__()
        self.dim = dim
        self.num_head = num_head
        self.sampling_points = sampling_points

        self.mamba_attention = nn.Sequential(
                                    nn.LayerNorm(dim),
                                    nn.ReLU(),
                                    Mamba(dim,  d_state=16, d_conv=4, expand=2),
                                    nn.LayerNorm(dim),
        )
        self.gmlp_f = GatedMLP(dim, dim, dim)
        self.generate_offsets = nn.ModuleList()
        self.add_offsets = nn.ModuleList()
        for _ in range(feat_lvls):
            self.generate_offsets.append(nn.Linear(dim, 2 * sampling_points))
            self.add_offsets.append(nn.Linear(2 * sampling_points, 2 * sampling_points))

    def forward(self, q, srcs, reference_points, fn2_dir=None, mask_split=None):
        batch_size, num_queries, _ = q.size()
        sampling_values = []

        for lvl, src in enumerate(srcs):
            _, _, h_, w_ = src.size()
            
            # Expand reference_points to a sampling grid
            sampling_grids = reference_points.unsqueeze(2).repeat(1, 1, self.sampling_points, 1)  # [B, N, 4, 2]
            sampling_grids = sampling_grids.view(batch_size, -1, 2)  # [B, N*4, 2]

            # Generate offsets and apply them to the sampling grid
            offsets = self.generate_offsets[lvl](q)
            offsets = self.add_offsets[lvl](offsets).view(batch_size, num_queries * self.sampling_points, 2)  # [B, N*4, 2]
            sampling_grids += offsets

            # Use grid_sample to sample from the feature map
            sampling_grids = sampling_grids.unsqueeze(1)  # [B, 1, N*4, 2]
            sampling_values_ = F.grid_sample(src, sampling_grids, mode='bilinear', padding_mode='zeros', align_corners=False)  # [B, C, 1, N*4]
            sampling_values_.squeeze_(2)  # [B, C, N*4]
            sampling_values.append(sampling_values_)

        # Concatenate sampling values from different levels
        sampling_values = torch.cat(sampling_values, dim=-1)  # [B, C, N*L*4]

        # Reshape k and v to ensure each query point interacts only with its sampling points
        k = sampling_values.view(batch_size, num_queries, self.sampling_points, self.dim)  # [B, N, 4, C]
        k = k.view(batch_size * num_queries, self.sampling_points, self.dim)  # [B*N, 4, C]
        v = k  # In multi-head attention, k and v are typically the same
        
        # Reshape q to match k and v
        q = q.view(batch_size * num_queries, 1, self.dim)  # [B*N, 1, C]

        attn_output = torch.max(self.mamba_attention(self.gmlp_f(torch.cat([k, q], dim=1))), dim=1)[0].unsqueeze(1)  # [B*N, 1, C]

        # Restore the shape of attn_output
        attn_output = attn_output.view(batch_size, num_queries, self.dim)  # [B, N, C]

        return attn_output



class Deformable_Mamba(nn.Module):
    def __init__(self, dim, out_dim, fold_w=2, fold_h=2, heads=4, head_dim=24,
                 return_center=False):
        """

        :param dim:  channel nubmer
        :param out_dim: channel nubmer
        :param proposal_w: the sqrt(proposals) value, we can also set a different value
        :param proposal_h: the sqrt(proposals) value, we can also set a different value
        :param fold_w: the sqrt(number of regions) value, we can also set a different value
        :param fold_h: the sqrt(number of regions) value, we can also set a different value
        :param heads:  heads number in context cluster
        :param head_dim: dimension of each head in context cluster
        :param return_center: if just return centers instead of dispatching back (deprecated).
        """
        super().__init__()
        self.heads = heads
        self.head_dim = head_dim
        self.f = nn.Conv2d(dim, heads * head_dim, kernel_size=1)  # for similarity
        self.proj = nn.Conv2d(heads * head_dim, out_dim, kernel_size=1)  # for projecting channel number
        self.v = nn.Conv2d(dim, heads * head_dim, kernel_size=1)  # for value
        self.sim_alpha = nn.Parameter(torch.ones(1))
        self.sim_beta = nn.Parameter(torch.zeros(1))
        self.defomable_mamba_layer = deformable_mamba_layer(dim, feat_lvls=1, num_head=8, sampling_points=4)
        self.fold_w = fold_w
        self.fold_h = fold_h
        self.return_center = return_center

    def forward(self, init_query, points, x, device, fn2_dir=None):  # [b,c,h,w]
        size_range = [1296.0, 384.0]
        val_flag_1 = (points[:, :, 1] > 0) & (points[:, :, 1] <=  size_range[1])
        val_flag_2 = (points[:, :, 0] > 0) & (points[:, :, 0] <= size_range[0])
        mask_split = val_flag_1 & val_flag_2
        points[:, :, 0] = points[:, :, 0] / (size_range[0] - 1.0) * 2.0 - 1.0
        points[:, :, 1] = points[:, :, 1] / (size_range[1] - 1.0) * 2.0 - 1.0
        out = self.defomable_mamba_layer(init_query, [x], points.cuda(device), fn2_dir, mask_split) # [B N C]
        # [B C 1 N]
        out = out.permute(0, 2, 1).unsqueeze(2) # [B C 1 N]
        return out

