# -*- coding: utf-8 -*-
# @File : tadconv.py
# @Author : 王军
# @Time : 2022/10/30 9:49
# @Software : PyCharm
""" TAdaConv. """

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _triple,_pair,_single
from einops import rearrange
from einops.layers.torch import Reduce,Rearrange
################################################
class TAdaFeaturConv(nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1,bias=True):
        super(TAdaFeaturConv, self).__init__()
        kernel_size = _triple(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        # base weights (W_b)
        self.weight = nn.Parameter(
            torch.Tensor(1,out_channels, in_channels, kernel_size[0], kernel_size[1])
        )
        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, out_channels))
        else:
            self.register_parameter('bias', None)

        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x,alpha):
        b,in_c,s,t = x.shape
        x = rearrange(x,'b c s t->1 (b c) s t')
        _,out_channels, in_channels, k1, k2 = self.weight.shape
        assert alpha.shape[1] == (self.in_channels)
        alpha = rearrange(alpha,'b i->b 1 i 1 1')
        adp_w = self.weight*alpha
        adp_w = rearrange(adp_w,'b o i k g->(b o) i k g')
        if self.bias is not None:
            bias = self.bias.repeat(b, 1).reshape(-1)
        out = F.conv2d(x, weight=adp_w, bias=bias, stride=self.stride, padding=self.padding,
            dilation=self.dilation,groups=b)
        out = rearrange(out,'1 (b c) s t->b c s t',b=b)
        return out
    
class AlphaFeature(nn.Module):
    def __init__(self,in_channels):
        super(AlphaFeature, self).__init__()
        self.temporal_patch = 2
        self.radio = 2
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,out_channels=int(in_channels / self.radio),kernel_size=(1,1)),
            nn.BatchNorm2d(int(in_channels / self.radio)),
            nn.ELU(),
            nn.Conv2d(in_channels=int(in_channels / self.radio), out_channels=in_channels, kernel_size=(1, 1)),
            nn.ELU(),
            nn.AdaptiveAvgPool2d((1,self.temporal_patch)),
            Rearrange('b c 1 p->b (c p)'),
            nn.Linear(self.temporal_patch*in_channels,in_channels)
        )
    def forward(self,x):
        return self.model(x) + 1
    
class TAdFeatureCNN(nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1,bias=True):
        super(TAdFeatureCNN, self).__init__()
        self.ad_conv = TAdaFeaturConv(in_channels=in_channels, out_channels=out_channels,
                                            kernel_size=kernel_size, stride=stride, padding=padding,
                                            dilation=dilation,bias=bias)
        self.alpha = AlphaFeature(in_channels=in_channels)
    def forward(self,x):
        return self.ad_conv(x,self.alpha(x))

#################################################
class AlphaSpace(nn.Module):
    def __init__(self,in_c):
        super(AlphaSpace, self).__init__()
        self.models = nn.Sequential(
            nn.Conv2d(in_channels=in_c, out_channels=int(in_c / 2), kernel_size=(1, 1)),
            nn.BatchNorm2d(int(in_c / 2)),
            nn.ELU(),
            nn.Conv2d(in_channels=int(in_c / 2), out_channels=1, kernel_size=(1, 1)),
            Reduce('b 1 s t->b s',reduction='mean'),
        )
    def forward(self,x):
        return self.models(x)

class TAdaSpaceCovn1d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,bias=True):
        super(TAdaSpaceCovn1d, self).__init__()
        kernel_size = _single(kernel_size)
        stride = _single(stride)
        padding = _single(padding)
        dilation = _single(dilation)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        # base weights (W_b)
        self.weight = nn.Parameter(
            torch.Tensor(1,1, out_channels, in_channels, kernel_size[0])
        )
        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, out_channels))
        else:
            self.register_parameter('bias', None)
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    def forward(self,x,alpha):
        b, in_c, s, t = x.shape
        x = rearrange(x, 'b c s t->1 (b s c) t')
        _, _,out_channels, in_channels, k = self.weight.shape
        alpha = rearrange(alpha, 'b s->b s 1 1 1')
        weights = self.weight*alpha
        weights = rearrange(weights,' b s o i k->(b s o) i k')
        if self.bias is not None:
            bias = self.bias.repeat(b*s, 1).reshape(-1)
        out = F.conv1d(x, weight=weights, bias=bias, padding=self.padding,stride=self.stride,
            dilation=self.dilation,groups=b*s)
        out = rearrange(out,"1 (b s o) t-> b o s t",b=b,s=s,o=self.out_channels)
        return out
    
class TAdSpaceCNN(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,bias=True):
        super(TAdSpaceCNN, self).__init__()
        self.ad_cnn = TAdaSpaceCovn1d(in_channels,out_channels,
                                      kernel_size,stride=stride,
                                      padding=padding,dilation=dilation,
                                      bias=bias)
        self.alpha_model = AlphaSpace(in_channels)
    def forward(self,x):
        return self.ad_cnn(x,self.alpha_model(x))
################################################


if __name__ == '__main__':
    # model = TAdaCNN(in_channels=32, out_channels=32, kernel_size=(1,3),padding=(0,1))
    # alpha = torch.randn((64,32))
    # x = torch.randn((64,32,207,12))
    #
    # alpha_model = TemporalMLP(32,12,32)
    # #alpha = alpha_model(x)
    # y = model(x,alpha_model(x))
    # print(y.shape)
    model = TAdFeatureCNN(in_channels=32, out_channels=22, kernel_size=3, padding=1)
    x = torch.randn((64, 32, 207, 12))
    y = model(x)
    print(y.shape)


