import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmeta.modules import MetaModule, MetaSequential, MetaConv2d, MetaBatchNorm2d, MetaLinear
from timm.models.registry import register_model

class MetaBasicBlock(MetaModule):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(MetaBasicBlock, self).__init__()
        self.layer1 = MetaSequential(
            MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            MetaBatchNorm2d(planes)
        )
        self.layer2 = MetaSequential(
            MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False),
            MetaBatchNorm2d(planes)
        )

        self.shortcut = MetaSequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = MetaSequential(
                MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                MetaBatchNorm2d(self.expansion * planes)
            )

    def forward(self, x, params=None):
        out = F.relu(self.layer1(x, params=self.get_subdict(params, 'layer1')))
        out = self.layer2(out, params=self.get_subdict(params, 'layer2'))
        out += self.shortcut(x, params=self.get_subdict(params, 'shortcut'))
        out = F.relu(out)
        return out


class MetaBottleneck(MetaModule):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(MetaBottleneck, self).__init__()
        self.layer1 = MetaSequential(
            MetaConv2d(in_planes, planes, kernel_size=1, bias=False),
            MetaBatchNorm2d(planes)
        )
        self.layer2 = MetaSequential(
            MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            MetaBatchNorm2d(planes)
        )
        self.layer3 = MetaSequential(
            MetaConv2d(planes, self.expansion * planes, kernel_size=1, bias=False),
            MetaBatchNorm2d(self.expansion * planes)
        )

        self.shortcut = MetaSequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = MetaSequential(
                MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                MetaBatchNorm2d(self.expansion * planes)
            )

    def forward(self, x, params=None):
        out = F.relu(x, params=self.get_subdict(params, 'layer1'))
        out = F.relu(out, params=self.get_subdict(params, 'layer2'))
        out = self.layer3(out, params=self.get_subdict(params, 'layer3'))
        out += self.shortcut(x, params=self.get_subdict(params, 'shortcut'))
        out = F.relu(out)
        return out


class MetaResNet(MetaModule):
    def __init__(self, block, num_blocks, num_classes=10, in_channel=3, contrastive_learning=True):
        super(MetaResNet, self).__init__()
        self.in_planes = 64
        self.inplanes = 512 * block.expansion
        self.in_channel = in_channel
        self.contrastive_learning = contrastive_learning

        self.features = MetaSequential(
            MetaConv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
            MetaBatchNorm2d(64),
            nn.ReLU(),
            self._make_layer(block, 64, num_blocks[0], stride=1),
            self._make_layer(block, 128, num_blocks[1], stride=2),
            self._make_layer(block, 256, num_blocks[2], stride=2),
            self._make_layer(block, 512, num_blocks[3], stride=2),
            nn.AvgPool2d(4),
        )
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return MetaSequential(*layers)

    def _forward_all(self, x, params, inner_update_type):
        out = self.features(x, params=self.get_subdict(params, 'features'))
        out = out.view((out.size(0), -1))

        logits = self.fc(out)

        return logits, out
    
    def forward(self, qry, adv=None, sprt=None, qry_num=1, adv_num=0, sprt_num=0, params=None, params2=None, feat=False, inner_update_type='both'):
        

        if qry_num == 1:
            x1 = qry
        else:
            x1, x2 = qry
        
        logits_qry, z_qry = self._forward_all(x1, params, inner_update_type)

        if qry_num == 2:
            logits_qry2, z_qry2 = self._forward_all(x2, params2, inner_update_type)
            logits_qry = (logits_qry, logits_qry2)
            z_qry = (z_qry, z_qry2)

        if adv_num == 1:
            adv1 = adv
        elif adv_num == 2:
            adv1, adv2 = adv
        
        if adv_num >= 1:
            logits_adv, z_adv = self._forward_all(adv1, params, inner_update_type)
        else:
            logits_adv, z_adv = None, None
        if adv_num == 2:
            logits_adv2, z_adv2 = self._forward_all(adv2, params2, inner_update_type)
            logits_adv = (logits_adv, logits_adv2)
            z_adv = (z_adv, z_adv2)

        if sprt_num == 1:
            sprt1 = sprt
        elif sprt_num == 2:
            sprt1, sprt2 = sprt

        if sprt_num >= 1:
            logits_sprt, z_sprt = self._forward_all(sprt1, params, inner_update_type)
        else:
            logits_sprt, z_sprt = None, None
        if sprt_num == 2:
            logits_sprt2, z_sprt2 = self._forward_all(sprt2, params2, inner_update_type)
            logits_sprt = (logits_sprt, logits_sprt2)
            z_sprt = (z_sprt, z_sprt2)

        if feat:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt, z_qry, z_adv, z_sprt
            else:
                return logits_qry, z_qry
        else:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt
            else:
                return logits_qry

