import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy


class Conv2d_fw_v2(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, groups=1, dilation=1,
                 mapping=None):
        super(Conv2d_fw_v2, self).__init__()
        self.mapping = [frozenset(m) for m in mapping]
        self.n_modules = len(mapping)
        self.m_list = nn.ModuleList(
            [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
                       groups=groups, bias=bias) for i in range(self.n_modules)])

        self.task_map = self._search_mapping()


    def set_mapping(self, mapping):
        n_modules = len(mapping)
        self.mapping = [frozenset(m) for m in mapping]

        if n_modules >= self.n_modules:
            gap = n_modules - self.n_modules
            self.n_modules = n_modules
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
            self.task_map = self._search_mapping()
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')

    def _get_input_id(self, x, m):
        out = []
        for key in x.keys():
            key_set = set(key)
            if m.issubset(key_set):
                out.append(key)

        assert len(out) == 1

        return out[0]


    def _search_mapping(self):
        n_tasks = [len(m) for m in self.mapping]
        n_tasks = sum(n_tasks)

        out = {}
        for task_id in range(n_tasks):
            value = [i for i, m in enumerate(self.mapping) if task_id in m]
            assert len(value) == 1
            out[frozenset([task_id])] = value[0]

        return out

    def forward(self, x):
        assert isinstance(x, dict)

        out = {}

        if len(x) > self.n_modules:
            for task_id, input in x.items():
                module_id = self.task_map[task_id]
                out[task_id] = self.m_list[module_id](input)
        else:
            for i, m in enumerate(self.mapping):
                id = self._get_input_id(x, m)
                out[m] = self.m_list[i](x[id])

        return out

class BatchNorm2d_fw_v2(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, mapping=None):
        super(BatchNorm2d_fw_v2, self).__init__()
        self.mapping = [frozenset(m) for m in mapping]
        self.n_modules = len(mapping)
        self.m_list = nn.ModuleList([nn.BatchNorm2d(num_features, eps, momentum, affine,
                                                    track_running_stats, device, dtype) for i in range(self.n_modules)])

        self.task_map = self._search_mapping()

    def set_mapping(self, mapping):
        n_modules = len(mapping)
        self.mapping = mapping

        if n_modules >= self.n_modules:
            gap = n_modules - self.n_modules
            self.n_modules = n_modules
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
            self.task_map = self._search_mapping()
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')

    def _get_input_id(self, x, m):
        out = []
        for key in x.keys():
            key_set = set(key)
            if m.issubset(key_set):
                out.append(key)

        assert len(out) == 1

        return out[0]


    def _search_mapping(self):
        n_tasks = [len(m) for m in self.mapping]
        n_tasks = sum(n_tasks)

        out = {}
        for task_id in range(n_tasks):
            value = [i for i, m in enumerate(self.mapping) if task_id in m]
            assert len(value) == 1
            out[frozenset([task_id])] = value[0]

        return out

    def forward(self, x):
        assert isinstance(x, dict)

        out = {}

        if len(x) > self.n_modules:
            for task_id, input in x.items():
                module_id = self.task_map[task_id]
                out[task_id] = self.m_list[module_id](input)
        else:
            for i, m in enumerate(self.mapping):
                id = self._get_input_id(x, m)
                out[m] = self.m_list[i](x[id])

        return out

class ReLU_fw_v2(nn.ReLU):
    def __init__(self, inplace=False):
        super(ReLU_fw_v2, self).__init__(inplace)

    def forward(self, x):
        assert isinstance(x, dict)

        out = {}
        for task_id, input in x.items():
            out[task_id] = super(ReLU_fw_v2, self).forward(input)

        return out

class ReLU6_fw_v2(nn.ReLU6):
    def __init__(self, inplace=False):
        super(ReLU6_fw_v2, self).__init__(inplace)

    def forward(self, x):
        assert isinstance(x, dict)

        out = {}
        for task_id, input in x.items():
            out[task_id] = super(ReLU6_fw_v2, self).forward(input)

        return out

# class Linear_fw_v2(nn.Module):  # used in MAML to forward input with fast weight
#     def __init__(self, in_features, out_features, n_tasks=1):
#         super(Linear_fw_v2, self).__init__()
#         self.n_modules = n_tasks
#         self.m_list = nn.ModuleList([nn.Linear(in_features, out_features) for i in range(n_tasks)])
#
#     def set_n_tasks(self, n_tasks=1):
#         if n_tasks >= self.n_modules:
#             gap = n_tasks - self.n_modules
#             self.n_modules = n_tasks
#             for i in range(gap):
#                 module = deepcopy(self.m_list[0])
#                 self.m_list.append(module)
#         else:
#             raise ValueError('Can not decrease the number of tasks in fw module')
#
#
#     def forward(self, x):
#         if not isinstance(x, list):
#             x = [x]
#
#         out = []
#         if len(x) == 1 and self.n_modules > 1:
#             for i, ln in enumerate(self.m_list):
#                 o = ln(x[0])
#                 out.append(o)
#         elif len(x) > 1 and self.n_modules == 1:
#             for i, x_i in enumerate(x):
#                 o = self.m_list[0](x_i)
#                 out.append(o)
#         elif len(x) == self.n_modules:
#             for i, ln in enumerate(self.m_list):
#                 o = ln(x[i])
#                 out.append(o)
#         else:
#             raise ValueError('Error')
#
#         return out

class Identity_fw_v2(nn.Identity):
    def __init__(self):
        super(Identity_fw_v2, self).__init__()

    def forward(self, x):
        assert isinstance(x, dict)

        out = {}
        for task_id, input in x.items():
            out[task_id] = super(Identity_fw_v2, self).forward(input)

        return out

