import numpy as np
import torch.nn as nn
from src.AIDomains.concrete_layers import Normalization
from dsnt_lib import dsnt, faster_dsnt, flat_softmax, fast_dsnt_cov

def init_weights(net):
    # Initialize weights for new layers
    for m in net.modules():
        if isinstance(m, nn.ConvTranspose2d):
            nn.init.normal_(m.weight, std=0.001)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, std=0.001)
            nn.init.constant_(m.bias, 0)

class UpsamplingHead(nn.Module):
    def __init__(self,
                 upsampling='nearest',
                 n_channels=[256, 256, 256],
                 kernels=[4, 4, 4],
                 in_channels=2048):
        super().__init__()
        self.upsampling = upsampling
        self.n_channels = n_channels
        self.kernels = kernels
        self.in_channels = in_channels
        self.out_channels = n_channels[-1]
        if self.upsampling not in ['nearest', 'deconv']:
            raise NotImplementedError
        if self.upsampling == 'nearest':
            self._net = self._build_upsample_head()
        else:
            self._net = self._build_deconv_head()

    def _build_deconv_head(self):
        layers = []
        for i, k in enumerate(self.kernels):
            #print('--------------',i,k)
            if k == 4:
                padding = 1
                out_padding = 0
            elif k == 3:
                padding = 1
                out_padding = 1
            elif k == 2:
                padding = 0
                out_padding = 0
            if i == 0:
                in_c = self.in_channels
                out_c = self.n_channels[i]
            else:
                in_c = self.n_channels[i - 1]
                out_c = self.n_channels[i]
            layers.append(nn.ConvTranspose2d(
                in_c, out_c, k, 2, padding, out_padding, bias=False))
            layers.append(nn.BatchNorm2d(out_c, momentum=0.1))
            layers.append(nn.ReLU(inplace=True))
        return nn.Sequential(*layers)

    def _build_upsample_head(self,
                             scale_factor=2,
                             interpolate='nearest',
                             align_corners=False):
        layers = []
        layer_scale = scale_factor
        for i in range(len(self.n_channels)):
            if i == 0:
                in_c = self.in_channels
                out_c = self.n_channels[i]
            else:
                in_c = self.n_channels[i - 1]
                out_c = self.n_channels[i]
            layers.append(nn.Conv2d(in_c, out_c, 3, 1, 1))
            layers.append(nn.BatchNorm2d(out_c, momentum=0.1))
            layers.append(nn.ReLU(inplace=True))
            if interpolate == 'nearest':
                layers.append(nn.Upsample(scale_factor=layer_scale))
            else:
                layers.append(nn.Upsample(scale_factor=layer_scale,
                                          mode=interpolate,
                                          align_corners=align_corners))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self._net(x)

