import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange
from einops import rearrange, repeat
from typing import List
import os
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os
import matplotlib.pyplot as plt


def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing ="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

def prepare_mask(mask, h, w, save_vis=False):
    N = mask.shape[0]
    rescaled_mask = torch.zeros(N, mask.shape[1]//(64//h), mask.shape[2]//(64//w))
    for i, m in enumerate(mask):
        h_idx, w_idx = torch.where(m==1)
        # patchify
        h_idx, w_idx = h_idx//(64//h), w_idx//(64//w)
        rescaled_mask[i][h_idx, w_idx] = 1
    
    rescaled_mask = rescaled_mask.view(N, -1)
    attn_mask = torch.zeros(N, rescaled_mask.shape[1], rescaled_mask.shape[1])
    attn_mask.requires_grad = False
    for i, m in enumerate(attn_mask):
        # get unzero index
        indices = torch.where(rescaled_mask[i]==1)[0]
        attn_mask[i,indices,:] = 1
        attn_mask[i,:,indices] = 1

    attn_mask = (attn_mask == 1)

    return attn_mask

# classes
class MaskAttentionComposer(nn.Module):
    def __init__(self,
                 base_weight_init=1.0,
                 overlap_boost_init=2.5,
                 cross_person_scale_init=0.6,
                 constrained=True,
                 debug_print=True):
        super().__init__()
        self.constrained = constrained
        self.debug_print = debug_print
        self.print_counter = 0

        if constrained:
            self._base_weight = nn.Parameter(self._inverse_sigmoid(base_weight_init / 2.0))
            self._overlap_boost = nn.Parameter(self._inverse_sigmoid((overlap_boost_init - 1.0) / 4.0))
            self._cross_person_scale = nn.Parameter(self._inverse_sigmoid(cross_person_scale_init))
        else:
            self.base_weight = nn.Parameter(torch.tensor(float(base_weight_init)))
            self.overlap_boost = nn.Parameter(torch.tensor(float(overlap_boost_init)))
            self.cross_person_scale = nn.Parameter(torch.tensor(float(cross_person_scale_init)))

    def _inverse_sigmoid(self, x, eps=1e-6):
        x = torch.tensor(x, dtype=torch.float32)
        x = torch.clamp(x, eps, 1 - eps)
        return torch.log(x / (1 - x))

    @property
    def base_weight(self):
        if self.constrained:
            return torch.sigmoid(self._base_weight) * 2.0  
        return self._parameters.get('base_weight', torch.tensor(1.0))

    @property
    def overlap_boost(self):
        if self.constrained:
            return torch.sigmoid(self._overlap_boost) * 4.0 + 1.0  
        return self._parameters.get('overlap_boost', torch.tensor(2.5))

    @property
    def cross_person_scale(self):
        if self.constrained:
            return torch.sigmoid(self._cross_person_scale)  
        return self._parameters.get('cross_person_scale', torch.tensor(0.6))

    def forward(self, multi_masks: List[torch.Tensor], h: int, w: int, save_vis=False):
        if len(multi_masks) == 0:
            return None

        P = len(multi_masks)  
        B, H, W = multi_masks[0].shape
        device = multi_masks[0].device
        ph, pw = H // h, W // w
        N = h * w

        patch_masks = []
        for mask in multi_masks:
            mask = mask.view(B, h, ph, w, pw).max(dim=2)[0].max(dim=3)[0]
            patch_masks.append(mask)

        flat_masks = [m.view(B, -1) for m in patch_masks]  
        stacked = torch.stack(flat_masks, dim=1) 

        overlap = (stacked.sum(dim=1) > 1).float()  

        attn_mask = torch.zeros(B, N, N, device=device)

        for p in range(P):
            m_p = stacked[:, p, :]  
            self_mask = torch.bmm(m_p.unsqueeze(2), m_p.unsqueeze(1))  
            attn_mask += self_mask * self.base_weight.to(device)

        for p in range(P):
            for q in range(P):
                if q == p:
                    continue
                m_p = stacked[:, p, :]
                m_q = stacked[:, q, :]
                cross_mask = torch.bmm(m_p.unsqueeze(2), m_q.unsqueeze(1)) * overlap.unsqueeze(1)
                attn_mask += cross_mask * self.cross_person_scale.to(device)

        overlap_mat = torch.bmm(overlap.unsqueeze(2), overlap.unsqueeze(1))  
        attn_mask += overlap_mat * self.overlap_boost.to(device)

        if self.debug_print:
            self.print_counter += 1
            if self.print_counter % 100 == 1:  
                print(f"\n[DEBUG Step {self.print_counter}]")
                print(f"base_weight: {self.base_weight.item():.6f} | grad: {self.base_weight.grad.norm().item() if self.base_weight.grad is not None else 0:.6f}")
                print(f"overlap_boost: {self.overlap_boost.item():.6f} | grad: {self.overlap_boost.grad.norm().item() if self.overlap_boost.grad is not None else 0:.6f}")
                print(f"cross_person_scale: {self.cross_person_scale.item():.6f} | grad: {self.cross_person_scale.grad.norm().item() if self.cross_person_scale.grad is not None else 0:.6f}")

        return attn_mask  



class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)
        self.last_attn_map = None

    def forward(self, x, mask):
        x = self.norm(x)
        h = self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        B, N, _ = x.shape

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        if mask is not None:
            if mask.dtype == torch.bool:
                max_neg_value = -torch.finfo(dots.dtype).max
                mask = mask.unsqueeze(1).repeat(1,h,1,1)
            
                min_val = torch.min(dots).detach()
                dots.masked_fill_(~mask, min_val)
                
            else:
                mask = mask.unsqueeze(1).repeat(1, h, 1, 1)
                min_val = torch.min(dots).detach()
                dots.masked_fill_(mask==0, min_val)
                dots = dots * mask

        attn = self.attend(dots)
        self.last_attn_map = attn.detach()
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x, mask):
        for i, (attn, ff) in enumerate(self.layers):
            if mask is not None:
                x = attn(x, mask[i]) + x
            else:
                x = attn(x, mask) + x
            x = ff(x) + x
        return x

