import torch
import torch.nn as nn
import torch.nn.functional as F

def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")



class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(
                dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
            )
        else:
            assert self.channels == self.out_channels
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)


class ResnetBlock(nn.Module):
    def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
        super().__init__()
        ps = ksize // 2
        if in_c != out_c or sk == False:
            self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
        else:
            # print('n_in')
            self.in_conv = None
        self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.act = nn.ReLU()
        self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
        if sk == False:
            # self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) # edit by zhouxiawang
            self.skep = nn.Conv2d(out_c, out_c, ksize, 1, ps)
        else:
            self.skep = None

        self.down = down
        if self.down == True:
            self.down_opt = Downsample(in_c, use_conv=use_conv)

    def forward(self, x):
        if self.down == True:
            x = self.down_opt(x)
        if self.in_conv is not None:  # edit
            x = self.in_conv(x)

        h = self.block1(x)
        h = self.act(h)
        h = self.block2(h)
        if self.skep is not None:
            return h + self.skep(x)
        else:
            return h + x

class Control_adapter(nn.Module):
    def __init__(self, channels=[320, 640, 1280, 1280], 
                 nums_rb=2, cin=64, ksize=3, sk=True, 
                 use_conv=False, align_training_size = 0):
        super(Control_adapter, self).__init__()
        self.align_training_size = align_training_size
        self.unshuffle = nn.PixelUnshuffle(8)
        self.channels = channels
        self.nums_rb = nums_rb
        self.body = []
        for i in range(len(channels)):
            for j in range(nums_rb):
                if (i != 0) and (j == 0):
                    self.body.append(
                        ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
                else:
                    self.body.append(
                        ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
        self.body = nn.ModuleList(self.body)
        self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)

        self.image_projs = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(dim, dim//8)
            ) for dim in channels
        ])
    def forward(self, x):
        # import pdb; pdb.set_trace()
        if self.align_training_size > 0:
            org_b, org_c, org_h, org_w = x.shape
            x = F.pixel_unshuffle(x, self.align_training_size) # [B, C**, H, W] -> [B, C, h, w]
            x = x.reshape(org_b, org_c, -1, org_h//self.align_training_size, org_w//self.align_training_size)
            x = x.permute(0, 2, 1, 3, 4).reshape(-1, org_c, org_h//self.align_training_size, org_w//self.align_training_size)

        # unshuffle
        x = self.unshuffle(x)
        # extract features
        features = []
        x = self.conv_in(x)
        for i in range(len(self.channels)):
            for j in range(self.nums_rb):
                idx = i * self.nums_rb + j
                x = self.body[idx](x)
            features.append(x)

        if self.align_training_size > 0:
            for i in range(len(features)):
                feat = features[i]
                cur_b, cur_c, cur_h, cur_w = feat.shape
                feat = feat.reshape(org_b, -1, cur_c, cur_h, cur_w)
                feat = feat.permute(0, 2, 1, 3, 4).reshape(org_b, -1, cur_h, cur_w)
                feat = F.pixel_shuffle(feat, self.align_training_size)
                features[i] = feat
        projected_feats = []
        for feat, proj in zip(features, self.image_projs):
            projected = proj(feat)
            projected_feats.append(projected)
        multi_scale_feat = torch.cat(projected_feats, dim=-1)
        return multi_scale_feat

if __name__ == '__main__':
    adapter = Control_adapter(
    channels=[32, 64, 128, 128], 
    nums_rb=2, 
    cin=192,  #  3 
    ksize=3, 
    sk=True, 
    use_conv=False, 
    ).cuda()
    x = torch.randn(1, 3, 512, 512).cuda()
    y = adapter(x)
    print(y.shape)  # 