import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy


class Conv2d_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, groups=1, dilation=1,
                 n_tasks=1):
        super(Conv2d_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList(
            [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
                       groups=groups, bias=bias) for i in range(n_tasks)])

    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, conv in enumerate(self.m_list):
                o = conv(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, conv in enumerate(self.m_list):
                o = conv(x[i])
                out.append(o)
        else:
            raise ValueError('Error')
        return out

class BatchNorm2d_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, n_tasks=1):
        super(BatchNorm2d_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList([nn.BatchNorm2d(num_features, eps, momentum, affine,
                                                    track_running_stats, device, dtype) for i in range(n_tasks)])

    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, bn in enumerate(self.m_list):
                o = bn(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, bn in enumerate(self.m_list):
                o = bn(x[i])
                out.append(o)
        else:
            raise ValueError('Error')

        return out

class ReLU_fw(nn.ReLU):
    def __init__(self, inplace=False):
        super(ReLU_fw, self).__init__(inplace)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(ReLU_fw, self).forward(x_i))

        return out

class ReLU6_fw(nn.ReLU6):
    def __init__(self, inplace=False):
        super(ReLU6_fw, self).__init__(inplace)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(ReLU6_fw, self).forward(x_i))

        return out

class Linear_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, in_features, out_features, n_tasks=1):
        super(Linear_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList([nn.Linear(in_features, out_features) for i in range(n_tasks)])

    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')


    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, ln in enumerate(self.m_list):
                o = ln(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, ln in enumerate(self.m_list):
                o = ln(x[i])
                out.append(o)
        else:
            raise ValueError('Error')

        return out

class Identity_fw(nn.Identity):
    def __init__(self):
        super(Identity_fw, self).__init__()

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(Identity_fw, self).forward(x_i))

        return out

class Dropout_fw(nn.Dropout):
    def __init__(self, p=0.5, inplace=False):
        super(Dropout_fw, self).__init__(p, inplace)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(Dropout_fw, self).forward(x_i))

        return out

class AdaptiveAvgPool2d_fw(nn.AdaptiveAvgPool2d):
    def __init__(self, output_size):
        super(AdaptiveAvgPool2d_fw, self).__init__(output_size)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(AdaptiveAvgPool2d_fw, self).forward(x_i))

        return out

def cat_fw(x, y, dim=0):
    if not isinstance(x, list):
        x = [x]
    if not isinstance(y, list):
        y = [y]

    out = []
    if len(x) == 1 and len(y) > 1:
        for i, y_i in enumerate(y):
            o = torch.cat((x[0], y_i), dim=dim)
            out.append(o)
    elif len(x) > 1 and len(y) == 1:
        for i, x_i in enumerate(x):
            o = torch.cat((x_i, y[0]), dim=dim)
            out.append(o)
    elif len(x) == len(y):
        for i, x_i in enumerate(x):
            o = torch.cat((x_i, y[i]), dim=dim)
            out.append(o)
    else:
        raise ValueError('Error')

    return out

