from __future__ import print_function, division
import math
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

from .shared import conv_block, up_conv


class AdaptiveRFConv_block(nn.Module):
    def __init__(self, in_ch, out_ch, ks, di, patch_size, hidden=12):
        super().__init__()
        assert patch_size % 3 == 0

        self.ps = patch_size
        self.ch_reduce = nn.Sequential(
            nn.Conv2d(in_ch, 64, kernel_size=1, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, hidden, kernel_size=1, stride=1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(inplace=True),
        )  # bs, H, W, Cin => bs, H, W, h
        self.Unfold = nn.Unfold(kernel_size=patch_size, padding=patch_size//2, stride=1)  # bs, ps*ps*h, H*W
        self.avgpool = nn.AvgPool2d(kernel_size=3, stride=3)
        self.mlp = nn.Sequential(
            nn.Conv2d(patch_size*patch_size*hidden//9 + 4, 128, kernel_size=1, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=1, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 3, kernel_size=1, stride=1)
        )

        p = [(ks[i] + (ks[i]-1) * (di[i]-1)) // 2 for i in [0, 1, 2]]

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=ks[0], stride=1, padding=p[0], dilation=di[0]),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=ks[0], stride=1, padding=p[0], dilation=di[0]),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=ks[1], stride=1, padding=p[1], dilation=di[1]),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=ks[1], stride=1, padding=p[1], dilation=di[1]),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=ks[2], stride=1, padding=p[2], dilation=di[2]),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=ks[2], stride=1, padding=p[2], dilation=di[2]),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )


    def forward(self, x):
        B, Cin, H, W = x.shape
        x_p = self.ch_reduce(x)  # bs, H, W, h
        x_p = self.Unfold(x_p).permute(0, 2, 1).reshape(B, -1, self.ps, self.ps)  # bs, H*W*h, ps, ps
        x_p = self.avgpool(x_p).reshape(B, H, W, -1).permute(0, 3, 1, 2)

        x_p_mean = x_p.mean(dim=1, keepdim=True)
        x_p_var = x_p.var(dim=1, keepdim=True)
        x_p_max = x_p.max(dim=1, keepdim=True)[0]
        x_p_min = x_p.min(dim=1, keepdim=True)[0]

        x_p = self.mlp(torch.cat((x_p, x_p_mean, x_p_var, x_p_max, x_p_min), dim=1))
        x_p = F.softmax(x_p * 2, dim=1)

        return self.conv1(x) * x_p[:, 0:1, :, :] + self.conv2(x) * x_p[:, 1:2, :, :] + self.conv3(x) * x_p[:, 2:3, :, :], x_p



class UNet_arf(nn.Module):

    def __init__(self, in_ch=4, out_ch=5):
        super().__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])

        self.ARFConv1_1 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 1, 1], patch_size=9)
        self.ARFConv1_2 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 1, 1], patch_size=9)

        self.ARFConv2_1 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 2, 2], patch_size=15)
        self.ARFConv2_2 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 2, 2], patch_size=15)

        self.ARFConv3_1 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 3, 3], patch_size=15)
        self.ARFConv3_2 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 3, 3], patch_size=15)

        self.ARFConv4_1 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 4, 4], patch_size=21)
        self.ARFConv4_2 = AdaptiveRFConv_block(filters[1], filters[1], ks=[1, 3, 5], di=[1, 4, 4], patch_size=21)

        self.final= conv_block(filters[1], filters[0])
        self.pred = nn.Conv2d(filters[0], out_ch, kernel_size=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        x = self.Conv1(x)
        x = self.Maxpool1(x)
        x = self.Conv2(x)

        x, c1_1 = self.ARFConv1_1(x)
        x, c1_2 = self.ARFConv1_2(x)
        x, c2_1 = self.ARFConv2_1(x)
        x, c2_2 = self.ARFConv2_2(x)
        x, c3_1 = self.ARFConv3_1(x)
        x, c3_2 = self.ARFConv3_2(x)
        x, c4_1 = self.ARFConv4_1(x)
        # x, c4_2 = self.ARFConv4_2(x)

        x = self.final(x)
        x = self.pred(x)
        x = self.up(x)

        return x, [self.up(i) for i in (c1_1, c2_1, c3_1, c4_1)]
