import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math
import warnings
from torch import einsum


class TopkOperation(nn.Module):
    def __init__(self, dim, ratio_list) -> None:
        super().__init__()

        self.dim = dim
        self.ratio_list = ratio_list
        self.attn_list = nn.ParameterList([
            nn.Parameter(torch.tensor([0.2]), requires_grad=True) for i in range(len(ratio_list))
        ])

    def forward(self, attn, v):
        out = 0
        for i, ratio in enumerate(self.ratio_list):
            index = torch.topk(attn, k=int(self.dim * ratio), dim=-1, largest=True)[1]
            mask = torch.zeros_like(attn, requires_grad=False)
            mask = mask.scatter_(-1, index, 1.)
            attn_i = torch.where(mask > 0, attn, torch.full_like(attn, float('-inf')))
            attn_i = attn_i.softmax(dim=-1)
            out = out + (attn_i @ v) * self.attn_list[i]
        return out


class Topk_WindowSpectralAttn(nn.Module):
    def __init__(
            self,
            dim,
            dim_head,
            heads,
            window_size=8
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_qkv = nn.Conv2d(dim, dim_head * heads * 3, 1, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1, 1))
        self.proj = nn.Conv2d(dim_head * heads, dim, 1, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim
        self.window_size = window_size
        self.topk = TopkOperation(self.dim_head, ratio_list=[1/2, 2/3, 3/4, 4/5])

    def forward(self, x_in):
        """
        x_in: [b,c,h,w]
        return out: [b,c,h,w]
        """
        b, c, h, w = x_in.shape
        N = self.window_size
        nH = h // N
        nW = w // N
        qkv = self.to_qkv(x_in)
        q_inp, k_inp, v_inp = torch.chunk(qkv, chunks=3, dim=1)
        q, k, v = map(lambda t: rearrange(t, 'b (n d) (nh h) (nw w) -> b n (nh nw) (h w) d',n=self.num_heads, h=N, w=N),
                                (q_inp, k_inp, v_inp))
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-2, p=2)
        k = F.normalize(k, dim=-2, p=2)
        attn = (k.transpose(-2, -1) @ q)   # A = K^T*Q
        attn = attn * self.rescale
        x = self.topk(attn, v)
        x = rearrange(x, 'b n (nh nw) d (h w) -> b (n d) (nh h) (nw w)', nh=nH, nw=nW, h=N, w=N)
        out_c = self.proj(x)
        out_p = self.pos_emb(v_inp)
        out = out_c + out_p

        return out
    

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), bias=bias, stride=stride)


class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm_type='ln'):
        super().__init__()
        self.fn = fn
        self.norm_type = norm_type
        if norm_type == 'ln':
            self.norm = nn.LayerNorm(dim)
        else:
            self.norm = nn.GroupNorm(dim, dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, *args, **kwargs):
        if self.norm_type == 'ln':
            x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        else:
            x = self.norm(x)
        return self.fn(x, *args, **kwargs)


class GELU(nn.Module):
    def forward(self, x):
        return F.gelu(x)


## Channel Attention Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16, bias=False):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x) 
        y = self.conv_du(y)
        return x * y



