import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils import data
import numpy as np
import time

class CRM(nn.Module):
    def __init__(self,stride=1,useRepaire=True):
        super().__init__()

        self.depth_conv3=nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,stride=stride,padding=1,padding_mode='reflect')
        
        self.depth_conv5=nn.Conv2d(in_channels=1,out_channels=1,kernel_size=5,stride=stride,padding=2,padding_mode='reflect')
        
        self.depth_conv7=nn.Conv2d(in_channels=1,out_channels=1,kernel_size=7,stride=stride,padding=3,padding_mode='reflect')
        
        self.useRepaire=useRepaire
        
    def forward(self,x,xlb,xub):
        '''
        实现depth-wise卷积操作
        '''  
        self.xub=xub
        self.xlb=xlb
        channels=[]
        for i in range(x.shape[1]):
            channel=torch.unsqueeze(x[:,i,:,:],1)
            channels.append(self.depth_conv3(channel))
        x1=torch.cat(channels,dim=1)
        
        channels=[]
        for i in range(x.shape[1]):
            channel=torch.unsqueeze(x[:,i,:,:],1)
            channels.append(self.depth_conv5(channel))
        x2=torch.cat(channels,dim=1)
        
        channels=[]
        for i in range(x.shape[1]):
            channel=torch.unsqueeze(x[:,i,:,:],1)
            channels.append(self.depth_conv7(channel))
        x3=torch.cat(channels,dim=1)
        
        y=(x1+x2+x3)/3
        if self.useRepaire:
            y=self.repair(y)
        return y
    
    def repair(self,x):
        x[x>self.xub]=self.xub
        x[x<self.xlb]=self.xlb
        return x
    
    
if __name__=='__main__':
    data=torch.randn((2,3,10,10))
    net=CRM()
    label=torch.zeros_like(data)
    lf=torch.nn.MSELoss()
    optimizer=torch.optim.Adam(net.parameters(),lr=0.001)
    for i in range(1000):
        y=net(data)
        # loss=lf(y,label)
        loss=lf(y,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
        time.sleep(0.1)