class SmallMetaBasicBlock(MetaModule):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1,
                downsample=None, drop_block=False,
                block_size=1, max_padding=0, track_running_stats=False):
        super(SmallMetaBasicBlock, self).__init__()

        self.layer1 = MetaSequential(
            MetaConv2d(inplanes, planes, kernel_size=3,
                                stride=1, padding=1, bias=False),
            MetaBatchNorm2d(planes, track_running_stats=track_running_stats),
            nn.LeakyReLU()                   
        )
        self.layer2 = MetaSequential(
            MetaConv2d(planes, planes, kernel_size=3,
                                stride=1, padding=1, bias=False),
            MetaBatchNorm2d(planes, track_running_stats=track_running_stats),
            nn.LeakyReLU()                   
        )
        self.layer3 = MetaSequential(
            MetaConv2d(planes, planes, kernel_size=3,
                                stride=1, padding=1, bias=False),
            MetaBatchNorm2d(planes, track_running_stats=track_running_stats)
        )   
        self.relu = nn.LeakyReLU()
        self.maxpool = nn.MaxPool2d(stride=stride, kernel_size=[stride,stride],
                                                            padding=max_padding)
        self.max_pool = True if stride != max_padding else False
        self.downsample = downsample
        self.stride = stride
        self.num_batches_tracked = 0
        self.drop_block = drop_block
        self.block_size = block_size

    def forward(self, x, params=None):
        self.num_batches_tracked += 1

        residual = x
        out = self.layer1(x, params=self.get_subdict(params, 'layer1'))
        out = self.layer2(out, params=self.get_subdict(params, 'layer2'))
        out = self.layer3(out, params=self.get_subdict(params, 'layer3'))

        if self.downsample is not None:
            residual = self.downsample(x, params=self.get_subdict(params, 'downsample'))
        out += residual
        out = self.relu(out)

        if self.max_pool:
            out = self.maxpool(out)

        return out


class SmallMetaResNet(MetaModule):
    def __init__(self, blocks, avg_pool=True, dropblock_size=5,
                 out_features=5, wh_size=1, inductive_bn=False, contrastive_learning=True):
        super(SmallMetaResNet, self).__init__()
        self.in_channel = 3
        self.inplanes = 512 * wh_size * wh_size
        self.contrastive_learning=contrastive_learning

        self.inductive_bn = inductive_bn
        self.layer1 = self._make_layer(blocks[0], 64, stride=2, drop_block=True,
                                       block_size=dropblock_size)
        self.layer2 = self._make_layer(blocks[1], 128, stride=2, drop_block=True,
                                       block_size=dropblock_size)
        self.layer3 = self._make_layer(blocks[2], 256, stride=2, drop_block=True,
                                       block_size=dropblock_size)
        self.layer4 = self._make_layer(blocks[3], 512, stride=2, drop_block=True,
                                       block_size=dropblock_size)

        if avg_pool:
            self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        self.keep_avg_pool = avg_pool
        self.fc = nn.Linear(512 * wh_size * wh_size, out_features)

        for m in self.modules():
            if isinstance(m, MetaConv2d):
                nn.init.xavier_uniform_(m.weight)
            if isinstance(m, MetaLinear):
                nn.init.xavier_uniform_(m.weight)

    def _make_layer(self, block, planes, stride=1,
                    drop_block=False, block_size=1, max_padding=0):
        downsample = None
        if stride != 1 or self.in_channel != planes * block.expansion:
            downsample = MetaSequential(
                MetaConv2d(self.in_channel, planes * block.expansion,
                                kernel_size=1, stride=1, bias=False),
                MetaBatchNorm2d(planes * block.expansion,
                                track_running_stats=self.inductive_bn),
            )

        layers = []
        layers.append(block(self.in_channel, planes, stride,
                    downsample, drop_block, block_size, max_padding, track_running_stats=self.inductive_bn))
        self.in_channel = planes * block.expansion
        return MetaSequential(*layers)

    def _forward_all(self, x, params, inner_update_type):
        if self.anil or inner_update_type=='linear_only':
            params_feature = [None for _ in range(4)]
        else:
            params_feature = [get_subdict(params, f'layer{i+1}') for i in range(4)]

        x = self.layer1(x, params=params_feature[0])
        x = self.layer2(x, params=params_feature[1])
        x = self.layer3(x, params=params_feature[2])
        x = self.layer4(x, params=params_feature[3])
        if self.keep_avg_pool:
            x = self.avgpool(x)
        features = x.view((x.size(0), -1))
        
        if inner_update_type == 'encoder_only':
            logits = self.classifier(features)
        else:
            logits = self.classifier(features, params=get_subdict(params, 'classifier'))
       
        return logits, features
    
    def forward(self, qry, adv=None, sprt=None, qry_num=1, adv_num=0, sprt_num=0, params=None, params2=None, feat=False, inner_update_type='both'):
        

        if qry_num == 1:
            x1 = qry
        else:
            x1, x2 = qry
        
        logits_qry, z_qry = self._forward_all(x1, params, inner_update_type)

        if qry_num == 2:
            logits_qry2, z_qry2 = self._forward_all(x2, params2, inner_update_type)
            logits_qry = (logits_qry, logits_qry2)
            z_qry = (z_qry, z_qry2)

        if adv_num == 1:
            adv1 = adv
        elif adv_num == 2:
            adv1, adv2 = adv
        
        if adv_num >= 1:
            logits_adv, z_adv = self._forward_all(adv1, params, inner_update_type)
        else:
            logits_adv, z_adv = None, None
        if adv_num == 2:
            logits_adv2, z_adv2 = self._forward_all(adv2, params2, inner_update_type)
            logits_adv = (logits_adv, logits_adv2)
            z_adv = (z_adv, z_adv2)

        if sprt_num == 1:
            sprt1 = sprt
        elif sprt_num == 2:
            sprt1, sprt2 = sprt

        if sprt_num >= 1:
            logits_sprt, z_sprt = self._forward_all(sprt1, params, inner_update_type)
        else:
            logits_sprt, z_sprt = None, None
        if sprt_num == 2:
            logits_sprt2, z_sprt2 = self._forward_all(sprt2, params2, inner_update_type)
            logits_sprt = (logits_sprt, logits_sprt2)
            z_sprt = (z_sprt, z_sprt2)

        if feat:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt, z_qry, z_adv, z_sprt
            else:
                return logits_qry, z_qry
        else:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt
            else:
                return logits_qry

def meta_resnet18_selfsup(num_classes):
    return MetaResNet(MetaBasicBlock, [2, 2, 2, 2], in_channel=3, num_classes=num_classes)
