import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import math
from collections import OrderedDict
from RFN import RFN

import torch
import torch.nn as nn
import torch.nn.functional as F

# channel-wise global time residual
class Residual1D(nn.Module):
    def __init__(self, channels, k = 5, dilation = 1):
        super().__init__()
        pad = (k // 2) * dilation
        self.conv1d = nn.Conv1d(channels, channels, kernel_size=k,padding=pad, dilation=dilation, bias=False)

    def forward(self, x):
        y = x.mean(dim=2) # frequency mean
        y = self.conv1d(y) # 1 D conv
        y = y.unsqueeze(2).expand_as(x)
        return x + y

# BC-Resnet block
class BCResBlock(nn.Module):
    def __init__(self, in_c, out_c, stride=(1, 1), dilation = 1):
        super().__init__()
        k = 5
        pad = (k // 2, dilation * (k //2))
        self.conv1 = nn.Conv2d(in_c, out_c, k, stride=stride, padding=pad, dilation = (1,dilation), bias=False)
        self.bn1   = nn.BatchNorm2d(out_c)

        self.conv2 = nn.Conv2d(out_c, out_c, k, padding=pad, dilation = (1,dilation), bias=False)
        self.bn2   = nn.BatchNorm2d(out_c)

        self.br1d  = Residual1D(out_c, k=k, dilation = dilation)

        self.shortcut = (
            nn.Identity()
            if in_c == out_c and stride == (1, 1)
            else nn.Sequential(
                nn.Conv2d(in_c, out_c, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_c),
            )
        )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.br1d(out)       # 1-D residual
        out += self.shortcut(x)    # skip
        return F.relu(out)



class BCResNet1(nn.Module):
    def __init__(self, n_classes = 10, base_c = 16, tau = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 2*base_c, 5, stride=2, padding=2, bias=False)
        self.bn1   = nn.BatchNorm2d(2*base_c)
        self.relu  = nn.ReLU(inplace=True)
        
        self.rfn = RFN(lam=0.5)

        cfg = [
            # (out_c, n_blocks)
            (base_c,2),
            (int(base_c*1.5),2),
            (base_c*2, 2),
            (int(base_c*2.5), 3),
        ]

        layers, in_c = [], 2*base_c
        for idx, (out_c, n_blocks) in enumerate(cfg):
            for i in range(n_blocks):
                layers.append(BCResBlock(in_c, out_c))
                in_c = out_c
            layers.append(RFN(lam=0.5))
            if idx in (0,1):
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        self.layers = nn.Sequential(*layers)
        self.final_conv = nn.Conv2d(in_c, n_classes, kernel_size=1, bias=False)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.out_dim = in_c
        
        
        
    def forward(self, x):
        x = self.rfn(x) # input
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layers(x)
        x = self.final_conv(x)
        x = self.global_pool(x)
        x = x.squeeze(-1).squeeze(-1) 
        return x


class AudioClassifier(nn.Module):
    def __init__(self, in_dim: int, n_classes: int):
        super().__init__()
        self.fc = nn.Linear(in_dim, n_classes)

    def forward(self, x):
        return self.fc(x)





class AlexNetCaffe(nn.Module):
    def __init__(self, n_classes=1000):
        super(AlexNetCaffe, self).__init__()
        self.features = nn.Sequential(OrderedDict([
            ("conv1", nn.Conv2d(3, 96, kernel_size=11, stride=4)),
            ("relu1", nn.ReLU(inplace=True)),
            ("pool1", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
            ("norm1", nn.LocalResponseNorm(5, 1.e-4, 0.75)),
            ("conv2", nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2)),
            ("relu2", nn.ReLU(inplace=True)),
            ("pool2", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
            ("norm2", nn.LocalResponseNorm(5, 1.e-4, 0.75)),
            ("conv3", nn.Conv2d(256, 384, kernel_size=3, padding=1)),
            ("relu3", nn.ReLU(inplace=True)),
            ("conv4", nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2)),
            ("relu4", nn.ReLU(inplace=True)),
            ("conv5", nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2)),
            ("relu5", nn.ReLU(inplace=True)),
            ("pool5", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
        ]))
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(OrderedDict([
            ("fc6", nn.Linear(256 * 6 * 6, 4096)),
            ("relu6", nn.ReLU(inplace=True)),
            ("drop6", nn.Dropout()),
            ("fc7", nn.Linear(4096, 4096)),
            ("relu7", nn.ReLU(inplace=True)),
            ("drop7", nn.Dropout()),
            ("fc8", nn.Linear(4096, n_classes))]))
     
    def forward(self, x):
        x = self.features(x * 57.6)
        x = self.avgpool(x)
        x = x.flatten(1)
        x = self.classifier(x)
        return x


class AlexNetBackbone(nn.Module):

    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

    def forward(self, x):
        return self.backbone(x * 57.6)


class BasicBlock(nn.Module):
    """Basic ResNet block."""

    def __init__(self, in_planes, out_planes, stride, drop_rate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.drop_rate = drop_rate
        self.is_in_equal_out = (in_planes == out_planes)
        self.conv_shortcut = (not self.is_in_equal_out) and nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size=1,
            stride=stride,
            padding=0,
            bias=False) or None

    def forward(self, x):
        if not self.is_in_equal_out:
            x = self.relu1(self.bn1(x))
            out = self.relu2(self.bn2(self.conv1(x)))
        else:
            out = self.relu1(self.bn1(x))
            out = self.relu2(self.bn2(self.conv1(out)))
        if self.drop_rate > 0:
            out = F.dropout(out, p=self.drop_rate, training=self.training)
        out = self.conv2(out)
        if not self.is_in_equal_out:
            return torch.add(self.conv_shortcut(x), out)
        else:
            return torch.add(x, out)


class NetworkBlock(nn.Module):
    """Layer container for blocks."""

    def __init__(self,
                 nb_layers,
                 in_planes,
                 out_planes,
                 block,
                 stride,
                 drop_rate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers,
                                      stride, drop_rate)

    @staticmethod
    def _make_layer(block, in_planes, out_planes, nb_layers, stride,
                    drop_rate):
        layers = []
        for i in range(nb_layers):
            layers.append(
                block(i == 0 and in_planes or out_planes, out_planes,
                      i == 0 and stride or 1, drop_rate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    """WideResNet class."""

    def __init__(self, depth, num_classes, widen_factor=1, drop_rate=0.0):
        super(WideResNet, self).__init__()
        n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert (depth - 4) % 6 == 0
        n = (depth - 4) // 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(
            3, n_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1,
                                   drop_rate)
        # 2nd block
        self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2,
                                   drop_rate)
        # 3rd block
        self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2,
                                   drop_rate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(n_channels[3])
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = self.pool(out)
        out = self.flatten(out)
        return self.mlp(out)

        

def get_network(network, n_classes, instancenorm = True):
    if network == 'resnet18':
        if instancenorm == True:
            model = torchvision.models.resnet18(pretrained=True)
            model = list(model.children())  # conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc

            backbone = nn.Sequential(
                *model[:6],   # conv1,bn1,relu,maxpool,layer1,layer2, layer3, layer4
                nn.InstanceNorm2d(128, affine=True),
            )
            classifier = nn.Sequential(
                *model[6:-1],
                nn.AdaptiveAvgPool2d((1,1)),
                nn.Flatten(),
                nn.Linear(512, n_classes))
        else:
            model = torchvision.models.resnet18(pretrained=True)
            model = list(model.children())  # conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc

            backbone = nn.Sequential(
                *model[:6],   # conv1,bn1,relu,maxpool,layer1,layer2, layer3, layer4
                nn.InstanceNorm2d(128, affine=True),
            )
            classifier = nn.Sequential(
                *model[6:-1],
                nn.AdaptiveAvgPool2d((1,1)),
                nn.Flatten(),
                nn.Linear(512, n_classes))
        

    elif network == 'wide_resnet164':
        if instancenorm == True:
            model = WideResNet(16, num_classes=n_classes, widen_factor=4, drop_rate=0.0)
            backbone = nn.Sequential(model.conv1, model.block1,
                                    nn.InstanceNorm2d(64, affine=True)
                                    )
            classifier = nn.Sequential(model.block2, model.block3, 
                                    model.bn1, model.relu, model.pool, model.flatten, nn.Linear(256, n_classes))
        else:
            model = WideResNet(16, num_classes=n_classes, widen_factor=4, drop_rate=0.0)
            backbone = nn.Sequential(model.conv1, model.block1)
            classifier = nn.Sequential(model.block2, model.block3, 
                                    model.bn1, model.relu, model.pool, model.flatten, nn.Linear(256, n_classes))
    
    elif network == 'bcresnet18':
        backbone = BCResNet1(n_classes=n_classes, tau=1)
        classifier = AudioClassifier(backbone.out_dim, n_classes)
    
    else:
        raise NotImplementedError
    return backbone, classifier