class myPoseNet(nn.Module):
    def __init__(self, device, dataset, n_class=10, input_size=32, input_channel=3, num_keypoints=24, upsamp_channels=[128,128,128,64,64], up_samp=False, conv_widths=None,
                 kernel_sizes=None, linear_sizes=None, depth_conv=None, paddings=None, strides=None,
                 dilations=None, pool=False, net_dim=None, bn=True, bn2=False, max=True, scale_width=True, mean=0, sigma=1):
        super(myPoseNet, self).__init__()
        if kernel_sizes is None:
            kernel_sizes = [3]
        if conv_widths is None:
            conv_widths = [2]
        if linear_sizes is None:
            linear_sizes = [200]
        if paddings is None:
            paddings = [1]
        if strides is None:
            strides = [2]
        if dilations is None:
            dilations = [1]
        if net_dim is None:
            net_dim = input_size

        if len(conv_widths) != len(kernel_sizes):
            kernel_sizes = len(conv_widths) * [kernel_sizes[0]]
        if len(conv_widths) != len(paddings):
            paddings = len(conv_widths) * [paddings[0]]
        if len(conv_widths) != len(strides):
            strides = len(conv_widths) * [strides[0]]
        if len(conv_widths) != len(dilations):
            dilations = len(conv_widths) * [dilations[0]]

        self.n_class=n_class
        self.input_size=input_size
        self.input_channel=input_channel
        self.conv_widths=conv_widths
        self.kernel_sizes=kernel_sizes
        self.paddings=paddings
        self.strides=strides
        self.dilations = dilations
        self.linear_sizes=linear_sizes
        self.depth_conv=depth_conv
        self.net_dim = net_dim
        self.bn=bn
        self.bn2=bn2
        self.max=max
        self.up_samp=up_samp
        self.num_keypoints=num_keypoints

        if dataset == "fashionmnist":
            mean = 0.1307
            sigma = 0.3081
        elif dataset == "cifar10":
            mean = [0.4914, 0.4822, 0.4465]
            sigma = [0.2023, 0.1994, 0.2010]
        elif dataset == "tinyimagenet":
            mean = [0.4802, 0.4481, 0.3975]
            sigma = [0.2302, 0.2265, 0.2262]
        elif dataset == "boeing":
            mean = [0.485, 0.456, 0.406]
            sigma = [0.299, 0.224, 0.225]

        layers = []
        layers += [Normalization((input_channel,input_size,input_size),mean, sigma)]

        N = net_dim
        n_channels = input_channel
        self.dims = [(n_channels,N,N)]
        conv_index = 0

        for width, kernel_size, padding, stride, dilation in zip(conv_widths, kernel_sizes, paddings, strides, dilations):
            conv_index += 1
            if scale_width:
                width *= 16
            N = int(np.floor((N + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
            layers += [nn.Conv2d(n_channels, int(width), kernel_size, stride=stride, padding=padding, dilation=dilation)]
            if self.bn:
                layers += [nn.BatchNorm2d(int(width))]
            if self.max:
                if conv_index != 5:
                    layers += [nn.MaxPool2d(2)]
            layers += [nn.ReLU()]
            n_channels = int(width)
            self.dims += 2*[(n_channels,N,N)]

        if self.up_samp:
            upsampling = 'nearest'
        else:
            upsampling = 'deconv'
        n_layers_upsamp = 5
        assert len(upsamp_channels) == n_layers_upsamp, \
            'Length of upsamp_channels must match n_layers_upsamp'
        #deconv_kernels = [4] * n_layers_upsamp
        deconv_kernels = [2] * n_layers_upsamp  # NOTE CMU edit as wanted padding of 0
        feature_channels = 128
        up_head = UpsamplingHead(
            upsampling=upsampling,
            n_channels=upsamp_channels,
            kernels=deconv_kernels,
            in_channels=feature_channels)._net

        if upsampling == 'nearest':
            up_head.apply(init_weights)
            layers += [up_head]
        else:
            up_head.apply(init_weights)
            layers += [up_head]

        final_layer = nn.Conv2d(
            in_channels=upsamp_channels[-1],
            out_channels=self.num_keypoints,
            kernel_size=1,
            stride=1,
            padding=0
        )
        final_layer.apply(init_weights)
        layers += [final_layer]
        
        '''
        if upsampling == 'nearest':
            self.up = up_head
            self.up.apply(init_weights)
            layers += [self.up]
        else:
            self.deconv_layers = up_head
            self.deconv_layers.apply(init_weights)
            layers += [self.deconv_layers]

        self.final_layer = nn.Conv2d(
            in_channels=upsamp_channels[-1],
            out_channels=self.num_keypoints,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.final_layer.apply(init_weights)
        layers += [self.final_layer]
        '''

        self.blocks = nn.Sequential(*layers)

    def forward(self, x):
        return self.blocks(x)

class MyPoseNetDSNT(myPoseNet):
    """
    ResNet model that disregards final pooling and fc layers in favor of deconv
    layers and a final conv layer to output keypoint heatmaps. Borrows from
    Microsoft Human Pose Estimation project.
    """

    def forward(self, x):
        x = self.blocks(x)
        # DSNT. Normalize the heatmaps
        heatmaps = flat_softmax(x)
        # DSNT. Calculate the coordinates
        # coords = dsnt(heatmaps)  # our local edits for dsnt
        coords = faster_dsnt(heatmaps)
        # IPython.embed()
        return heatmaps, coords



class myNet(nn.Module):
    def __init__(self, device, dataset, n_class=10, input_size=32, input_channel=3, conv_widths=None,
                 kernel_sizes=None, linear_sizes=None, depth_conv=None, paddings=None, strides=None,
                 dilations=None, pool=False, net_dim=None, bn=False, bn2=False, max=False, scale_width=True, mean=0, sigma=1):
        super(myNet, self).__init__()
        if kernel_sizes is None:
            kernel_sizes = [3]
        if conv_widths is None:
            conv_widths = [2]
        if linear_sizes is None:
            linear_sizes = [200]
        if paddings is None:
            paddings = [1]
        if strides is None:
            strides = [2]
        if dilations is None:
            dilations = [1]
        if net_dim is None:
            net_dim = input_size

        if len(conv_widths) != len(kernel_sizes):
            kernel_sizes = len(conv_widths) * [kernel_sizes[0]]
        if len(conv_widths) != len(paddings):
            paddings = len(conv_widths) * [paddings[0]]
        if len(conv_widths) != len(strides):
            strides = len(conv_widths) * [strides[0]]
        if len(conv_widths) != len(dilations):
            dilations = len(conv_widths) * [dilations[0]]

        self.n_class=n_class
        self.input_size=input_size
        self.input_channel=input_channel
        self.conv_widths=conv_widths
        self.kernel_sizes=kernel_sizes
        self.paddings=paddings
        self.strides=strides
        self.dilations = dilations
        self.linear_sizes=linear_sizes
        self.depth_conv=depth_conv
        self.net_dim = net_dim
        self.bn=bn
        self.bn2=bn2
        self.max=max

        if dataset == "fashionmnist":
            mean = 0.1307
            sigma = 0.3081
        elif dataset == "cifar10":
            mean = [0.4914, 0.4822, 0.4465]
            sigma = [0.2023, 0.1994, 0.2010]
        elif dataset == "tinyimagenet":
            mean = [0.4802, 0.4481, 0.3975]
            sigma = [0.2302, 0.2265, 0.2262]

        layers = []
        layers += [Normalization((input_channel,input_size,input_size),mean, sigma)]

        N = net_dim
        n_channels = input_channel
        self.dims = [(n_channels,N,N)]

        for width, kernel_size, padding, stride, dilation in zip(conv_widths, kernel_sizes, paddings, strides, dilations):
            if scale_width:
                width *= 16
            N = int(np.floor((N + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
            layers += [nn.Conv2d(n_channels, int(width), kernel_size, stride=stride, padding=padding, dilation=dilation)]
            if self.bn:
                layers += [nn.BatchNorm2d(int(width))]
            if self.max:
                layers += [nn.MaxPool2d(int(width))]
            layers += [nn.ReLU((int(width), N, N))]
            n_channels = int(width)
            self.dims += 2*[(n_channels,N,N)]

        if depth_conv is not None:
            layers += [nn.Conv2d(n_channels, depth_conv, 1, stride=1, padding=0),
                       nn.ReLU((n_channels, N, N))]
            n_channels = depth_conv
            self.dims += 2*[(n_channels,N,N)]

        if pool:
            layers += [nn.GlobalAvgPool2d()]
            self.dims += 2 * [(n_channels, 1, 1)]
            N=1

        layers += [nn.Flatten()]
        N = n_channels * N ** 2
        self.dims += [(N,)]

        for width in linear_sizes:
            if width == 0:
                continue
            layers += [nn.Linear(int(N), int(width))]
            if self.bn2:
                layers += [nn.BatchNorm1d(int(width))]
            layers += [nn.ReLU(width)]
            N = width
            self.dims+=2*[(N,)]

        layers += [nn.Linear(N, n_class)]
        self.dims+=[(n_class,)]

        self.blocks = nn.Sequential(*layers)

    def forward(self, x):
        return self.blocks(x)


class FFNN(myNet):
    def __init__(self, device, dataset, sizes, n_class=10, input_size=32, input_channel=3, net_dim=None):
        super(FFNN, self).__init__(device, dataset, n_class, input_size, input_channel, conv_widths=[],
                                  linear_sizes=sizes, net_dim=net_dim)


def ConvMed_tiny(dataset, bn=False, bn2=False, device="cuda"):
    in_ch, in_dim, n_class = get_dataset_info(dataset)
    return myNet(device, dataset, n_class, in_dim, in_ch, conv_widths=[1,2], kernel_sizes=[5,4],
                 linear_sizes=[50],  strides=[2,2], paddings=[1,1], net_dim=None, bn=bn, bn2=bn2)


def CNN7(dataset, bn, bn2, device="cuda"):
    in_ch, in_dim, n_class = get_dataset_info(dataset)
    return myNet(device, dataset, n_class, in_dim, in_ch,
                                   conv_widths=[4, 4, 8, 8, 8], kernel_sizes=[3, 3, 3, 3, 3],
                                   linear_sizes=[512], strides=[1, 1, 2, 1, 1], paddings=[1, 1, 1, 1, 1],
                                   net_dim=None, bn=bn, bn2=bn2)

def CNN7_pose(dataset, num_keypoints, upsamp_channels, up_samp, bn=True, bn2=True,device="cuda"):
    in_ch, in_dim, n_class = get_dataset_info(dataset)
    return MyPoseNetDSNT(device, dataset, n_class, in_dim, in_ch, num_keypoints, upsamp_channels, 
                                   conv_widths=[4, 4, 8, 8, 8], kernel_sizes=[3, 3, 3, 3, 3],
                                   linear_sizes=[512], strides=[2, 1, 1, 1, 1], paddings=[1, 1, 1, 1, 1],
                                   net_dim=None, bn=bn, bn2=bn2)


def CNN7_narrow(dataset, bn, bn2, device="cuda"):
    in_ch, in_dim, n_class = get_dataset_info(dataset)
    return myNet(device, dataset, n_class, in_dim, in_ch,
                                   conv_widths=[2, 2, 4, 4, 4], kernel_sizes=[3, 3, 3, 3, 3],
                                   linear_sizes=[216], strides=[1, 1, 2, 1, 1], paddings=[1, 1, 1, 1, 1],
                                   net_dim=None, bn=bn, bn2=bn2)


def CNN7_wide(dataset, bn, bn2, device="cuda"):
    in_ch, in_dim, n_class = get_dataset_info(dataset)
    return myNet(device, dataset, n_class, in_dim, in_ch,
                                   conv_widths=[6, 6, 12, 12, 12], kernel_sizes=[3, 3, 3, 3, 3],
                                   linear_sizes=[512], strides=[1, 1, 2, 1, 1], paddings=[1, 1, 1, 1, 1],
                                   net_dim=None, bn=bn, bn2=bn2)


def get_dataset_info(dataset):
    if dataset == "mnist":
        return 1, 28, 10
    elif dataset == "emnist":
        return 1, 28, 10
    elif dataset == "fashionmnist":
        return 1, 28, 10
    if dataset == "svhn":
        return 3, 32, 10
    elif dataset == "cifar10":
        return 3, 32, 10
    elif dataset == "tinyimagenet":
        return 3, 56, 200
    elif dataset == "boeing":
        return 3, 256, 10
    else:
        raise ValueError(f"Dataset {dataset} not available")


Models = {
    'ConvMed_tiny': ConvMed_tiny,
    'CNN7': CNN7,
    'CNN7_narrow': CNN7_narrow,
    'CNN7_wide': CNN7_wide,
    'CNN7_pose': CNN7_pose,
}
