import sys
sys.path.append('..')
from util import framework
import torch
from torch import autograd, optim, nn
from torch.autograd import Variable
from torch.nn import functional as F
import random

class ProtoParamModule(nn.Module):
    def __init__(self, dot=False):
        nn.Module.__init__(self)
        self.center_d = nn.ParameterDict({})
        self.radius_d = nn.ParameterDict({})
        self.dot = dot
    
    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            # return torch.sqrt(torch.pow(x - y, 2).sum(dim))
            return (torch.pow(x - y, 2)).sum(dim)

    def __init_proto_value__(self, support_embeddings):
        proto_center = torch.mean(support_embeddings, 0).detach() 
        mean_dist_to_c = torch.mean(self.__dist__(support_embeddings, proto_center, dim=1)).detach()
        proto_radius = mean_dist_to_c / 2
        if proto_radius == 0:
            proto_radius = torch.tensor(10.00)
        elif proto_radius < 0:
            raise ValueError('negative proto_radius')
        assert(proto_radius > 0)
        return nn.Parameter(proto_center), nn.Parameter(proto_radius)

    
    def forward(self, tag, support_embeddings, save=True):
        '''
        given a tag and support sample embeddings
        return the proto parameter (center, radius)
        '''
        if tag not in self.center_d:
            c, r = self.__init_proto_value__(support_embeddings)
            if save:
                self.center_d[tag] = c
                self.radius_d[tag] = r
            return c, r
        else:
            if not save:
                print('[WARNING] in eval mode, but not using mean support embedding!')
            return self.center_d[tag], self.radius_d[tag]



