# This code is modified from https://github.com/jakesnell/prototypical-networks 

import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from methods.meta_template import MetaTemplate
from methods.min_norm_solvers import MinNormSolver, gradient_normalizers

class ProtoNet_MOML(MetaTemplate):
    def __init__(self, model_func,  n_way, n_support):
        super(ProtoNet_MOML, self).__init__( model_func,  n_way, n_support)
        self.loss_fn = nn.CrossEntropyLoss()
        self.weighting_mode = None
        
    def clamp(self, X, lower_limit, upper_limit):
        return torch.max(torch.min(X, upper_limit), lower_limit)
    
    
    def compute_proto(self, x_s):
        x_s = x_s.cuda()
#         print(x_s.size(), x_s.size()[-2:])
        x_s = x_s.reshape(self.n_way*self.n_support, *x_s.size()[-3:])
        z_support = self.feature.forward(x_s)
        z_proto = z_support.reshape(self.n_way, self.n_support, -1 ).mean(1)
        return z_proto
    
    def compute_q_loss(self, x_q, z_proto):
        x_q = x_q.cuda()
        x_q = x_q.reshape(self.n_way*self.n_query, *x_q.size()[-3:])
        z_query = self.feature.forward(x_q)
        z_query = z_query.reshape(self.n_way* self.n_query, -1)
        dists = euclidean_dist(z_query, z_proto)
        y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)).cuda()
        return -dists, self.loss_fn(-dists, y_query)
    
    
    def train_loop(self, epoch, train_loader, optimizer):
        print_freq = 10

        avg_loss = 0
        avg_loss_adv = 0
        for i, (x,_ ) in enumerate(train_loader):
            self.n_query = x.size(1) - self.n_support           
            if self.change_way:
                self.n_way  = x.size(0)
            optimizer.zero_grad()
#             print('--', x.size())
            x_s = x[:,:self.n_support].reshape(self.n_way*self.n_support, *x.size()[2:])
            x_q = x[:,self.n_support:].reshape(self.n_way*self.n_query, *x.size()[2:])
            z_proto = self.compute_proto(x_s)
            _, q_loss = self.compute_q_loss(x_q, z_proto)
            q_loss.backward()
            optimizer.step()
            avg_loss = avg_loss+q_loss.item()

            if i % print_freq==0:
                #print(optimizer.state_dict()['param_groups'][0]['lr'])
                print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} '.format(epoch, i, len(train_loader), avg_loss/float(i+1)))
    
    def correct(self, x_q, z_proto):     
        with torch.no_grad():
            scores, _ = self.compute_q_loss(x_q, z_proto)
            y_query = np.repeat(range( self.n_way ), self.n_query)

            topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()
            top1_correct = np.sum(topk_ind[:,0] == y_query)
            return float(top1_correct), len(y_query)
                
    def test_loop(self, test_loader, record=None, return_std=False):
        correct =0
        count = 0
        acc_all = []
        acc_all_adv = []
        
        iter_num = len(test_loader) 
        for i, (x,_) in enumerate(test_loader):
            self.n_query = x.size(1) - self.n_support
            if self.change_way:
                self.n_way  = x.size(0)
            x_s = x[:,:self.n_support].reshape(self.n_way*self.n_support, *x.size()[2:])
            x_q = x[:,self.n_support:].reshape(self.n_way*self.n_query, *x.size()[2:])
            z_proto = self.compute_proto(x_s)        
            correct_this, count_this = self.correct(x_q, z_proto)
            acc_all.append(correct_this/ count_this*100)


        print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, np.mean(np.asarray(acc_all)), 
                                                   1.96* np.std(np.asarray(acc_all))/np.sqrt(iter_num)))

        return np.mean(np.asarray(acc_all))

def euclidean_dist( x, y):
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    assert d == y.size(1)

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)