class SAM(nn.Module):
    def __init__(self, in_c, ch, bias=True):
        super(SAM, self).__init__()
        self.conv1 = conv(in_c, ch, 3, bias=bias)
        self.conv2 = conv(ch, ch, 1, bias=bias)
        self.ca = CALayer(ch, reduction=8)
        self.head = nn.Sequential(
            nn.Conv2d(ch, ch // 4, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch // 4, 2, 3, padding=1),
            nn.Softplus() 
        )       
    def forward(self, x, x_phi):
        img = torch.cat([x, x_phi], dim=1)
        feat = self.conv2(self.conv1(img))
        noise_map1 = self.head(self.ca(feat)) + 1e-6     
        return noise_map1[:,0].squeeze(1), noise_map1[:,1].squeeze(1)



class HS_MSA(nn.Module):
    def __init__(
            self,
            dim,
            window_size=(8, 8),
            dim_head=28,
            heads=8,
            only_local_branch=False,
            non_local=False,            
    ):
        super().__init__()

        self.dim = dim
        self.heads = heads
        self.scale = dim_head ** -0.5
        if isinstance(window_size, int):
            window_size = (window_size, window_size)
        
        self.window_size = window_size
        self.only_local_branch = only_local_branch
        self.non_local = non_local
        
        # position embedding
        if only_local_branch:
            seq_l = window_size[0] * window_size[1]
            self.pos_emb = nn.Parameter(torch.Tensor(1, heads, seq_l, seq_l))
            trunc_normal_(self.pos_emb)
        else:
            seq_l1 = window_size[0] * window_size[1]
            self.pos_emb1 = nn.Parameter(torch.Tensor(1, 1, heads, seq_l1, seq_l1))
            h,w = 256//self.heads,320//self.heads
            seq_l2 = h*w//seq_l1
            self.pos_emb2 = nn.Parameter(torch.Tensor(1, 1, heads, seq_l2, seq_l2))
            trunc_normal_(self.pos_emb1)
            # trunc_normal_(self.pos_emb2)

        inner_dim = dim_head * heads
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        """
        x: [b,c, h, w]
        return out: [b, c, h, w]
        """
        b, c, h, w = x.shape
        w_size = self.window_size
        assert h % w_size[0] == 0 and w % w_size[1] == 0, 'fmap dimensions must be divisible by the window size'
        if self.only_local_branch:
            x_inp = rearrange(x, 'b c (h b0) (w b1) -> (b h w) (b0 b1) c', b0=w_size[0], b1=w_size[1])
            q = self.to_q(x_inp)
            k, v = self.to_kv(x_inp).chunk(2, dim=-1)
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
            q *= self.scale
            sim = einsum('b h i d, b h j d -> b h i j', q, k)
            sim = sim + self.pos_emb
            attn = sim.softmax(dim=-1)
            out = einsum('b h i j, b h j d -> b h i d', attn, v)
            out = rearrange(out, 'b h n d -> b n (h d)')
            out = self.to_out(out)
            out = rearrange(out, '(b h w) (b0 b1) c -> b c (h b0) (w b1)', h=h // w_size[0], w=w // w_size[1],
                            b0=w_size[0])
        else:
            x = x.permute(0, 2, 3, 1)
            q = self.to_q(x)
            k, v = self.to_kv(x).chunk(2, dim=-1)
            # q1, q2 = q[:,:,:,:c//2], q[:,:,:,c//2:]
            # k1, k2 = k[:,:,:,:c//2], k[:,:,:,c//2:]
            # v1, v2 = v[:,:,:,:c//2], v[:,:,:,c//2:]

            # non-local branch
            if self.non_local:
                q, k, v = map(lambda t: rearrange(t, 'b (h b0) (w b1) c -> b (h w) (b0 b1) c',
                                                    b0=w_size[0], b1=w_size[1]), (q, k, v))
                q, k, v = map(lambda t: t.permute(0, 2, 1, 3), (q.clone(), k.clone(), v.clone()))
                q, k, v = map(lambda t: rearrange(t, 'b n mm (h d) -> b n h mm d', h=self.heads), (q, k, v))
                q *= self.scale
                sim = einsum('b n h i d, b n h j d -> b n h i j', q, k)
                sim = sim + self.pos_emb2
                attn = sim.softmax(dim=-1)
                out = einsum('b n h i j, b n h j d -> b n h i d', attn, v)
                out = rearrange(out, 'b n h mm d -> b n mm (h d)')
                out = out.permute(0, 2, 1, 3)
            else:
                # local branch
                q, k, v = map(lambda t: rearrange(t, 'b (h b0) (w b1) c -> b (h w) (b0 b1) c',
                                                b0=w_size[0], b1=w_size[1]), (q, k, v))
                q, k, v = map(lambda t: rearrange(t, 'b n mm (h d) -> b n h mm d', h=self.heads), (q, k, v)) # 要把heads除以2去掉吗？
                q *= self.scale
                sim = einsum('b n h i d, b n h j d -> b n h i j', q, k)
                sim = sim + self.pos_emb1 
                attn = sim.softmax(dim=-1)
                out = einsum('b n h i j, b n h j d -> b n h i d', attn, v)
                out = rearrange(out, 'b n h mm d -> b n mm (h d)')

            out = self.to_out(out)
            out = rearrange(out, 'b (h w) (b0 b1) c -> b c (h b0) (w b1)', h=h // w_size[0], w=w // w_size[1],
                            b0=w_size[0])
        return out

class HSAB(nn.Module):
    def __init__(
            self,
            dim,
            window_size=(8, 8),
            dim_head=64,
            heads=8,
            num_blocks=2,
            non_local_sign=False,
    ):
        super().__init__()
        self.non_local = non_local_sign
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(nn.ModuleList([
                PreNorm(dim, HS_MSA(dim=dim, window_size=window_size, dim_head=dim_head, heads=heads, only_local_branch=(heads==1),non_local=self.non_local)),
                PreNorm(dim, FeedForward(dim=dim))
            ]))

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out: [b,c,h,w]
        """
        for (attn, ff) in self.blocks:
            x = attn(x) + x
            x = ff(x) + x

        return x
    

class SpectralAttnBlock(nn.Module):
    def __init__(
            self,
            dim,
            dim_head,
            heads,
            window_size=8,
            num_blocks=2
    ):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(nn.ModuleList([
                PreNorm(dim, Topk_WindowSpectralAttn(dim=dim, dim_head=dim_head, heads=heads, window_size=window_size)),
                PreNorm(dim, FeedForward(dim=dim))
            ]))

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out: [b,c,h,w]
        """
        for (attn, ff) in self.blocks:
            x = attn(x) + x
            x = ff(x) + x

        return x


class SplitBlock(nn.Module):
    def __init__(self, b1, b2) -> None:
        super().__init__()
        self.block1 = b1
        self.block2 = b2

    def forward(self, x):
        x1, x2 = torch.chunk(x, 2, dim=1)
        x1 = self.block1(x1)
        x2 = self.block2(x2)
        return torch.cat([x1, x2], dim=1)


class SpatialSpectralAttnBlock(nn.Module):
    def __init__(
            self,
            dim,
            dim_head,
            heads,
            window_size=8,
            num_blocks=2
    ):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            if dim==28:
                self.blocks.append(nn.ModuleList([
                    PreNorm(dim, HS_MSA(dim=dim, window_size=window_size, dim_head=dim_head, heads=heads, only_local_branch=heads==1,non_local=False)),
                    PreNorm(dim, FeedForward(dim=dim))
                ]))
            else:
                spa = HS_MSA(dim=dim//2, window_size=window_size, dim_head=dim_head, heads=heads//2, only_local_branch=(heads==1), non_local=False)
                spec = Topk_WindowSpectralAttn(dim=dim//2, dim_head=dim_head, heads=heads//2, window_size=window_size)
                self.blocks.append(nn.ModuleList([
                    SplitBlock(PreNorm(dim=dim//2, fn=spa), PreNorm(dim//2,fn=spec)),
                    PreNorm(dim, FeedForward(dim=dim))
                ]))

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out: [b,c,h,w]
        """
        for (attn, ff) in self.blocks:
            x = attn(x) + x
            x = ff(x) + x

        return x


class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
            GELU(),
            nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
            GELU(),
            nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
        )

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out: [b,c,h,w]
        """
        out = self.net(x)
        return out

# try:
#     from DAUHST import HSAB
# except:
#     from .DAUHST import HSAB

class HST(nn.Module):
    def __init__(self, in_dim=28, out_dim=28, dim=28, num_blocks=[1,1,1]):
        super(HST, self).__init__()
        self.dim = dim
        self.scales = len(num_blocks)

        # Input projection
        self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)

        # Encoder
        self.encoder_layers = nn.ModuleList([])
        dim_scale = dim
        for i in range(self.scales-1):
            self.encoder_layers.append(nn.ModuleList([
                HSAB(dim=dim_scale, num_blocks=num_blocks[i], dim_head=dim, heads=dim_scale // dim, non_local_sign=False),
                nn.Conv2d(dim_scale, dim_scale * 2, 4, 2, 1, bias=False)
            ]))
            dim_scale *= 2

        # Bottleneck
        self.bottleneck = HSAB(dim=dim_scale, dim_head=dim, heads=dim_scale // dim, num_blocks=num_blocks[-1])


        # Decoder
        self.decoder_layers = nn.ModuleList([])
        for i in range(self.scales-1):
            self.decoder_layers.append(nn.ModuleList([
                nn.ConvTranspose2d(dim_scale, dim_scale // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
                nn.Conv2d(dim_scale, dim_scale // 2, 1, 1, bias=False),
                SpectralAttnBlock(dim=dim_scale // 2, dim_head=dim, heads=(dim_scale // 2) // dim, num_blocks=num_blocks[self.scales - 2 - i]),
            ]))
            dim_scale //= 2

        # Output projection
        self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)

        #### activation function
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """

        b, c, h_inp, w_inp = x.shape
        hb, wb = 32, 32
        pad_h = (hb - h_inp % hb) % hb
        pad_w = (wb - w_inp % wb) % wb
        x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')

        # Embedding
        fea = self.embedding(x)
        x = x[:,:self.dim,:,:]

        # Encoder
        fea_encoder = []
        for (HSAB, FeaDownSample) in self.encoder_layers:
            fea = HSAB(fea)
            fea_encoder.append(fea)
            fea = FeaDownSample(fea)

        # Bottleneck
        fea = self.bottleneck(fea)

        # Decoder
        for i, (FeaUpSample, Fution, HSAB) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(torch.cat([fea, fea_encoder[self.scales-2-i]], dim=1))
            fea = HSAB(fea)

        # Mapping
        out = self.mapping(fea) + x
        return out[:, :, :h_inp, :w_inp]

def A(x,Phi):
    temp = x*Phi
    y = torch.sum(temp,1)
    return y

def At(y,Phi):
    temp = torch.unsqueeze(y, 1).repeat(1,Phi.shape[1],1,1)
    x = temp*Phi
    return x

def shift_3d(inputs,step=2):
    [bs, nC, row, col] = inputs.shape
    for i in range(nC):
        inputs[:,i,:,:] = torch.roll(inputs[:,i,:,:], shifts=step*i, dims=2)
    return inputs

def shift_back_3d(inputs,step=2):
    [bs, nC, row, col] = inputs.shape
    for i in range(nC):
        inputs[:,i,:,:] = torch.roll(inputs[:,i,:,:], shifts=(-1)*step*i, dims=2)
    return inputs

class HyPaNet(nn.Module):
    def __init__(self, in_nc=29, out_nc=8, channel=64):
        super(HyPaNet, self).__init__()
        self.fution = nn.Conv2d(in_nc, channel, 1, 1, 0, bias=True)
        self.down_sample = nn.Conv2d(channel, channel, 3, 2, 1, bias=True)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
                nn.Conv2d(channel, channel, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, out_nc, 1, padding=0, bias=True),
                nn.Softplus())
        self.relu = nn.ReLU(inplace=True)
        self.out_nc = out_nc

    def forward(self, x):
        x = self.down_sample(self.relu(self.fution(x)))
        x = self.avg_pool(x)
        x = self.mlp(x) + 1e-6
        return x[:,:self.out_nc//2,:,:], x[:,self.out_nc//2:,:,:]

class CIDNet(nn.Module):
    def __init__(self, num_iterations=1, ch=28, step=2):
        super(CIDNet, self).__init__()
        self.ch = ch
        self.step = step
        self.para_estimator = HyPaNet(in_nc=ch, out_nc=num_iterations*2)
        self.fution = nn.Conv2d(ch*2, ch, 1, padding=0, bias=True)
        self.num_iterations = num_iterations
        self.denoisers = nn.ModuleList([])
        self.attention = SAM(in_c=2, ch=ch) # 28
        for _ in range(num_iterations):
            self.denoisers.append(
                HST(in_dim=ch+1, out_dim=ch, dim=ch, num_blocks=[1,1,1]),
            )
    def initial(self, y, Phi):
        """
        :param y: [b,256,310]
        :param Phi: [b,28,256,310]
        :return: temp: [b,28,256,310]; alpha: [b, num_iterations]; beta: [b, num_iterations]
        """
        nC, step = self.ch, self.step
        y = y / nC * 2
        bs,row,col = y.shape
        y_shift = torch.zeros(bs, nC, row, col, device=y.device, dtype=torch.float32)
        for i in range(nC):
            y_shift[:, i, :, step * i:step * i + col - (nC - 1) * step] = y[:, :, step * i:step * i + col - (nC - 1) * step]
        z = self.fution(torch.cat([y_shift, Phi], dim=1))
        alpha, beta = self.para_estimator(self.fution(torch.cat([y_shift, Phi], dim=1)))
        return z, alpha, beta

    def forward(self, y, input_mask=None, save_iterations=False):
        Phi, Phi_s = input_mask
        z, _, _ = self.initial(y, Phi)
        b, h, w = y.shape
        iteration_results = [] if save_iterations else None
        for i in range(self.num_iterations):          
            Phi_z = A(z, Phi)
            w1, w2 = self.attention(Phi_z.unsqueeze(1), y.unsqueeze(1))
            x = z + At(torch.div(y-Phi_z, w1+Phi_s), Phi)
            x = shift_back_3d(x, self.step)
            w2_expanded = w2.unsqueeze(1)
            z = self.denoisers[i](torch.cat([x, w2_expanded], dim=1))
            if i<self.num_iterations-1:
                z = shift_3d(z, self.step)
        final_result = z[:, :, :, 0:w-(self.ch-1)*self.step]
        if save_iterations:
            return final_result, iteration_results
        else:
            return final_result


if __name__ == "__main__":
    from fvcore.nn import FlopCountAnalysis, parameter_count_table, flop_count_table
    ch = 28
    H=256
    W=256
    model = CIDNet(5, ch=ch, step=2)
    y = torch.randn(1, H, W+(ch-1)*2)
    Phi = torch.randn(1, ch, H, W+(ch-1)*2)
    PhiPhi_T = torch.randn(1, H, W+(ch-1)*2)

    print(model(y, (Phi, PhiPhi_T)).shape)

    flops = FlopCountAnalysis(model, (y, (Phi, PhiPhi_T)))
    print(flops.total() / 1e9)
    print(parameter_count_table(model))
    print(flop_count_table(flops, max_depth=5))