class BigProto_deprecated(framework.FewShotNERModel):
    
    def __init__(self, word_encoder, dot=False, ignore_index=-1, label2tag=None):
        framework.FewShotNERModel.__init__(self, word_encoder, ignore_index=ignore_index)
        
        # initialize centers and radii for prototypes
        # self.proto_centers = nn.Embedding(tot_num_of_cls, embedding_dim)
        # self.proto_radii = nn.Parameter(torch.zeros(tot_num_of_cls))

        # self.test_proto_centers = nn.Embedding(tot_num_of_cls, embedding_dim)
        # self.test_proto_radii = nn.Parameter(torch.zeros(tot_num_of_cls))
        
        # # keep track of which classes have been seen
        # self.initialized = [0 for i in range(tot_num_of_cls)]

        # self.test_initialized = [0 for i in range(tot_num_of_cls)]
        
        self.drop = nn.Dropout()
        self.dot = dot
        self.label2tag = label2tag

        # get proto center and radius
        self.proto_param = ProtoParamModule(dot=dot)

    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            return (torch.pow(x - y, 2)).sum(dim)
            # return torch.sqrt(torch.pow(x - y, 2).sum(dim))

    def __batch_dist__(self, C, R, Q, q_mask):
        # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
        assert Q.size()[:2] == q_mask.size()
        Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
        dist_to_proto_ct = self.__dist__(C.unsqueeze(0), Q.unsqueeze(1), 2)
        dist_to_proto = dist_to_proto_ct - R
        return -dist_to_proto
    
    def __get_avg_proto__(self, embeddings, labels, cls):
        ''' embeddings are already flattened '''
        cur_samples = embeddings[labels==cls]
        proto_center = torch.mean(cur_samples, 0) 
        mean_dist_to_c = torch.mean(self.__dist__(cur_samples, proto_center, dim=1))
        proto_radius = mean_dist_to_c / 2
        if proto_radius == 0:
            proto_radius = torch.tensor(10.00)
        elif proto_radius < 0:
            raise ValueError('negative proto_radius')
        assert(proto_radius > 0)
        return proto_center, proto_radius

    def __get_all_avg_protos__(self, embedding, tag, mask):
        proto_centers = []
        proto_radii = []
        embedding = embedding[mask==1].view(-1, embedding.size(-1))
        tag = torch.cat(tag, 0)
        for label in range(torch.max(tag)+1):
            c, r = self.__get_avg_proto__(embedding, tag, label)
            proto_centers.append(c)
            proto_radii.append(r)
        
        proto_centers = torch.stack(proto_centers)
        proto_radii = torch.tensor(proto_radii)
        return proto_centers, proto_radii
    
    def __get_all_protos__(self, cur_episode_emb, cur_episode_mask, cur_episode_lab, label_set, label2tag, save=True):
        proto_centers = []
        proto_radii = []
        valid_support_emb = cur_episode_emb[cur_episode_mask==1].view(-1, cur_episode_emb.size(-1))
        support_labels_ts = torch.cat(cur_episode_lab, 0)
        for j, label in enumerate(label_set):
            support_embeddings = valid_support_emb[support_labels_ts == label]
            c, r = self.proto_param(label2tag[label], support_embeddings, save=save)
            proto_centers.append(c)
            proto_radii.append(r)
        
        proto_centers = torch.stack(proto_centers)
        proto_radii = torch.stack(proto_radii)
        
        return proto_centers, proto_radii
    
    def init_proto(self):
        if self.support_embedding is None:
            for label in range(37):
                _, _ = self.proto_param(self.label2tag[label], torch.zeros(1,768))
            self.train_labels = list(range(37))

        else:
            # embedding = self.support_embedding
            unique_labels = torch.unique(self.support_labels.cpu()).tolist()
            unique_labels.sort()
            unique_labels.remove(self.ignore_index)
            self.train_labels = unique_labels
            for label in unique_labels:
                support_embed = self.support_embedding[self.support_labels == label]
                _, _= self.proto_param(self.label2tag[label], support_embed.view(-1, support_embed.size(-1)))


    def forward_full_supervised(self, batch):
        embeddings = self.word_encoder(batch['input_ids'], batch['attention_mask'])
        proto_centers = []
        proto_radii = []
        for label in self.train_labels:
            c, r = self.proto_param(self.label2tag[label], None)
            proto_centers.append(c)
            proto_radii.append(r)
        
        proto_centers = torch.stack(proto_centers)
        proto_radii = torch.stack(proto_radii)

        if torch.cuda.is_available():
            proto_centers = proto_centers.cuda()
            proto_radii = proto_radii.cuda()

        # compute distances to prototypes
        logits = self.__batch_dist__(proto_centers, 
                                    proto_radii, 
                                    embeddings,
                                    batch['text_mask'])
        _, pred = torch.max(logits, 1)
        
        return logits, pred



    def forward(self, support=None, query=None, eval=False):
        '''
        support: Inputs of the support set.
        query: Inputs of the query set.
        N: Num of classes
        K: Num of instances for each class in the support set
        Q: Num of instances in the query set
        '''
        #print('computing embeddings...')
        support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768]
        support_emb = self.drop(support_emb)
        #print('got embeddings')
        
        # Prototypical Networks
        logits = []
        current_support_num = 0
        assert support_emb.size()[:2] == support['mask'].size()

        current_query_num = 0
        query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768]
        query_emb = self.drop(query_emb)
        if not eval:
            for i, sent_support_num in enumerate(support['sentence_num']):
                sent_query_num = query['sentence_num'][i]
                cur_label2tag = query['label2tag'][i]
                cur_episode_emb = support_emb[current_support_num:current_support_num+sent_support_num]
                cur_episode_lab = support['label'][current_support_num:current_support_num+sent_support_num]
                cur_episode_mask = support['text_mask'][current_support_num:current_support_num+sent_support_num]
                cur_episode_lab_set = list(set(torch.cat(cur_episode_lab, 0).tolist()))
                if self.ignore_index in cur_episode_lab_set:
                    cur_episode_lab_set.remove(self.ignore_index)
                # make sure that the prototypes will be in the right order later
                cur_episode_lab_set.sort()

                # get proto center and radius for each label
                proto_centers, proto_radii = self.__get_all_protos__(cur_episode_emb, cur_episode_mask, cur_episode_lab, cur_episode_lab_set, cur_label2tag)

                if torch.cuda.is_available():
                    proto_centers = proto_centers.cuda()
                    proto_radii = proto_radii.cuda()

                # compute distances to prototypes
                logits.append(self.__batch_dist__(proto_centers, 
                                                    proto_radii, 
                                                    query_emb[current_query_num:current_query_num+sent_query_num],
                                                    query['text_mask'][current_query_num: current_query_num+sent_query_num]))
                                                
                current_query_num += sent_query_num
                current_support_num += sent_support_num
        else:
            for i, sent_support_num in enumerate(support['sentence_num']):
                sent_query_num = query['sentence_num'][i]
                cur_label2tag = query['label2tag'][i]
                cur_batch_emb = support_emb[current_support_num:current_support_num+sent_support_num]
                cur_batch_lab = support['label'][current_support_num:current_support_num+sent_support_num]
                cur_batch_mask = support['text_mask'][current_support_num:current_support_num+sent_support_num]
                # compute prototypes on support set
                proto_centers, proto_radii = self.__get_all_avg_protos__(cur_batch_emb, 
                                                                     cur_batch_lab, 
                                                                     cur_batch_mask)
                if torch.cuda.is_available():
                    proto_centers = proto_centers.cuda()
                    proto_radii = proto_radii.cuda()
                # compute distances to prototypes
                logits.append(self.__batch_dist__(proto_centers, 
                                                  proto_radii, 
                                                  query_emb[current_query_num:current_query_num+sent_query_num],
                                                  query['text_mask'][current_query_num: current_query_num+sent_query_num]))
                                             
                current_query_num += sent_query_num
                current_support_num += sent_support_num
        
        logits = torch.cat(logits, 0)
        _, pred = torch.max(logits, 1)

        # for v in cur_label2tag.values():
        #     print(v,self.proto_param.radius_d[v])
        # print('radius:')
        # tags = ['O', 'building-hotel', 'location-mountain', 'building-airport']
        # # for k, v in self.proto_param.radius_d.items():
        # for k in tags:
        #     if k in self.proto_param.radius_d:
        #         print(k,self.proto_param.radius_d[k])
        return logits, pred