def interpolate_fw(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
    if not isinstance(input, list):
        input = [input]

    out = []
    for x in input:
        o = F.interpolate(x, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
        out.append(o)

    return out

def clone_fw(x):
    if not isinstance(x, list):
        x = [x]

    out = [value.clone() for value in x]

    return out

def activation_func_fw(activation):
    return nn.ModuleDict({
        'relu': ReLU_fw(inplace=True),
        'relu6': ReLU6_fw(inplace=True),
        'none': Identity_fw()
    })[activation]

class ConvBNReLU_fw(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, affine=True, activation='relu'):
        super().__init__()
        self.op = nn.Sequential(
            Conv2d_fw(in_channels,
                      out_channels,
                      kernel_size,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      groups=groups,
                      bias=False),
            BatchNorm2d_fw(out_channels, affine=affine),
            activation_func_fw(activation)
        )

    def forward(self, x):
        return self.op(x)

class InvertedResidual_fw(nn.Module):

    def __init__(self, in_channels, out_channels, stride, expansion, kernel_size=3, groups=1,
                 dilation=1, skip_connect=True, final_affine=True, activation='relu'):
        super().__init__()
        assert kernel_size in [1, 3, 5, 7]
        assert stride in [1, 2]
        if stride == 2 and dilation > 1:
            stride = 1
            dilation = dilation // 2
        padding = int((kernel_size - 1) * dilation / 2)
        hidden_dim = round(in_channels * expansion)

        self.chain = []
        if expansion != 1:
            self.chain.append(ConvBNReLU_fw(in_channels,
                                         hidden_dim,
                                         1,
                                         stride=1,
                                         padding=0,
                                         groups=groups,
                                         activation=activation))
        self.chain.extend([
            ConvBNReLU_fw(hidden_dim,
                       hidden_dim,
                       kernel_size,
                       stride=stride,
                       padding=padding,
                       groups=hidden_dim,
                       dilation=dilation,
                       activation=activation),
            ConvBNReLU_fw(hidden_dim,
                       out_channels,
                       1,
                       stride=1,
                       padding=0,
                       groups=groups,
                       affine=final_affine,
                       activation='none')])
        self.chain = nn.Sequential(*self.chain)

        if skip_connect and in_channels == out_channels and stride == 1:
            self.res_flag = True
        else:
            self.res_flag = False

    def forward(self, x):
        identity = x
        out = self.chain(x)
        if self.res_flag:
            if len(identity) == len(out):
                for i in range(len(identity)):
                    out[i] += identity[i]
            elif len(identity) == 1 and len(out) > 1:
                for i in range(len(identity)):
                    out[i] += identity[0]
            elif len(identity) > 1 and len(out) == 1:
                o = []
                for i in range(len(identity)):
                    o.append(out[0] + identity[i])
                out = o

        return out

class RASPP_fw(nn.Module):

    def __init__(self, in_channels, out_channels, activation='relu6',
                 drop_rate=0, final_affine=True):
        super().__init__()

        self.drop_rate = drop_rate

        # 1x1 convolution
        self.aspp_branch_1 = ConvBNReLU_fw(in_channels,
                                        out_channels,
                                        kernel_size=1,
                                        stride=1,
                                        activation=activation)
        # image pooling feature
        self.aspp_branch_2 = nn.Sequential(
            AdaptiveAvgPool2d_fw(output_size=(1, 1)),
            ConvBNReLU_fw(in_channels, out_channels, kernel_size=1, stride=1,
                       activation=activation))

        self.aspp_projection = ConvBNReLU_fw(2 * out_channels, out_channels, kernel_size=1, stride=1,
                                          activation=activation, affine=final_affine)

        self.dropout = Dropout_fw(p=self.drop_rate)

    def forward(self, x):

        h, w = x[0].size(2), x[0].size(3)

        branch_1 = self.aspp_branch_1(x)
        branch_2 = self.aspp_branch_2(x)
        branch_2 = interpolate_fw(input=branch_2, size=(h, w),
                                             mode='bilinear', align_corners=False)

        # Concatenate the parallel streams
        out = cat_fw(branch_1, branch_2, dim=1)

        if self.drop_rate > 0:
            out = self.dropout(out)

        out = self.aspp_projection(out)

        return out

class DeepLabV3PlusDecoder_fw(nn.Module):

    def __init__(self,
                 num_outputs,
                 in_channels_low=256,
                 out_f_classifier=256,
                 use_separable=False,
                 activation='relu'):

        super().__init__()

        projected_filters = 48

        self.low_level_reduce = ConvBNReLU_fw(in_channels_low,
                                           projected_filters,
                                           kernel_size=1,
                                           stride=1,
                                           activation=activation)

        if use_separable:
            self.conv_1 = nn.Sequential(ConvBNReLU_fw(out_f_classifier + projected_filters,
                                                   out_f_classifier + projected_filters,
                                                   kernel_size=3,
                                                   stride=1,
                                                   padding=1,
                                                   groups=out_f_classifier + projected_filters,
                                                   activation=activation),
                                        ConvBNReLU_fw(out_f_classifier + projected_filters,
                                                   out_f_classifier,
                                                   kernel_size=1,
                                                   stride=1,
                                                   activation=activation))
            self.conv_2 = nn.Sequential(ConvBNReLU_fw(out_f_classifier,
                                                   out_f_classifier,
                                                   kernel_size=3,
                                                   stride=1,
                                                   padding=1,
                                                   groups=out_f_classifier,
                                                   activation=activation),
                                        ConvBNReLU_fw(out_f_classifier,
                                                   out_f_classifier,
                                                   kernel_size=1,
                                                   stride=1,
                                                   activation=activation))
        else:
            self.conv_1 = ConvBNReLU_fw(out_f_classifier + projected_filters,
                                     out_f_classifier,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1,
                                     activation=activation)
            self.conv_2 = ConvBNReLU_fw(out_f_classifier,
                                     out_f_classifier,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1,
                                     activation=activation)

        self.conv_logits = Conv2d_fw(out_f_classifier,
                                     num_outputs,
                                     kernel_size=1,
                                     bias=True)

    def forward(self, x, x_low, input_shape):
        decoder_height, decoder_width = x_low.shape[-2:]
        x = interpolate_fw(x, size=(decoder_height, decoder_width),
                                      mode='bilinear',
                                      align_corners=False)

        x_low = self.low_level_reduce(x_low)
        x = cat_fw(x_low, x, dim=1)
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_logits(x)
        x = interpolate_fw(x, size=input_shape, mode='bilinear',
                                      align_corners=False)
        return x