class Dropout_fw_v2(nn.Dropout):
    def __init__(self, p=0.5, inplace=False):
        super(Dropout_fw_v2, self).__init__(p, inplace)

    def forward(self, x):
        assert isinstance(x, dict)

        out = {}
        for task_id, input in x.items():
            out[task_id] = super(Dropout_fw_v2, self).forward(input)

        return out

class AdaptiveAvgPool2d_fw_v2(nn.AdaptiveAvgPool2d):
    def __init__(self, output_size):
        super(AdaptiveAvgPool2d_fw_v2, self).__init__(output_size)

    def forward(self, x):
        assert isinstance(x, dict)

        out = {}
        for task_id, input in x.items():
            out[task_id] = super(AdaptiveAvgPool2d_fw_v2, self).forward(input)

        return out

def cat_fw_v2(x, y, dim=0):
    assert isinstance(x, dict)
    assert isinstance(y, dict)
    assert len(x) == len(y)

    out = {}
    for key in x.keys():
        o = torch.cat((x[key], y[key]), dim=dim)
        out[key] = o

    return out

def interpolate_fw_v2(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
    assert isinstance(input, dict)

    out = {}
    for task_id, input in input.items():
        o = F.interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
        out[task_id] = o

    return out


def clone_fw_v2(x):
    assert isinstance(x, dict)

    out = {id: value.clone() for id, value in x.items()}

    return out

def activation_func_fw(activation):
    return nn.ModuleDict({
        'relu': ReLU_fw_v2(inplace=True),
        'relu6': ReLU6_fw_v2(inplace=True),
        'none': Identity_fw_v2()
    })[activation]

class ConvBNReLU_fw_v2(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, affine=True, activation='relu', mapping=None):
        super(ConvBNReLU_fw_v2, self).__init__()
        self.op = nn.Sequential(
            Conv2d_fw_v2(in_channels,
                      out_channels,
                      kernel_size,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      groups=groups,
                      bias=False,
                      mapping=mapping),
            BatchNorm2d_fw_v2(out_channels, affine=affine, mapping=mapping),
            activation_func_fw(activation)
        )

    def forward(self, x):
        return self.op(x)

class InvertedResidual_fw_v2(nn.Module):

    def __init__(self, in_channels, out_channels, stride, expansion, kernel_size=3, groups=1,
                 dilation=1, skip_connect=True, final_affine=True, activation='relu', mapping=None):
        super().__init__()
        assert kernel_size in [1, 3, 5, 7]
        assert stride in [1, 2]
        if stride == 2 and dilation > 1:
            stride = 1
            dilation = dilation // 2
        padding = int((kernel_size - 1) * dilation / 2)
        hidden_dim = round(in_channels * expansion)

        self.mapping = mapping

        self.chain = []
        if expansion != 1:
            self.chain.append(ConvBNReLU_fw_v2(in_channels,
                                         hidden_dim,
                                         1,
                                         stride=1,
                                         padding=0,
                                         groups=groups,
                                         activation=activation, mapping=mapping))
        self.chain.extend([
            ConvBNReLU_fw_v2(hidden_dim,
                       hidden_dim,
                       kernel_size,
                       stride=stride,
                       padding=padding,
                       groups=hidden_dim,
                       dilation=dilation,
                       activation=activation, mapping=mapping),
            ConvBNReLU_fw_v2(hidden_dim,
                       out_channels,
                       1,
                       stride=1,
                       padding=0,
                       groups=groups,
                       affine=final_affine,
                       activation='none', mapping=mapping)])

        self.chain = nn.Sequential(*self.chain)

        if skip_connect and in_channels == out_channels and stride == 1:
            self.res_flag = True
        else:
            self.res_flag = False

    def _search_module(self, k, keys):
        id = []
        for value in keys:
            if k in value:
                id.append(value)

        assert len(id) == 1
        return id

    def forward(self, x):
        identity = x
        out = self.chain(x)
        if self.res_flag:
            if len(out) == len(identity):
                for key in out.keys():
                    out[key] += identity[key]
            elif len(out) > len(identity):
                for key in out.keys():
                    moudle_id = self._search_module(key, identity.keys())
                    out[key] += identity[moudle_id]

        return out

class RASPP_fw_v2(nn.Module):

    def __init__(self, in_channels, out_channels, activation='relu6',
                 drop_rate=0, final_affine=True, mapping=None):
        super(RASPP_fw_v2, self).__init__()

        self.drop_rate = drop_rate
        self.mapping = mapping

        # 1x1 convolution
        self.aspp_branch_1 = ConvBNReLU_fw_v2(in_channels,
                                        out_channels,
                                        kernel_size=1,
                                        stride=1,
                                        activation=activation,
                                        mapping=mapping)
        # image pooling feature
        self.aspp_branch_2 = nn.Sequential(
            AdaptiveAvgPool2d_fw_v2(output_size=(1, 1)),
            ConvBNReLU_fw_v2(in_channels, out_channels, kernel_size=1, stride=1,
                       activation=activation, mapping=mapping))

        self.aspp_projection = ConvBNReLU_fw_v2(2 * out_channels, out_channels, kernel_size=1, stride=1,
                                          activation=activation, affine=final_affine, mapping=mapping)

        self.dropout = Dropout_fw_v2(p=self.drop_rate)

    def forward(self, x):
        h, w = list(x.values())[0].size(2), list(x.values())[0].size(3)

        branch_1 = self.aspp_branch_1(x)
        branch_2 = self.aspp_branch_2(x)
        branch_2 = interpolate_fw_v2(input=branch_2, size=(h, w),
                                             mode='bilinear', align_corners=False)

        # Concatenate the parallel streams
        out = cat_fw_v2(branch_1, branch_2, dim=1)

        if self.drop_rate > 0:
            out = self.dropout(out)

        out = self.aspp_projection(out)

        return out

