import torch.nn as nn
import torch
from torchmeta.modules import (MetaModule, MetaSequential, MetaConv2d,
                               MetaBatchNorm2d, MetaLinear)

# from https://github.com/kjunelee/MetaOptNet
# Embedding network used in Meta-learning with differentiable closed-form solvers
# (Bertinetto et al., in submission to NIPS 2018).
# They call the ridge rigressor version as "Ridge Regression Differentiable Discriminator (R2D2)."
  
# Note that they use a peculiar ordering of functions, namely conv-BN-pooling-lrelu,
# as opposed to the conventional one (conv-BN-lrelu-pooling).
def get_subdict(adict, name):
    if adict is None:
        return adict
    tmp = {k[len(name) + 1:]:adict[k] for k in adict if name in k}
    return tmp

class R2D2_conv_block(MetaModule):
    expansion = 1

    def __init__(self, in_channels, out_channels, retain_activation=True, keep_prob=1.0):
        super(R2D2_conv_block, self).__init__()
        self.conv1 = MetaConv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = MetaBatchNorm2d(out_channels)
        self.maxpool = nn.MaxPool2d(2)
        self.retain_activation = retain_activation
        self.keep_prob = keep_prob
        if retain_activation:
            self.leakyrelu = nn.LeakyReLU(0.1)

        if keep_prob < 1.0:
            self.dropout = nn.Dropout(p=1 - keep_prob, inplace=False)
    
    def forward(self, x, params=None):
        out = self.conv1(x, params=get_subdict(params, 'conv1'))
        out = self.bn1(out, params=get_subdict(params, 'bn1'))
        out = self.maxpool(out)

        if self.retain_activation:
            out = self.leakyrelu(out)
        
        if self.keep_prob < 1.0:
            out = self.dropout(out)

        return out

def R2D2Head(query, support, support_labels, n_way, n_shot, l2_regularizer_lambda=50.0):
    """
    Fits the support set with ridge regression and 
    returns the classification score on the query set.
    
    This model is the classification head described in:
    Meta-learning with differentiable closed-form solvers
    (Bertinetto et al., in submission to NIPS 2018).
    
    Parameters:
      query:  a (tasks_per_batch, n_query, d) Tensor.
      support:  a (tasks_per_batch, n_support, d) Tensor.
      support_labels: a (tasks_per_batch, n_support) Tensor.
      n_way: a scalar. Represents the number of classes in a few-shot classification task.
      n_shot: a scalar. Represents the number of support examples given per class.
      l2_regularizer_lambda: a scalar. Represents the strength of L2 regularization.
    Returns: a (tasks_per_batch, n_query, n_way) Tensor.
    """
    tasks_per_batch = query.size(0)
    n_support = support.size(1)

    assert(query.dim() == 3)
    assert(support.dim() == 3)
    assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
    assert(n_support == n_way * n_shot)      # n_support must equal to n_way * n_shot
    
    support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
    support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)

    id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
    
    # Compute the dual form solution of the ridge regression.
    # W = X^T(X X^T - lambda * I)^(-1) Y
    ridge_sol = computeGramMatrix(support, support) + l2_regularizer_lambda * id_matrix
    ridge_sol = binv(ridge_sol)
    ridge_sol = torch.bmm(support.transpose(1,2), ridge_sol)
    ridge_sol = torch.bmm(ridge_sol, support_labels_one_hot)
    
    # Compute the classification score.
    # score = W X
    logits = torch.bmm(query, ridge_sol)

    return logits

class R2D2Embedding_barlow(MetaModule):
    def __init__(self, x_dim=3, h1_dim=96, h2_dim=192, h3_dim=384, z_dim=512, \
                 retain_last_activation=False, out_features=5):
        super(R2D2Embedding_barlow, self).__init__()

        self.block1 = R2D2_conv_block(x_dim, h1_dim)
        self.block2 = R2D2_conv_block(h1_dim, h2_dim)
        self.block3 = R2D2_conv_block(h2_dim, h3_dim, keep_prob=0.9)
        # In the last conv block, we disable activation function to boost the classification accuracy.
        # This trick was proposed by Gidaris et al. (CVPR 2018).
        # With this trick, the accuracy goes up from 50% to 51%.
        # Although the authors of R2D2 did not mention this trick in the paper,
        # we were unable to reproduce the result of Bertinetto et al. without resorting to this trick.
        self.block4 = R2D2_conv_block(h3_dim, z_dim, retain_activation=retain_last_activation, keep_prob=0.7)
        # Add a learnable scale
        
        self.classifier = MetaLinear(8192, out_features)

    def _forward_all(self, x, params, inner_update_type):
        if inner_update_type=='linear_only':
            params_feature = [None for _ in range(4)]
        else:
            params_feature = [get_subdict(params, f'block{i+1}') for i in range(4)]

        b1 = self.block1(x, params=params_feature[0])
        b2 = self.block2(b1, params=params_feature[1])
        b3 = self.block3(b2, params=params_feature[2])
        b4 = self.block4(b3, params=params_feature[3])
        # Flatten and concatenate the output of the 3rd and 4th conv blocks as proposed in R2D2 paper.
        features = torch.cat((b3.view(b3.size(0), -1), b4.view(b4.size(0), -1)), 1)
        logits = self.classifier(features)
        
        return logits, features

    
    def forward(self, qry, adv=None, sprt=None, qry_num=1, adv_num=0, sprt_num=0, params=None, params2=None, feat=False, inner_update_type='both'):
        

        if qry_num == 1:
            x1 = qry
        else:
            x1, x2 = qry
        
        logits_qry, z_qry = self._forward_all(x1, params, inner_update_type)

        if qry_num == 2:
            logits_qry2, z_qry2 = self._forward_all(x2, params2, inner_update_type)
            logits_qry = (logits_qry, logits_qry2)
            z_qry = (z_qry, z_qry2)

        if adv_num == 1:
            adv1 = adv
        elif adv_num == 2:
            adv1, adv2 = adv
        
        if adv_num >= 1:
            logits_adv, z_adv = self._forward_all(adv1, params, inner_update_type)
        else:
            logits_adv, z_adv = None, None
        if adv_num == 2:
            logits_adv2, z_adv2 = self._forward_all(adv2, params2, inner_update_type)
            logits_adv = (logits_adv, logits_adv2)
            z_adv = (z_adv, z_adv2)

        if sprt_num == 1:
            sprt1 = sprt
        elif sprt_num == 2:
            sprt1, sprt2 = sprt

        if sprt_num >= 1:
            logits_sprt, z_sprt = self._forward_all(sprt1, params, inner_update_type)
        else:
            logits_sprt, z_sprt = None, None
        if sprt_num == 2:
            logits_sprt2, z_sprt2 = self._forward_all(sprt2, params2, inner_update_type)
            logits_sprt = (logits_sprt, logits_sprt2)
            z_sprt = (z_sprt, z_sprt2)

        if feat:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt, z_qry, z_adv, z_sprt
            else:
                return logits_qry, z_qry
        else:
            if adv_num>0 or sprt_num>0:
                return logits_qry, logits_adv, logits_sprt
            else:
                return logits_qry
    
    
