import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable

from protonets.models import register_model
import torchvision.models as models

from .utils import euclidean_dist

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)

class Protonet(nn.Module):
    def __init__(self, encoder):
        super(Protonet, self).__init__()
        
        self.encoder = encoder

    def loss(self, sample, eval=False):
        xs = Variable(sample['xs']) # support
        xq = Variable(sample['xq']) # query


        n_class = xs.size(0)
        assert xq.size(0) == n_class
        n_support = xs.size(1)
        n_query = xq.size(1)

        target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
        target_inds = Variable(target_inds, requires_grad=False)

        if xq.is_cuda:
            target_inds = target_inds.cuda()

        x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]),
                       xq.view(n_class * n_query, *xq.size()[2:])], 0)
        # if eval:
        # batchsize = 16
        # num = int(x.size(0) / batchsize) + 1
        # z = []
        # for i in range(num):
        #     z_tmp = self.encoder.forward(x[batchsize*i:batchsize*(i+1)])
        #     z.append(z_tmp)
        # z = torch.cat(z, dim=0)
        # else:
        z = self.encoder.forward(x)
        z_dim = z.size(-1)
        # print('cnn embedding:', z.size())
        # print(z[:n_class*n_support].size())
        z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)
        zq = z[n_class*n_support:]
        # print(z_proto.size())
        # print(zq.size())
        dists = euclidean_dist(zq, z_proto)
        # print(dists.size())

        log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

        _, y_hat = log_p_y.max(2)
        acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()

        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc_val.item()
        }

@register_model('protonet_conv')
def load_protonet_conv(**kwargs):
    x_dim = kwargs['x_dim']
    hid_dim = kwargs['hid_dim']
    z_dim = kwargs['z_dim']

    def conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    encoder = nn.Sequential(
        conv_block(x_dim[0], hid_dim),
        conv_block(hid_dim, hid_dim),
        conv_block(hid_dim, hid_dim),
        conv_block(hid_dim, z_dim),
        Flatten()
    )

    return Protonet(encoder)

@register_model('protonet_resnet50')
def load_protonet_resnet50(**kwargs):
    x_dim = kwargs['x_dim']
    hid_dim = kwargs['hid_dim']
    z_dim = kwargs['z_dim']

    def resnet50():
        encoder = models.resnet50(pretrained=True)
        dim_representation = encoder.fc.in_features
        # dim_mlp = encoder.fc.weight.shape[1]
        encoder.fc = nn.Sequential(
            nn.Linear(dim_representation, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, z_dim),
            Flatten()
        )
        return encoder
    
    encoder = resnet50()

    return Protonet(nn.DataParallel(encoder))

@register_model('protonet_resnet18')
def load_protonet_resnet18(**kwargs):
    x_dim = kwargs['x_dim']
    hid_dim = kwargs['hid_dim']
    z_dim = kwargs['z_dim']

    def resnet18():
        encoder = models.resnet18(pretrained=True)
        dim_representation = encoder.fc.in_features
        # dim_mlp = encoder.fc.weight.shape[1]
        encoder.fc = nn.Sequential(
            nn.Linear(dim_representation, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, z_dim),
            Flatten()
        )
        return encoder
    
    encoder = resnet18()

    return Protonet(nn.DataParallel(encoder))