class IASAM(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim=512, channels = 320, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        self.depth = depth

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        patch_dim = 128 * patch_height * patch_width
        
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(channels,128,kernel_size=3,stride=1,padding=1),
            Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        
        self.h = image_height // patch_height
        self.w = image_width // patch_width
        
        self.pos_embedding = posemb_sincos_2d(
            h = self.h,
            w = self.w,
            dim = dim,
        ) 

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
        self.to_out = nn.Sequential(
            Rearrange("b (h w) (p1 p2 c) -> b c (h p1) (w p2)", p1 = patch_height, p2 = patch_width, h = image_size//patch_size),
            nn.Conv2d(128,channels,kernel_size=3,stride=1,padding=1),
        )

        self.prepare_multi_masks = MaskAttentionComposer()

    def forward(self, img, mask=None, t=None, save_attn=False, save_dir='./filnal_attn_vis'):
        with torch.cuda.amp.autocast():
            device = img.device
            x = self.to_patch_embedding(img)
            x += self.pos_embedding.to(device, dtype=x.dtype)
        
            attn_mask = None
            if mask is not None:
                if isinstance(mask, list): 
                    split_masks = [[]]
                    for mask_i in mask:
                        if t < 500:
                            k1_mask = mask_i[[1], :, :]
                            split_masks[0].append(k1_mask)
                        else:
                            k1_mask = mask_i[[3], :, :]
                            split_masks[0].append(k1_mask)
                    print('split mask:', len(split_masks), len(split_masks[0]), split_masks[0][0].shape)  
                    
                    if mask[0].shape[0] != x.shape[0]:
                        attn_mask = [self.prepare_multi_masks(m, self.h, self.w, save_vis=True).to(device) for m in split_masks]
                        assert self.depth == len(attn_mask)
                    else:
                        attn_mask = [self.prepare_multi_masks(mask, self.h, self.w, save_vis=True).to(device)]
                else:
                    if mask.shape[0] != x.shape[0]:
                        masks = torch.split(mask, x.shape[0])
                        if t < 500:
                            tmp_masks = [masks[0]]
                        else:
                            tmp_masks = [masks[1]]
                        masks = tmp_masks
                        attn_mask = [prepare_mask(m, self.h, self.w, save_vis=True).to(device) for m in masks]
                        assert self.depth == len(attn_mask)
                    else:
                        attn_mask = [prepare_mask(mask, self.h, self.w, save_vis=True).to(device)]
                        
            
            x = self.transformer(x, attn_mask)
            x = self.to_out(x)
            
            return x