class BigProto(framework.FewShotNERModel):
    
    def __init__(self, word_encoder, label2tag=None, embedding_dim=768, 
                 dot=False, ignore_index=-1):

        # label2tag: glb_tag2idx
        framework.FewShotNERModel.__init__(self, word_encoder, ignore_index=ignore_index)
        self.proto_param = None
        tot_num_of_cls = len(label2tag)
        self.glb_tag2ind = label2tag
        # initialize centers and radii for prototypes
        self.proto_centers = nn.Embedding(tot_num_of_cls, embedding_dim)
        self.proto_radii = nn.Parameter(torch.zeros(tot_num_of_cls))
        
        # keep track of which classes have been seen
        self.initialized = [0 for i in range(tot_num_of_cls)]
        
        self.drop = nn.Dropout()
        self.dot = dot

    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            return torch.pow(x - y, 2).sum(dim)

    def __batch_dist__(self, C, R, Q, q_mask):
        # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
        assert Q.size()[:2] == q_mask.size()
        Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
        dist_to_proto_ct = self.__dist__(C.unsqueeze(0), Q.unsqueeze(1), 2)
        dist_to_proto = dist_to_proto_ct - R
        return -dist_to_proto

    def __get_proto__(self, embedding, tag, mask):
        proto = []
        embedding = embedding[mask==1].view(-1, embedding.size(-1))
        tag = torch.cat(tag, 0)
        assert tag.size(0) == embedding.size(0)
        for label in range(torch.max(tag)+1):
            proto.append(torch.mean(embedding[tag==label], 0))
        proto = torch.stack(proto)
        return proto

    def __get_avg_proto__(self, embeddings, labels, cls):
        ''' embeddings are already flattened '''
        cur_samples = embeddings[labels==cls]
        proto_center = torch.mean(cur_samples, 0) 
        mean_dist_to_c = torch.mean(self.__dist__(cur_samples, proto_center, dim=1))
        proto_radius = mean_dist_to_c / 2
        if proto_radius == 0:
            proto_radius = torch.tensor(10.00)
        elif proto_radius < 0:
            raise ValueError('negative proto_radius')
        assert(proto_radius > 0)
        return proto_center, proto_radius
    
    def __get_all_protos__(self, embedding, tag, mask):
        proto_centers = []
        proto_radii = []
        embedding = embedding[mask==1].view(-1, embedding.size(-1))
        tag = torch.cat(tag, 0)
        for label in range(torch.max(tag)+1):
            c, r = self.__get_avg_proto__(embedding, tag, label)
            proto_centers.append(c)
            proto_radii.append(r)
        
        proto_centers = torch.stack(proto_centers)
        proto_radii = torch.tensor(proto_radii)
        
        return proto_centers, proto_radii

    def forward(self, support, query, eval=False):
        '''
        support: Inputs of the support set.
        query: Inputs of the query set.
        N: Num of classes
        K: Num of instances for each class in the support set
        Q: Num of instances in the query set
        '''
        #print('computing embeddings...')
        support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768]
        query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768]
        support_emb = self.drop(support_emb)
        query_emb = self.drop(query_emb)
        #print('got embeddings')
        
        # Prototypical Networks
        logits = []
        current_support_num = 0
        current_query_num = 0
        assert support_emb.size()[:2] == support['mask'].size()
        assert query_emb.size()[:2] == query['mask'].size()
        
        if not eval:
            for i, sent_support_num in enumerate(support['sentence_num']):
                #print('batch {}'.format(i))
                sent_query_num = query['sentence_num'][i]
                cur_label2tag = query['label2tag'][i]
                cur_batch_emb = support_emb[current_support_num:current_support_num+sent_support_num]
                cur_batch_lab = support['label'][current_support_num:current_support_num+sent_support_num]
                cur_batch_mask = support['text_mask'][current_support_num:current_support_num+sent_support_num]
                cur_batch_lab_set = list(set(torch.cat(cur_batch_lab, 0).tolist()))
                if self.ignore_index in cur_batch_lab_set:
                    cur_batch_lab_set.remove(self.ignore_index)
                # make sure that the prototypes will be in the right order later
                cur_batch_lab_set.sort()
                cur_batch_lab_ind = [self.glb_tag2ind[cur_label2tag[label]] for label in cur_batch_lab_set]
                assert(len(cur_batch_lab_set) == len(cur_batch_lab_ind))
                
                # if needed, initialize prototypes with average of embeddings
                need_init = []
                for j, ind in enumerate(cur_batch_lab_ind):
                    if self.initialized[ind]: 
                        continue
                    else: # haven't initialized yet
                        need_init.append((cur_batch_lab_set[j], ind))
                if need_init: # non-empty list; existence of previously unseen class
                    valid_support_emb = cur_batch_emb[cur_batch_mask==1].view(-1, cur_batch_emb.size(-1))
                    support_labels_ts = torch.cat(cur_batch_lab, 0)
                    new_centers = self.proto_centers.weight.data
                    new_radii = self.proto_radii.data
                    for lab, ind in need_init:
                        #print(lab, ind)
                        c, r = self.__get_avg_proto__(valid_support_emb,
                                                      support_labels_ts,
                                                      lab)
                        #print(c, r)
                        new_centers[ind, :] = c
                        new_radii[ind] = r
                        self.initialized[ind] = 1
                        
                    self.proto_centers.weight.data.copy_(new_centers)
                    self.proto_radii.data.copy_(new_radii)
                
                # get prototypes
                indices_ts = torch.tensor(cur_batch_lab_ind)
                if torch.cuda.is_available():
                    indices_ts = indices_ts.cuda()
                proto_centers = self.proto_centers(indices_ts)
                proto_radii = self.proto_radii[cur_batch_lab_ind]
                if torch.cuda.is_available():
                    proto_radii = proto_radii.cuda()
                
                # compute distances to prototypes
                
                logits.append(self.__batch_dist__(proto_centers, 
                                                  proto_radii, 
                                                  query_emb[current_query_num:current_query_num+sent_query_num],
                                                  query['text_mask'][current_query_num: current_query_num+sent_query_num]))
                                             
                current_query_num += sent_query_num
                current_support_num += sent_support_num
                
        else:
            for i, sent_support_num in enumerate(support['sentence_num']):
                sent_query_num = query['sentence_num'][i]
                cur_label2tag = query['label2tag'][i]
                cur_batch_emb = support_emb[current_support_num:current_support_num+sent_support_num]
                cur_batch_lab = support['label'][current_support_num:current_support_num+sent_support_num]
                cur_batch_mask = support['text_mask'][current_support_num:current_support_num+sent_support_num]
                # compute prototypes on support set
                proto_centers, proto_radii = self.__get_all_protos__(cur_batch_emb, 
                                                                     cur_batch_lab, 
                                                                     cur_batch_mask)
                if torch.cuda.is_available():
                    proto_centers = proto_centers.cuda()
                    proto_radii = proto_radii.cuda()
                # compute distances to prototypes
                logits.append(self.__batch_dist__(proto_centers, 
                                                  proto_radii, 
                                                  query_emb[current_query_num:current_query_num+sent_query_num],
                                                  query['text_mask'][current_query_num: current_query_num+sent_query_num]))
                                             
                current_query_num += sent_query_num
                current_support_num += sent_support_num
                
        logits = torch.cat(logits, 0)
        _, pred = torch.max(logits, 1)
        return logits, pred