from collections import OrderedDict

import torch
import torch.nn as nn
from torchmeta.modules import (MetaModule, MetaConv2d,
                               MetaSequential, MetaLinear)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def conv_block(in_channels, out_channels, **kwargs):
    return MetaSequential(OrderedDict([
        ('conv', MetaConv2d(in_channels, out_channels, **kwargs)),
        ('norm', nn.BatchNorm2d(out_channels, momentum=1.,
                                track_running_stats=False)),
        ('relu', nn.ReLU()),
        ('pool', nn.MaxPool2d(2))
    ]))


def conv_drop_block(in_channels, out_channels, drop_p, **kwargs):
    return MetaSequential(OrderedDict([
        ('conv', MetaConv2d(in_channels, out_channels, **kwargs)),
        ('norm', nn.BatchNorm2d(out_channels, momentum=1.,
                                track_running_stats=False)),
        ('relu', nn.ReLU()),
        ('dropout', nn.Dropout(drop_p)),
        ('pool', nn.MaxPool2d(2)),
    ]))


class MetaConvModel(MetaModule):
    def __init__(self, in_channels, out_features, hidden_size=64, feature_size=64, drop_p=0.):
        super(MetaConvModel, self).__init__()
        self.in_channels = in_channels
        self.out_features = out_features
        self.hidden_size = hidden_size
        self.feature_size = feature_size
        self.drop_p = drop_p
        self.anil = False

        kwargs = {}
        if self.drop_p > 0.:
            conv = conv_drop_block
            kwargs['drop_p'] = self.drop_p
            self.drop_classifer = nn.Identity()
        else:
            conv = conv_block
            self.drop_classifer = nn.Identity()

        self.classifier = MetaLinear(feature_size, out_features, bias=True)

        self.features = MetaSequential(OrderedDict([
            ('layer1', conv(in_channels, hidden_size, kernel_size=3,
                            stride=1, padding=1, bias=True, **kwargs)),
            ('layer2', conv(hidden_size, hidden_size, kernel_size=3,
                            stride=1, padding=1, bias=True, **kwargs)),
            ('layer3', conv(hidden_size, hidden_size, kernel_size=3,
                            stride=1, padding=1, bias=True, **kwargs)),
            ('layer4', conv(hidden_size, hidden_size, kernel_size=3,
                            stride=1, padding=1, bias=True, **kwargs))
        ]))

    def _forward_all(self, inputs, params=None, inner_update_type='encoder_only'):
        if self.anil or inner_update_type=='linear_only':
            params_feature = None
        else:
            params_feature = self.get_subdict(params, 'features')

        features = self.features(inputs, params=params_feature)
        features = features.view((features.size(0), -1))

        if inner_update_type == 'encoder_only':
            logits = self.classifier(features)
        else:
            logits = self.classifier(features, params=self.get_subdict(params, 'classifier'))
    
        return logits, features

    def forward(self, qry, adv=None, sprt=None, qry_num=1, adv_num=0, sprt_num=0, params=None, params2=None, feat=False, inner_update_type='encoder_only'):
        if qry_num == 1:
            x1 = qry
        else:
            x1, x2 = qry
        
        logits_qry, z_qry = self._forward_all(x1, params, inner_update_type)

        if qry_num == 2:
            logits_qry2, z_qry2 = self._forward_all(x2, params2, inner_update_type)
            logits_qry = (logits_qry, logits_qry2)
            z_qry = (z_qry, z_qry2)

        if adv_num == 1:
            adv1 = adv
        elif adv_num == 2:
            adv1, adv2 = adv
        
        if adv_num >= 1:
            logits_adv, z_adv = self._forward_all(adv1, params, inner_update_type)
        else:
            logits_adv, z_adv = None, None
        if adv_num == 2:
            logits_adv2, z_adv2 = self._forward_all(adv2, params2, inner_update_type)
            logits_adv = (logits_adv, logits_adv2)
            z_adv = (z_adv, z_adv2)

        if sprt_num == 1:
            sprt1 = sprt
        elif sprt_num == 2:
            sprt1, sprt2 = sprt

        if sprt_num >= 1:
            logits_sprt, z_sprt = self._forward_all(sprt1, params, inner_update_type)
        else:
            logits_sprt, z_sprt = None, None
        if sprt_num == 2:
            logits_sprt2, z_sprt2 = self._forward_all(sprt2, params2, inner_update_type)
            logits_sprt = (logits_sprt, logits_sprt2)
            z_sprt = (z_sprt, z_sprt2)

        if feat:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt, z_qry, z_adv, z_sprt
            else:
                return logits_qry, z_qry
        else:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt
            else:
                return logits_qry
    