import sys
sys.path.append('..')
import fewshot_re_kit
import torch
from torch import autograd, optim, nn
from torch.autograd import Variable
from torch.nn import functional as F

class ProtoParamModule(nn.Module):
    def __init__(self, dot=False):
        nn.Module.__init__(self)
        self.center_d = nn.ParameterDict({})
        self.radius_d = nn.ParameterDict({})
        self.dot = dot
    
    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            # return torch.sqrt(torch.pow(x - y, 2).sum(dim))
            return (torch.pow(x - y, 2)).sum(dim)

    def __init_proto_value__(self, support_embeddings):
        proto_center = torch.mean(support_embeddings, 0).detach() 
        mean_dist_to_c = torch.mean(self.__dist__(support_embeddings, proto_center, dim=1)).detach()
        proto_radius = mean_dist_to_c / 2
        if proto_radius == 0:
            proto_radius = torch.tensor(10.00)
        elif proto_radius < 0:
            raise ValueError('negative proto_radius')
        assert(proto_radius > 0)
        return nn.Parameter(proto_center), nn.Parameter(proto_radius)

    
    def forward(self, tag, support_embeddings, save=True):
        '''
        given a tag and support sample embeddings
        return the proto parameter (center, radius)
        '''
        if tag not in self.center_d:
            c, r = self.__init_proto_value__(support_embeddings)
            if save:
                self.center_d[tag] = c
                self.radius_d[tag] = r
            return c, r
        else:
            return self.center_d[tag], self.radius_d[tag]

class BigProto(fewshot_re_kit.framework.FewShotREModel):
    
    def __init__(self, sentence_encoder, dot=False):
        fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder)
        # self.fc = nn.Linear(hidden_size, hidden_size)
        self.drop = nn.Dropout()
        self.dot = dot
        self.proto_param = ProtoParamModule(dot=dot)

    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            return -(torch.pow(x - y, 2)).sum(dim)

    def __batch_dist__(self, S, Q):
        return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3)

    def forward(self, support, query, N, K, total_Q, label2classname, eval=False):
        '''
        support: Inputs of the support set.
        query: Inputs of the query set.
        N: Num of classes
        K: Num of instances for each class in the support set
        Q: Num of instances in the query set
        '''
        support_emb = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size
        query_emb = self.sentence_encoder(query) # (B * total_Q, D)
        hidden_size = support_emb.size(-1)
        support = self.drop(support_emb)
        query = self.drop(query_emb)
        support = support.view(-1, N, K, hidden_size) # (B, N, K, D)
        query = query.view(-1, total_Q, hidden_size) # (B, total_Q, D)

        if not eval:
            # Prototypical Networks 
            # Ignore NA policy
            support_center = []
            support_radius = []
            for i, batch_support in enumerate(support):
                batch_support_center = []
                batch_support_radius = []
                for label in range(batch_support.size(0)):
                    c, r = self.proto_param(label2classname[i][label], batch_support[label])
                    batch_support_center.append(c)
                    batch_support_radius.append(r)
                support_center.append(torch.vstack(batch_support_center))
                support_radius.append(torch.hstack(batch_support_radius))
            # print(support_center[0].size())
            support_center = torch.stack(support_center, dim=0) # (B, N, D)
            support_radius = torch.stack(support_radius, dim=0) # (B, N)
        
        else:
            support_center = torch.mean(support, 2) # (B, N, D)
            expanded_proto_center = support_center.unsqueeze(2).expand(support.size())
            support_radius = torch.pow(support - expanded_proto_center, 2).sum(-2).mean(-1) / 2

        logits = self.__batch_dist__(support_center, query) # (B, total_Q, N)
        if torch.cuda.is_available():
            support_radius = support_radius.cuda()
        logits += support_radius.unsqueeze(1).expand(logits.size())
        minn, _ = logits.min(-1)
        logits = torch.cat([logits, minn.unsqueeze(2) - 1], 2) # (B, total_Q, N + 1)
        _, pred = torch.max(logits.view(-1, N + 1), 1)
        return logits, pred