import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import *
from torch.autograd import Variable
from torch.nn import Parameter
from torch_optimizer import RAdam
from tensorboardX import SummaryWriter
from tqdm import tqdm
import hashlib
import json
from dataset import *
import os.path
import utils
from utils import *


class Callback:
    def on_training_begin(self,**locals):
        pass
    def on_training_end(self,**locals):
        pass
    def on_epoch_begin(self,**locals):
        pass
    def on_epoch_end(self,**locals):
        pass
    def on_batch_begin(self,**locals):
        pass
    def on_batch_end(self,**locals):
        pass


class CallbackList(Callback):
    def __init__(self):
        super().__init__()
        self.callbacks = []
    def on_training_begin(self,**locals):
        for c in self.callbacks:
            c.on_training_begin(**locals)
    def on_training_end(self,**locals):
        for c in self.callbacks:
            c.on_training_end(**locals)
    def on_epoch_begin(self,**locals):
        for c in self.callbacks:
            c.on_epoch_begin(**locals)
    def on_epoch_end(self,**locals):
        for c in self.callbacks:
            c.on_epoch_end(**locals)
    def on_batch_begin(self,**locals):
        for c in self.callbacks:
            c.on_batch_begin(**locals)
    def on_batch_end(self,**locals):
        for c in self.callbacks:
            c.on_batch_end(**locals)


class SummaryWriterCallback(Callback):
    def __init__(self, path, per_batch=1000, metrics={}):
        self.path = path
        self.per_batch = per_batch
        self.metrics = metrics
        print("write logs into", path)
    def on_training_begin(self,**locals):
        self.writer = SummaryWriter(logdir=self.path,flush_secs=10)
    def on_training_end(self,**locals):
        self.writer.close()
    def on_epoch_end(self,**locals):
        b           = locals["b"]
        split       = locals["split"]
        epoch       = locals["epoch"]
        num_batches = locals["num_batches"]
        total_loss  = locals["total_loss"]
        for k, v in self.metrics.items():
            self.writer.add_scalar(split+'/'+k,v(**locals), epoch)
        self.writer.add_scalar(split+'/loss',total_loss/(b*num_batches), epoch)
    def on_batch_end(self,**locals):
        b=locals["b"]
        split=locals["split"]
        loss=locals["loss"]
        batch_n=locals["batch_n"]
        num_batches=locals["num_batches"]
        epoch=locals["epoch"]
        if self.per_batch is not None:
            if (batch_n % self.per_batch) == 0:
                for k, v in self.metrics.items():
                    self.writer.add_scalar(split+'-batch/'+k,v(**locals), batch_n + epoch * num_batches)
                self.writer.add_scalar(split+'-batch/loss', loss.item()/b, batch_n + epoch * num_batches)


class TQDMCallback(Callback):
    def on_epoch_begin(self,data=None,**locals):
        self.bar = tqdm(total=len(data),mininterval=1.0)
    def on_batch_end(self,b=None,split=None,epoch=None,loss=None,**locals):
        self.bar.update(1)
        self.bar.set_description("epoch "+str(epoch)+" "+split+"/loss {0:.6f}".format(loss.item()/b), refresh=False)
    def on_epoch_end(self,b=None,split=None,epoch=None,total_loss=None,data=None,**locals):
        self.bar.set_description("epoch "+str(epoch)+" "+split+"/loss {0:.6f}".format(total_loss/(b*len(data))))
        self.bar.close()


class ShuffleCallback(Callback):
    def on_epoch_begin(self,split=None,train=None,**locals):
        if split == "train":
            print("shuffling...",end="")
            np.random.shuffle(train)
            print("Done!")


class PeriodicSaveCallback(Callback):
    def __init__(self, interval=3600):
        self.interval = interval
    def on_training_begin(self,**locals):
        import time
        self.timestamp = time.time()
    def on_epoch_end(callback,split=None,epoch=None,self=None,**locals):
        if split != "train":
            return
        import time
        if time.time() - callback.timestamp < callback.interval:
            return
        with open(self.local("timestamp.log"), 'a') as f:
            json.dump([time.asctime(), epoch], f)
            f.write("\n")
        self.save()
        callback.timestamp += callback.interval


class USR2SaveCallback(Callback):
    # save the model when USR2 signal is received from the job scheduler
    def __init__(self):
        self.stop_scheduled = False
    def on_training_begin(self,**locals):
        import signal
        def sig_handler(sig,frame):
            print("received", sig)
            self.stop_scheduled = True
        signal.signal(signal.SIGUSR2, sig_handler)
        print("signal handler installed:",signal.SIGUSR2)
    def on_batch_end(self,**locals):
        if self.stop_scheduled:
            raise KeyboardInterrupt()


class LambdaCallback(Callback):
    def __init__(self,**kwargs):
        self.hooks = kwargs
    def on_training_begin(self,**locals):
        if "on_training_begin" in self.hooks:
            self.hooks["on_training_begin"](**locals)
    def on_training_end(self,**locals):
        if "on_training_end" in self.hooks:
            self.hooks["on_training_end"](**locals)
    def on_epoch_begin(self,**locals):
        if "on_epoch_begin" in self.hooks:
            self.hooks["on_epoch_begin"](**locals)
    def on_epoch_end(self,**locals):
        if "on_epoch_end" in self.hooks:
            self.hooks["on_epoch_end"](**locals)
    def on_batch_begin(self,**locals):
        if "on_batch_begin" in self.hooks:
            self.hooks["on_batch_begin"](**locals)
    def on_batch_end(self,**locals):
        if "on_batch_end" in self.hooks:
            self.hooks["on_batch_end"](**locals)


class Trainer(CallbackList):
    def __init__(self, hyper, data):
        super().__init__()
        self.hyper = hyper
        import pprint
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(hyper)
        os.makedirs(self._path(),exist_ok=True)
        with open(self.local("hyper.json"), 'w') as f:
            json.dump(self.hyper, f)
        self.metrics = {}
        self.callbacks.append(SummaryWriterCallback(self._path(), metrics=self.metrics))
        if self.hyper["shuffle"]:
            self.callbacks.append(ShuffleCallback())
        self.callbacks.append(TQDMCallback())
        self.callbacks.append(USR2SaveCallback())

    def md5(self):
        permanent_hyper = self.hyper.copy()
        impermanent_hyper = ["epochs", "start_epoch", "mode", "path"]
        for h in impermanent_hyper:
            del permanent_hyper[h]
        permanent_hyper["path"] = "results/"
        return hashlib.md5(str(json.dumps(permanent_hyper,sort_keys=True)).encode('utf-8')).hexdigest()
    def _path(self):
        return os.path.join(self.hyper["path"],
                            self.hyper["model"]+"-"+str(self.md5()))
    def local(self,subpath):
        return os.path.join(self._path(), subpath)
        
    def save(self):
        print("saving",self.local("model.model"))
        torch.save(self.state_dict(), self.local("model.model"))
    def load(self):
        print("loading",self.local("model.model"))
        self.load_state_dict(torch.load(self.local("model.model"), map_location=lambda storage, loc: storage))
        self.eval()

    def loop(self,train,valid):
        b = self.hyper["batch_size"]
        train = train.data
        valid = valid.data
        train = train[:(len(train)//b)*b]
        valid = valid[:(len(valid)//b)*b]
        train_b = train.reshape((len(train)//b),b,-1)
        valid_b = valid.reshape((len(valid)//b),b,-1)
        num_batches = train_b.shape[0]
        optimizer = \
            eval(self.hyper["optimizer"])(self.parameters(), lr=self.hyper["lr"])
        self = opt_cuda(self)
        try:
            print("Started training...")
            self.on_training_begin(**{k:v for k,v in locals().items() if v is not self})
            for epoch in range(self.hyper["start_epoch"],self.hyper["epochs"]):
                for (split, data) in (("train",train_b), ("valid",valid_b)):
                    try:
                        self.on_epoch_begin(**{k:v for k,v in locals().items() if v is not self})
                        total_loss = 0
                        for batch_n, batch in enumerate(data):
                            try:
                                self.on_batch_begin(**{k:v for k,v in locals().items() if v is not self})
                                batch = opt_cuda(torch.tensor(batch.astype(int)))
                                context = batch[:,:-1]
                                target  = batch[:,-1]

                                # print("context:", context)
                                output = self(context) # forward
                                # print("output:", output)
                                loss = self.loss(output, target)
                                # print("loss:", loss)
                                if split == "train":
                                    optimizer.zero_grad()
                                    loss.backward()
                                    optimizer.step()
                                total_loss += loss.item()
                            finally:
                                self.on_batch_end(**{k:v for k,v in locals().items() if v is not self})
                    finally:
                        self.on_epoch_end(**{k:v for k,v in locals().items() if v is not self})
        except KeyboardInterrupt:
            print("interrupted by keyboard")
        finally:
            self.on_training_end(**{k:v for k,v in locals().items() if v is not self})

    def evaluate(self,data,fn):
        b = self.hyper["batch_size"]
        data = data.data
        data = data[:(len(data)//b)*b]
        data_b = data.reshape((len(data)//b),b,-1)
        split = "valid"

        self = opt_cuda(self)

        total_loss = []
        bar = tqdm(data_b,mininterval=1.0)
        for batch in bar:
            self.on_batch_begin(**{k:v for k,v in locals().items() if v is not self})
            batch = torch.tensor(batch.astype(int))
            batch = opt_cuda(batch)
            context = batch[:,:-1]
            target  = batch[:,-1]
            output = self(context) # forward
            loss = fn(output, target)
            total_loss.append(loss.cpu().detach().numpy())
        bar.close()
        total_loss = np.asarray(total_loss).sum(0) / len(data)
        print('result',total_loss)
        return total_loss




class TopKMixin:
    def topk_accuracy(self, max_k, similarity, target, reduction="mean"):
        B, V = similarity.shape
        assert similarity.device == target.device
        device = similarity.device

        # torch tries to obtain the gradient from similarity[index0,best], which fails
        target = target.detach()
        similarity = similarity.detach()

        # cosine similarity is larger the better
        topk       = torch.zeros([B, max_k],dtype=torch.int,device=device)
        index0     = torch.arange(B,device=device)
        for i in range(max_k):
            best = torch.argmax(similarity, dim=1) # [B]
            topk[:,i] = best
            similarity[index0,best] = -float('inf')

        target     = target.view(-1,1) # [B, 1]

        match = (topk.long() == target) # B, max_k

        match_k, _ = torch.cummax(match,dim=1) # B, max_k
        # match_k[:,0] : whether top-1 results match
        # match_k[:,4] : whether top-5 results match

        if reduction == "mean":
            return match_k.mean(0)
        elif reduction == "sum":
            return match_k.sum(0)
        else:
            return match_k

    def topk_indices(self, max_k, similarity):
        B, V = similarity.shape
        device = similarity.device

        # torch tries to obtain the gradient from similarity[index0,best], which fails
        similarity = similarity.detach()

        # cosine similarity is larger the better
        topk       = torch.zeros([B, max_k],dtype=torch.int,device=device)
        index0     = torch.arange(B,device=device)
        for i in range(max_k):
            best = torch.argmax(similarity, dim=1) # [B]
            topk[:,i] = best
            similarity[index0,best] = -float('inf')

        return topk


class ContinuousEmbeddingMixin(TopKMixin):
    def accuracy(self, max_k, pred_emb, target, reduction="mean"):
        # returns the results for all k < max_k
        # B: batch
        # E: embedding
        # V: vocab
        pred_emb   = F.normalize(pred_emb) # [B, E]
        matrix = self.normalized_matrix
        similarity = torch.matmul(pred_emb, matrix.T) # [B, V]
        return self.topk_accuracy(max_k, similarity, target, reduction=reduction)

    def predict(self, max_k, pred_emb):
        pred_emb   = F.normalize(pred_emb) # [B, E]
        matrix = self.normalized_matrix
        similarity = torch.matmul(pred_emb, matrix.T) # [B, V]
        return self.topk_indices(max_k, similarity)

    def embed_sentence(self, indices, reduction="mean"):
        """
        sentences represented as a list of indices 
        """
        if reduction == "mean":
            return self.embeddings2(indices).mean(dim=0)
        if reduction == "sum":
            return self.embeddings2(indices).sum(dim=0)

    @property
    def normalized_matrix(self):
        # cache the L2 normalized matrix
        if hasattr(self, "_normalized_matrix"):
            return self._normalized_matrix

        matrix = self.embeddings2.weight # [vocab, embedding]
        matrix = F.normalize(matrix)   # [vocab, embedding]
        self._normalized_matrix = matrix
        return matrix

    @property
    def matrix(self):
        return self.embeddings2.weight


    def analogy_common(self, max_k, a1, a2, b1, order=0):
        matrix = self.normalized_matrix # normalized
        a1_emb = matrix[a1]             # normalized
        a2_emb = matrix[a2]             # normalized
        b1_emb = matrix[b1]             # normalized
        b2_emb = b1_emb - a1_emb + a2_emb
        return b2_emb

    def analogy_only_b_common(self, max_k, a1, a2, b1, order=0):
        matrix = self.normalized_matrix # normalized
        b2_emb = matrix[b1] 
        return b2_emb

    def analogy_ignore_a_common(self, max_k, a1, a2, b1, order=0):
        matrix = self.normalized_matrix # normalized
        a2_emb = matrix[a2]             # normalized
        b1_emb = matrix[b1]             # normalized
        b2_emb = b1_emb + a2_emb
        return b2_emb

    def analogy_add_opposite_common(self, max_k, a1, a2, b1, order=0):
        matrix = self.normalized_matrix # normalized
        a1_emb = matrix[a1]             # normalized
        a2_emb = matrix[a2]             # normalized
        b1_emb = matrix[b1]             # normalized
        b2_emb = b1_emb + a1_emb - a2_emb
        return b2_emb

    # Assume a1:a2 = b1:b2, where a1,a2,b1,b2 are the word index/indices.
    # b2 is the target word index/indices.
    def accuracy_3cosadd(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (1)
        # the input vector is normalized
        matrix = self.normalized_matrix # normalized
        b2_emb = self.analogy_common(max_k, a1, a2, b1)
        b2_emb = F.normalize(b2_emb, dim=0)  # the sum is normalized
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_pairdirection(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (2)
        matrix = self.normalized_matrix
        device = matrix.device
        a1_emb = matrix[a1]
        a2_emb = matrix[a2]
        b1_emb = matrix[b1]
        V, E = matrix.shape
        B, E = b1_emb.shape
        # the diff is normalized
        da_emb = F.normalize(a2_emb - a1_emb, dim=0).view(B,E,1) # B,E,1
        del a1_emb, a2_emb
        # the next operation requires a larger memory, thus move them to cpu
        matrix = matrix.cpu()
        b1_emb = b1_emb.cpu()
        db_emb = F.normalize(matrix - b1_emb.view(B,1,E), dim=2) # B,V,E
        del b1_emb
        matrix = matrix.to(device)
        da_emb = da_emb.cpu()
        similarity = torch.bmm(db_emb, da_emb) # B,V,E x B,E,1 = B,V,1
        del db_emb, da_emb
        similarity = similarity.view(B,V)
        similarity = similarity.to(device)
        if exclude_original:
            idx = torch.arange(len(a1)).to(device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_3cosmul(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (4)
        matrix = self.normalized_matrix
        a1_emb = matrix[a1]
        a2_emb = matrix[a2]
        b1_emb = matrix[b1]

        cos_b2a1 = torch.matmul(a1_emb, matrix.T)
        cos_b2a2 = torch.matmul(a2_emb, matrix.T)
        cos_b2b1 = torch.matmul(b1_emb, matrix.T)
        # [B, V]
        similarity = cos_b2b1 * cos_b2a2 / (cos_b2a1 + 0.001) # constant is from the paper
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_only_b(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Equation (5)
        matrix = self.normalized_matrix # normalized
        b2_emb = self.analogy_only_b_common(max_k, a1, a2, b1)
        b2_emb = F.normalize(b2_emb, dim=0)  # the sum is normalized
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_ignore_a(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Equation (6)
        matrix = self.normalized_matrix # normalized
        b2_emb = self.analogy_ignore_a_common(max_k, a1, a2, b1)
        b2_emb = F.normalize(b2_emb, dim=0)  # the sum is normalized
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)
    
    def accuracy_add_opposite(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Equation (7)
        matrix = self.normalized_matrix # normalized
        b2_emb = self.analogy_add_opposite_common(max_k, a1, a2, b1, order=0)
        b2_emb = F.normalize(b2_emb, dim=0)  # the sum is normalized
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_reverse(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Section Analogy Functions: REVERSE(ADD)
        matrix = self.normalized_matrix # normalized
        a1_emb = matrix[a1]             # normalized
        a2_emb = matrix[a2]             # normalized
        b2_emb = matrix[b2]             # normalized
        b1_emb = F.normalize(b2_emb - a2_emb + a1_emb, dim=0) 
        similarity = torch.matmul(b1_emb, matrix.T)
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b1, reduction=reduction)

    def accuracy_reverse_only_b(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Section Analogy Functions: REVERSE(ONLY-B)
        matrix = self.normalized_matrix # normalized
        b1_emb = matrix[b2]             # normalized
        similarity = torch.matmul(b1_emb, matrix.T)
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b1, reduction=reduction)

    # return the topk indices, not accuracy
    def predict_3cosadd(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (1)
        # the input vector is normalized
        matrix = self.normalized_matrix # normalized
        a1_emb = matrix[a1]             # normalized
        a2_emb = matrix[a2]             # normalized
        b1_emb = matrix[b1]             # normalized
        b2_emb = F.normalize(b1_emb - a1_emb + a2_emb, dim=0) # the sum is normalized
        similarity = torch.matmul(b2_emb, matrix.T).cpu() # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_pairdirection(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (2)
        matrix = self.normalized_matrix
        device = matrix.device
        a1_emb = matrix[a1]
        a2_emb = matrix[a2]
        b1_emb = matrix[b1]
        V, E = matrix.shape
        B, E = b1_emb.shape
        # the diff is normalized
        da_emb = F.normalize(a2_emb - a1_emb, dim=0).view(B,E,1) # B,E,1
        # the next operation requires a larger memory, thus move them to cpu
        da_emb = da_emb.cpu()
        matrix = matrix.cpu()
        b1_emb = b1_emb.cpu()
        db_emb = F.normalize(matrix - b1_emb.view(B,1,E), dim=2) # B,V,E
        similarity = torch.bmm(db_emb, da_emb) # B,V,E x B,E,1 = B,V,1
        similarity = similarity.view(B,V)
        similarity = similarity.to(device)
        if exclude_original:
            idx = torch.arange(len(a1)).to(device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_3cosmul(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (4)
        matrix = self.normalized_matrix
        a1_emb = matrix[a1]
        a2_emb = matrix[a2]
        b1_emb = matrix[b1]

        cos_b2a1 = torch.matmul(a1_emb, matrix.T)
        cos_b2a2 = torch.matmul(a2_emb, matrix.T)
        cos_b2b1 = torch.matmul(b1_emb, matrix.T)
        # [B, V]
        similarity = cos_b2b1 * cos_b2a2 / (cos_b2a1 + 0.001) # constant is from the paper
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_only_b(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Equation (5)
        matrix = self.normalized_matrix # normalized
        b2_emb = matrix[b1]             # normalized
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_ignore_a(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Equation (6)
        matrix = self.normalized_matrix # normalized
        a2_emb = matrix[a2]             # normalized
        b1_emb = matrix[b1]             # normalized
        b2_emb = F.normalize(b1_emb + a2_emb, dim=0) # the sum is normalized
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)
    
    def predict_add_opposite(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Equation (7)
        matrix = self.normalized_matrix # normalized
        a1_emb = matrix[a1]             # normalized
        a2_emb = matrix[a2]             # normalized
        b1_emb = matrix[b1]             # normalized
        b2_emb = F.normalize(b1_emb + a1_emb - a2_emb, dim=0) # the sum is normalized
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_reverse(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Section Analogy Functions: REVERSE(ADD)
        matrix = self.normalized_matrix # normalized
        a1_emb = matrix[a1]             # normalized
        a2_emb = matrix[a2]             # normalized
        b2_emb = matrix[b2]             # normalized
        b1_emb = F.normalize(b2_emb - a2_emb + a1_emb, dim=0) 
        similarity = torch.matmul(b1_emb, matrix.T)
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_reverse_only_b(self, max_k, a1, a2, b1, b2, exclude_original=False):
        # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Section Analogy Functions: REVERSE(ONLY-B)
                # Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Section Analogy Functions: REVERSE(ADD)
        matrix = self.normalized_matrix # normalized
        b1_emb = matrix[b2]             # normalized
        similarity = torch.matmul(b1_emb, matrix.T)
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def print_knn(self, words, max_k, data):
        from colors import yellow

        idxs = opt_cuda(torch.tensor(list(map(data.word_to_idx, words)),dtype=int))

        self = opt_cuda(self)

        stat = {}
        for i,embedding in enumerate(self.embeddings):
            if len(self.embeddings) > 1:
                print("---- embedding",i,"--------------------")

            embedding = self.target_embedding(idxs)
            for method in [self.predict]:

                topk = method(max_k, embedding)

                for word, topk_per_word in zip(words, topk):
                    print("======================================================")
                    print(f"{max_k}-NN of", yellow(word.ljust(15)), "in", method.__name__, ":",end="")
                    for j in topk_per_word:
                        print(data.idx_to_word(j).ljust(15),end="")
                    print("")
                    register(stat, i, method.__name__, word, list(map(data.idx_to_word,topk_per_word.tolist())))
        return stat

    def print_analogy(self, a1_words, a2_words, b1_words, b2_words, max_k, data):
        from colors import yellow, magenta, red

        a1 = opt_cuda(torch.tensor(list(map(data.word_to_idx, a1_words)),dtype=int))
        a2 = opt_cuda(torch.tensor(list(map(data.word_to_idx, a2_words)),dtype=int))
        b1 = opt_cuda(torch.tensor(list(map(data.word_to_idx, b1_words)),dtype=int))
        b2 = opt_cuda(torch.tensor(list(map(data.word_to_idx, b2_words)),dtype=int))

        self = opt_cuda(self)

        stat = {}
        for i,embedding in enumerate(self.embeddings):
            if len(self.embeddings) > 1:
                print("---- embedding",i,"--------------------")

            for method in [self.predict_3cosadd,
                           self.predict_pairdirection,
                           self.predict_3cosmul,
                           self.predict_only_b,
                           self.predict_ignore_a,
                           self.predict_add_opposite]:

                b2_topk_include = method(max_k, a1, a2, b1, b2, exclude_original=False)
                b2_topk_exclude = method(max_k, a1, a2, b1, b2, exclude_original=True)
                for a1_word, a2_word, b1_word, b2_word, topk_per_word_include, topk_per_word_exclude in zip(a1_words, a2_words, b1_words, b2_words, b2_topk_include, b2_topk_exclude):
                    print("======================================================")
                    print(f"{max_k}-NN of", f"{yellow(a2_word)}-{yellow(a1_word)}+{yellow(b1_word)}(={yellow(b2_word)})", "in", method.__name__, ":")
                    for j in topk_per_word_include:
                        print(magenta(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Include original words")
                    for j in topk_per_word_exclude:
                        print(red(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Exclude original words")
                    print("")
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(include)", list(map(data.idx_to_word,topk_per_word_include.tolist())))
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(exclude)", list(map(data.idx_to_word,topk_per_word_exclude.tolist())))

            for method in [self.predict_reverse,
                           self.predict_reverse_only_b]:

                b1_topk_include = method(max_k, a1, a2, b1, b2, exclude_original=False)
                b1_topk_exclude = method(max_k, a1, a2, b1, b2, exclude_original=True)
                for a1_word, a2_word, b1_word, b2_word, topk_per_word_include, topk_per_word_exclude in zip(a1_words, a2_words, b1_words, b2_words, b1_topk_include, b1_topk_exclude):
                    print("======================================================")
                    print(f"{max_k}-NN of", f"{yellow(a1_word)}-{yellow(a2_word)}+{yellow(b2_word)}(={yellow(b1_word)})", "in", method.__name__, ":")
                    for j in topk_per_word_include:
                        print(magenta(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Include original words")
                    for j in topk_per_word_exclude:
                        print(red(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Exclude original words")
                    print("")
                    register(stat, i, method.__name__, f"{a1_word}-{a2_word}+{b2_word}={b1_word}(include)", list(map(data.idx_to_word,topk_per_word_include.tolist())))
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(exclude)", list(map(data.idx_to_word,topk_per_word_exclude.tolist())))
        return stat

    def evaluate_analogy_category(self, category_data, data):
        """
        :param self: pytorch self with embedding matrix
        :param category_data: list of four words pair analogy
        :param data: dataset object with word to index information
        :return stat: accuracy and count of anlalogy data
        """
        stat = {}
        stat["total"] = len(category_data)
        total = len(category_data)
        category_data = np.array([ list(map(data.word_to_idx,row))
                                  for row in category_data ])
        category_data = torch.tensor(category_data)
        # disgard examples where any word is <unk>, automatically wrong
        category_data = category_data[(category_data!=0).all(1),:]
        after_total = category_data.shape[0]

        print(f"{total-after_total}/{total} ({100-after_total/total*100:.1f}%) removed (with <unk> in any of 4 words)")
        k_list = [1, 5, 10]

        b = 500 # batch size
        if after_total > b:
            category_data_b = category_data[:(len(category_data)//b)*b]
            category_data_b = category_data_b.reshape((len(category_data)//b),b,-1)
        else:
            category_data_b = None
        leftover_data = category_data[(len(category_data)//b)*b:]
        del category_data

        for i,embedding in enumerate(self.embeddings):
            # limiting the number of tracks to reduce memory usage
            for exclusion in ["exclude_original"]: # "include_original",
                for method in [self.accuracy_3cosadd,
                               # self.accuracy_pairdirection,
                               self.accuracy_3cosmul,
                               self.accuracy_only_b,
                               self.accuracy_ignore_a,
                               self.accuracy_add_opposite,
                               self.accuracy_reverse,
                               self.accuracy_reverse_only_b
                               ]:
                    counts = torch.zeros(10)
                    if category_data_b is not None:
                        for batch in category_data_b:
                            counts += method(10,batch[:,0],batch[:,1],batch[:,2],batch[:,3],reduction="sum",exclude_original=exclusion=="exclude_original")
                    if leftover_data.shape[0]>0:
                        counts += method(10,leftover_data[:,0],leftover_data[:,1],leftover_data[:,2],leftover_data[:,3],reduction="sum",exclude_original=exclusion=="exclude_original")

                    for k in k_list:
                        register(stat, "embedding"+str(i), exclusion, method.__name__, "top"+str(k),
                                 (counts[k-1]).detach().item())
                        register(stat, "embedding"+str(i), exclusion, method.__name__, "accuracy_top"+str(k),
                                 (counts[k-1]/float(total)).detach().item())
        return stat

    def word_similarity(self, w1, w2, normalize=False):
        if normalize:
            w1_effect = self.normalized_matrix[w1]
            w2_effect = self.normalized_matrix[w2]
        else:
            w1_effect = self.matrix[w1]
            w2_effect = self.matrix[w2]
        return (w1_effect * w2_effect).sum(dim=1).detach().numpy()


class EffectEmbeddingMixin(TopKMixin):
    @property
    def word_effects(self):
        if hasattr(self, "_word_effects"):
            return self._word_effects

        V, E = self.embeddings[0].weight.shape
        word_effects = torch.zeros((2, V, E), requires_grad=False, device=self.embeddings[0].weight.device)

        if False:
            B = self.hyper["batch_size"]
            ones  = torch.ones ([B,E], device=self.embeddings[0].weight.device)
            zeros = torch.zeros([B,E], device=self.embeddings[0].weight.device)
            for i in tqdm(range(0,V-B,B)):
                # if the bits are turned 0, this is a delete effect
                word_effects[0,i:i+B,:] = 1-self.discrete_add(ones , self.embeddings[0].weight[i:i+B], 1e-10, noise=False).round()
                # if the bits are turned 1, this is an add effect
                word_effects[1,i:i+B,:] = self.discrete_add(zeros, self.embeddings[0].weight[i:i+B], 1e-10, noise=False).round()
            # if the bits are turned 0, this is a delete effect
            word_effects[0,i+B:V,:] = 1-self.discrete_add(ones[:V-(i+B)] , self.embeddings[0].weight[i+B:V], 1e-10, noise=False).round()
            # if the bits are turned 1, this is an add effect
            word_effects[1,i+B:V,:] = self.discrete_add(zeros[:V-(i+B)], self.embeddings[0].weight[i+B:V], 1e-10, noise=False).round()
        else:
            ones  = torch.ones ([V,E], device=self.embeddings[0].weight.device)
            zeros = torch.zeros([V,E], device=self.embeddings[0].weight.device)
            # if the bits are turned 0, this is a delete effect
            word_effects[0] = 1-self.discrete_add(ones, self.embeddings[0].weight, 1e-10, noise=False).round()
            # if the bits are turned 1, this is an add effect
            word_effects[1] = self.discrete_add(zeros, self.embeddings[0].weight, 1e-10, noise=False).round()
        self._word_effects = word_effects
        return word_effects

    @property
    def posneg_effects(self):
        if hasattr(self, "_posneg_effects"):
            return self._posneg_effects

        dels = self.word_effects[0]
        adds = self.word_effects[1]
        if False:
            effects = torch.zeros_like(dels)
            B = self.hyper["batch_size"]
            V, E = effects.shape
            for i in tqdm(range(0,V-B,B)):
                effects[i:i+B,:] = adds[i:i+B] - dels[i:i+B]
            effects[i+B:V,:] = adds[i+B:V] - dels[i+B:V]
        else:
            effects = adds - dels
        self._posneg_effects = effects
        return effects

    @property
    def normalized_posneg_effects(self):
        if hasattr(self, "_normalized_posneg_effects"):
            return self._normalized_posneg_effects
        effects = F.normalize(self.posneg_effects)
        self._normalized_posneg_effects = effects
        return effects

    def effect_add(self, del1, add1, del2, add2):
        # computing the results of \ del1 ∪ add1 \ del2 ∪ add2
        del3 = del1
        del3 = torch.max(del3, del2)
        del3 = torch.min(del3, 1-add2)
        add3 = add1
        add3 = torch.min(add3, 1-del2)
        add3 = torch.max(add3, add2)
        return del3, add3

    def effect_sub(self, del1, add1, del2, add2):
        # computing the results of \ del1 ∪ add1 \ add2 ∪ del2
        del3 = del1
        del3 = torch.max(del3, add2)
        del3 = torch.min(del3, 1-del2)
        add3 = add1
        add3 = torch.min(add3, 1-add2)
        add3 = torch.max(add3, del2)
        return del3, add3

    def embed_sentence(self, indices, reduction="mean"):
        """
        sentences represented as a list of indices 
        """
        E = self.embeddings[0].weight.shape[1]
        V = self.word_effects.shape[0]
        device = self.embeddings[0].weight.device

        del_effect = torch.zeros(E).to(device)
        add_effect = torch.zeros(E).to(device)
        for w_idx in indices:
            del_effect, add_effect = \
                self.effect_add(del_effect,
                                add_effect,
                                self.word_effects[0][w_idx],
                                self.word_effects[1][w_idx])
        return add_effect - del_effect

    def accuracy(self, max_k, pred_emb, target, reduction="mean"):
        # Implementation note:
        # During `evaluate', a random initial state is generated and
        # the forward pass computes pred_emb.
        # While I wish to compute the "true" add/delete effects,
        # it will require a significant change to the `evaluate` code
        # (e.g. compute the forward pass once with ones, another with zeros),
        # which will be incompatible with the continuous code,
        # so this is not an option.
        #
        # the second best option is to compute [B,V,E] matrix, that is,
        # compute the successor state for every vocabulary for each random state,
        # but it seems impossible due to the memory issue.
        #
        # the third option, which I implemented, is to just compute the
        # incomplete set of effects which may be missing some add/del effects,
        # and try to find the topk nearest.
        #
        # note that this method cannot be used in BTLx (which should use the
        # code for the continuous embedding)

        # values: 0,-1,1
        # this is affected by the random initial state because
        # some add bits are not visible when the initial state is 1, and
        # some del bits are not visible when the initial state is 0
        incomplete_effects = pred_emb - self.initial_state # [B,E]

        incomplete_effects = F.normalize(incomplete_effects)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(incomplete_effects, complete_effects.T) # [B, V]
        return self.topk_accuracy(max_k, similarity, target, reduction=reduction)

    def accuracy_l1(self, max_k, pred_emb, target, reduction="mean"):
        # While this main accuracy code implements a cosine distance,
        # due to the nature of how I implemented the incomplete effects,
        # a similar, L1-based version may also be implemented
        # because a missing effect would affect the vector direction.
        # In contast, this L1-based version will maximize the number of matched effects
        # while ignoring the missing effects.

        # For example, in the original version, a complete effect vector [1,1]
        # and an incomplete effect vector [0,1] will be compared as follows.
        # First, the complete effect vector is normalized to [1/sqrt(2),1/sqrt(2)] ~ [0.71,0.71].
        # Next, the incomplete effect vector is normalized (it does not change).
        # Finally the cosine distance is obtained by a dot product.

        # In this code, however, we do this without normalization --- dot([1,1], [0,1]).
        # We can't tell if the first bit of the incomplete effect is 0 because the initial state is
        # [1,0] (the add-effect may be present but did nothing), or
        # because the add-effect is missing (the initial state is [0,0]).

        # While the difference is merely the presence of normalization, 
        # what it does is that it counts the number of matches, penalize the opposite match,
        # and ignore the bits that are missing in the incomplete effects (because they are 0).

        incomplete_effects = pred_emb - self.initial_state # [B,E]
        complete_effects = self.posneg_effects # [V,E]

        similarity = torch.matmul(incomplete_effects, complete_effects.T) # [B, V]
        return self.topk_accuracy(max_k, similarity, target, reduction=reduction)
        
    def predict(self, max_k, pred_emb):
        B, E = pred_emb.shape
        if not hasattr(self, 'initial_state'):
            self.initial_state = torch.zeros(B, E).to(pred_emb.device)
        incomplete_effects = pred_emb - self.initial_state # [B,E]
        incomplete_effects = F.normalize(incomplete_effects)
        complete_effects   = self.normalized_posneg_effects
        similarity = torch.matmul(incomplete_effects, complete_effects.T) # [B, V]
        return self.topk_indices(max_k, similarity)

    def predict_l1(self, max_k, pred_emb):
        B, E = pred_emb.shape
        incomplete_effects = pred_emb - self.initial_state # [B,E]
        complete_effects   = self.posneg_effects
        similarity = torch.matmul(incomplete_effects, complete_effects.T) # [B, V]
        return self.topk_indices(max_k, similarity)

    def analogy_common(self, max_k, a1, a2, b1, order=0):

        # Note: add and del are prime to each other (cf. IJCAI20 paper)

        dels = self.word_effects[0]
        adds = self.word_effects[1]
        if order == 0:
            # b1 - a1 + a2
            b2_dels = dels[b1]
            b2_adds = adds[b1]
            # undo a1: / add(a1) + del(a1)
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a1], adds[a1])
            # do a2: / del(a2) + add(a2)
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a2], adds[a2])
        elif order == 1:
            # b1 + a2 - a1
            b2_dels = dels[b1]
            b2_adds = adds[b1]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a2], adds[a2])
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a1], adds[a1])
        elif order == 2:
            # a2 - a1 + b1
            b2_dels = dels[a2]
            b2_adds = adds[a2]
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a1], adds[a1])
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])
        elif order == 3:
            # a2 + b1 - a1
            b2_dels = dels[a2]
            b2_adds = adds[a2]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a1], adds[a1])
        elif order == 4:
            # -a1 + b1 + a2
            b2_dels = adds[a1]
            b2_adds = dels[a1]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a2], adds[a2])
        elif order == 5:
            # -a1 + a2 + b1
            b2_dels = adds[a1]
            b2_adds = dels[a1]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a2], adds[a2])
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])

        # 0,-1,1
        b2_effect = b2_adds - b2_dels
        return b2_effect

    def analogy_only_b_common(self, max_k, a1, a2, b1):
        # parallel methods to Linzen
        # https://www.aclweb.org/anthology/W16-2503/
        # Equation (5)
        dels = self.word_effects[0]
        adds = self.word_effects[1]

        # b1 - a1 + a2
        b2_dels = dels[b1]
        b2_adds = adds[b1]

        # 0,-1,1
        b2_effect = b2_adds - b2_dels
        return b2_effect

    def analogy_ignore_a_common(self, max_k, a1, a2, b1, order=0):

        # Note: add and del are prime to each other (cf. IJCAI20 paper)
        dels = self.word_effects[0]
        adds = self.word_effects[1]

        if order == 0:
            # b1 + a2
            b2_dels = dels[b1]
            b2_adds = adds[b1]

            # do a2: / del(a2) + add(a2)
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a2], adds[a2])
        elif order == 1:
            # a2 + b1
            b2_dels = dels[a2]
            b2_adds = adds[a2]

            # do a2: / del(a2) + add(a2)
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])

        # 0,-1,1
        b2_effect = b2_adds - b2_dels
        return b2_effect

    def analogy_add_opposite_common(self, max_k, a1, a2, b1, order=0):

        dels = self.word_effects[0]
        adds = self.word_effects[1]

        if order == 0:
            # b1 - a2 + a1
            b2_dels = dels[b1]
            b2_adds = adds[b1]

            # undo a2: / add(a2) + del(a2)
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a2], adds[a2])

            # do a1: / del(a1) + add(a1)
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a1], adds[a1])
        elif order == 1:
            # b1 + a1 - a2
            b2_dels = dels[b1]
            b2_adds = adds[b1]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a1], adds[a1])
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a2], adds[a2])
        elif order == 2:
            # a1 - a2 + b1
            b2_dels = dels[a1]
            b2_adds = adds[a1]
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a2], adds[a2])
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])
        elif order == 3:
            # a1 + b1 - a2
            b2_dels = dels[a1]
            b2_adds = adds[a1]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])
            b2_dels, b2_adds = self.effect_sub(b2_dels, b2_adds, dels[a2], adds[a2])
        elif order == 4:
            # -a2 + b1 + a1
            b2_dels = adds[a2]
            b2_adds = dels[a2]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a1], adds[a1])
        elif order == 5:
            # -a2 + a1 + b1
            b2_dels = adds[a2]
            b2_adds = dels[a2]
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[a1], adds[a1])
            b2_dels, b2_adds = self.effect_add(b2_dels, b2_adds, dels[b1], adds[b1])

        # 0,-1,1
        b2_effect = b2_adds - b2_dels
        return b2_effect

    def accuracy_seqadd_base(self, b2_effect, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_seqadd(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=0)
        return self.accuracy_seqadd_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_1(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=1)
        return self.accuracy_seqadd_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_2(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=2)
        return self.accuracy_seqadd_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_3(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=3)
        return self.accuracy_seqadd_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_4(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=4)
        return self.accuracy_seqadd_base(b2_effect, max_k, a1, a2, b1, b2,reduction, exclude_original)

    def accuracy_seqadd_5(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=5)
        return self.accuracy_seqadd_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)    

    def accuracy_seqadd_l1_base(self,b2_effect, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        complete_effects   = self.posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_seqadd_l1(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=0)
        return self.accuracy_seqadd_l1_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_l1_1(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=1)
        return self.accuracy_seqadd_l1_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_l1_2(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=2)
        return self.accuracy_seqadd_l1_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_l1_3(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=3)
        return self.accuracy_seqadd_l1_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_l1_4(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=4)
        return self.accuracy_seqadd_l1_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_l1_5(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_common(max_k, a1, a2, b1, order=5)
        return self.accuracy_seqadd_l1_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_only_b(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_only_b_common(max_k, a1, a2, b1)

        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_seqadd_ignore_a(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_ignore_a_common(max_k, a1, a2, b1, order=0)

        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_seqadd_ignore_a_1(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_ignore_a_common(max_k, a1, a2, b1, order=1)

        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_seqadd_add_opposite_base(self, b2_effect, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_seqadd_add_opposite(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_add_opposite_common(max_k, a1, a2, b1, order=0)
        return self.accuracy_seqadd_add_opposite_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_add_opposite_1(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_add_opposite_common(max_k, a1, a2, b1, order=1)
        return self.accuracy_seqadd_add_opposite_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_add_opposite_2(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_add_opposite_common(max_k, a1, a2, b1, order=2)
        return self.accuracy_seqadd_add_opposite_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_add_opposite_3(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_add_opposite_common(max_k, a1, a2, b1, order=3)
        return self.accuracy_seqadd_add_opposite_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_add_opposite_4(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_add_opposite_common(max_k, a1, a2, b1, order=4)
        return self.accuracy_seqadd_add_opposite_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_add_opposite_5(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b2_effect = self.analogy_add_opposite_common(max_k, a1, a2, b1, order=5)
        return self.accuracy_seqadd_add_opposite_base(b2_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_reverse_base(self, b1_effect, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = F.normalize(b1_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b1_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b1_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b1, reduction=reduction)

    def accuracy_seqadd_reverse(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = self.analogy_common(max_k, a2, a1, b2, order=0)
        return self.accuracy_seqadd_reverse_base(b1_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)

    def accuracy_seqadd_reverse_1(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = self.analogy_common(max_k, a2, a1, b2, order=1)
        return self.accuracy_seqadd_reverse_base(b1_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)
   
    def accuracy_seqadd_reverse_2(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = self.analogy_common(max_k, a2, a1, b2, order=2)
        return self.accuracy_seqadd_reverse_base(b1_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)
   
    def accuracy_seqadd_reverse_3(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = self.analogy_common(max_k, a2, a1, b2, order=3)
        return self.accuracy_seqadd_reverse_base(b1_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)
   
    def accuracy_seqadd_reverse_4(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = self.analogy_common(max_k, a2, a1, b2, order=4)
        return self.accuracy_seqadd_reverse_base(b1_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)
   
    def accuracy_seqadd_reverse_5(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = self.analogy_common(max_k, a2, a1, b2, order=5)
        return self.accuracy_seqadd_reverse_base(b1_effect, max_k, a1, a2, b1, b2, reduction, exclude_original)
    
    def accuracy_seqadd_reverse_only_b(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        b1_effect = self.analogy_only_b_common(max_k, a1, a2, b2)

        b1_effect = F.normalize(b1_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b1_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b1_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b1, reduction=reduction)

    def predict_seqadd(self, max_k, a1, a2, b1, b2, exclude_original=False):

        b2_effect = self.analogy_common(max_k, a1, a2, b1)

        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_seqadd_l1(self, max_k, a1, a2, b1, b2, exclude_original=False):

        b2_effect = self.analogy_common(max_k, a1, a2, b1)

        complete_effects   = self.posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_seqadd_only_b(self, max_k, a1, a2, b1, b2, exclude_original=False):
        b2_effect = self.analogy_only_b_common(max_k, a1, a2, b1)

        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_seqadd_ignore_a(self, max_k, a1, a2, b1, b2, exclude_original=False):
        b2_effect = self.analogy_ignore_a_common(max_k, a1, a2, b1)

        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_seqadd_add_opposite(self, max_k, a1, a2, b1, b2, exclude_original=False):
        b2_effect = self.analogy_add_opposite_common(max_k, a1, a2, b1)

        b2_effect = F.normalize(b2_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b2_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b2_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_seqadd_reverse(self, max_k, a1, a2, b1, b2, exclude_original=False):

        b1_effect = self.analogy_common(max_k, a2, a1, b2)

        b1_effect = F.normalize(b1_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b1_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b1_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_seqadd_reverse_only_b(self, max_k, a1, a2, b1, b2, exclude_original=False):
        b1_effect = self.analogy_only_b_common(max_k, a1, a2, b2)

        b1_effect = F.normalize(b1_effect)
        complete_effects   = self.normalized_posneg_effects

        similarity = torch.matmul(b1_effect, complete_effects.T) # [B, V]
        if exclude_original:
            idx = torch.arange(len(a1)).to(b1_effect.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def print_knn(self, words, max_k, data):
        from colors import yellow

        idxs = opt_cuda(torch.tensor(list(map(data.word_to_idx, words)),dtype=int))

        self.set_initial_state(len(idxs),idxs.device)
        self = opt_cuda(self)

        stat = {}
        for i,embedding in enumerate(self.embeddings):
            if len(self.embeddings) > 1:
                print("---- embedding",i,"--------------------")

            embedding = self.target_embedding(idxs)
            for method in [self.predict,
                           self.predict_l1]:

                topk = method(max_k, embedding)

                for word, topk_per_word in zip(words, topk):
                    print("======================================================")
                    print(f"{max_k}-NN of", yellow(word.ljust(15)), "in", method.__name__, ":",end="")
                    for j in topk_per_word:
                        print(data.idx_to_word(j).ljust(15),end="")
                    print("")
                    register(stat, i, method.__name__, word, list(map(data.idx_to_word,topk_per_word.tolist())))
        return stat

    def print_analogy(self, a1_words, a2_words, b1_words, b2_words, max_k, data):
        from colors import yellow, magenta, red

        a1 = opt_cuda(torch.tensor(list(map(data.word_to_idx, a1_words)),dtype=int))
        a2 = opt_cuda(torch.tensor(list(map(data.word_to_idx, a2_words)),dtype=int))
        b1 = opt_cuda(torch.tensor(list(map(data.word_to_idx, b1_words)),dtype=int))
        b2 = opt_cuda(torch.tensor(list(map(data.word_to_idx, b2_words)),dtype=int))

        self.set_initial_state(len(a1),a1.device)
        self = opt_cuda(self)

        stat = {}
        for i,embedding in enumerate(self.embeddings):
            if len(self.embeddings) > 1:
                print("---- embedding",i,"--------------------")

            for method in [self.predict_seqadd,
                           self.predict_seqadd_l1,
                           self.predict_seqadd_only_b,
                           self.predict_seqadd_ignore_a,
                           self.predict_seqadd_add_opposite]:

                b2_topk_include = method(max_k, a1, a2, b1, b2, exclude_original=False)
                b2_topk_exclude = method(max_k, a1, a2, b1, b2, exclude_original=True)
                for a1_word, a2_word, b1_word, b2_word, topk_per_word_include, topk_per_word_exclude in zip(a1_words, a2_words, b1_words, b2_words, b2_topk_include, b2_topk_exclude):
                    print("======================================================")
                    print(f"{max_k}-NN of", f"{yellow(a2_word)}-{yellow(a1_word)}+{yellow(b1_word)}(={yellow(b2_word)})", "in", method.__name__, ":")
                    for j in topk_per_word_include:
                        print(magenta(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Include original words")
                    for j in topk_per_word_exclude:
                        print(red(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Exclude original words")
                    print("")
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(include)", list(map(data.idx_to_word,topk_per_word_include.tolist())))
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(exclude)", list(map(data.idx_to_word,topk_per_word_exclude.tolist())))

            for method in [self.predict_seqadd_reverse,
                           self.predict_seqadd_reverse_only_b]:

                b1_topk_include = method(max_k, a1, a2, b1, b2, exclude_original=False)
                b1_topk_exclude = method(max_k, a1, a2, b1, b2, exclude_original=True)
                for a1_word, a2_word, b1_word, b2_word, topk_per_word_include, topk_per_word_exclude in zip(a1_words, a2_words, b1_words, b2_words, b1_topk_include, b1_topk_exclude):
                    print("======================================================")
                    print(f"{max_k}-NN of", f"{yellow(a1_word)}-{yellow(a2_word)}+{yellow(b2_word)}(={yellow(b1_word)})", "in", method.__name__, ":")
                    for j in topk_per_word_include:
                        print(magenta(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Include original words")
                    for j in topk_per_word_exclude:
                        print(red(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Exclude original words")
                    print("")
                    register(stat, i, method.__name__, f"{a1_word}-{a2_word}+{b2_word}={b1_word}(include)", list(map(data.idx_to_word,topk_per_word_include.tolist())))
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(exclude)", list(map(data.idx_to_word,topk_per_word_exclude.tolist())))
        return stat

    def evaluate_analogy_category(self, category_data, data):
        """
        :param self: pytorch self with embedding matrix
        :param category_data: list of four words pair analogy
        :param data: dataset object with word to index information
        :return stat: accuracy and count of anlalogy data
        """
        self.eval()
        stat = {}
        stat["total"] = len(category_data)
        total = len(category_data)
        category_data = np.array([ list(map(data.word_to_idx,row))
                                  for row in category_data ])
        category_data = torch.tensor(category_data)
        # disgard examples where any word is <unk>, automatically wrong
        category_data = category_data[(category_data!=0).all(1),:]
        after_total = category_data.shape[0]
        print(f"{total-after_total}/{total} ({100-after_total/total*100:.1f}%) removed (with <unk> in any of 4 words)")
        k_list =       [1, 5, 10]
        for i,embedding in enumerate(self.embeddings):
            for exclusion in ["include_original", "exclude_original"]:
                for method in [self.accuracy_seqadd,
                               self.accuracy_seqadd_1,
                               self.accuracy_seqadd_2,
                               self.accuracy_seqadd_3,
                               self.accuracy_seqadd_4,
                               self.accuracy_seqadd_5,
                               self.accuracy_seqadd_l1,
                               self.accuracy_seqadd_l1_1,
                               self.accuracy_seqadd_l1_2,
                               self.accuracy_seqadd_l1_3,
                               self.accuracy_seqadd_l1_4,
                               self.accuracy_seqadd_l1_5,
                               self.accuracy_seqadd_only_b,
                               self.accuracy_seqadd_ignore_a,
                               self.accuracy_seqadd_ignore_a_1,
                               self.accuracy_seqadd_add_opposite,
                               self.accuracy_seqadd_add_opposite_1,
                               self.accuracy_seqadd_add_opposite_2,
                               self.accuracy_seqadd_add_opposite_3,
                               self.accuracy_seqadd_add_opposite_4,
                               self.accuracy_seqadd_add_opposite_5,
                               self.accuracy_seqadd_reverse,
                               self.accuracy_seqadd_reverse_1,
                               self.accuracy_seqadd_reverse_2,
                               self.accuracy_seqadd_reverse_3,
                               self.accuracy_seqadd_reverse_4,
                               self.accuracy_seqadd_reverse_5,
                               self.accuracy_seqadd_reverse_only_b]:
                    counts = method(10,category_data[:,0],category_data[:,1],category_data[:,2],category_data[:,3],reduction="sum",exclude_original=exclusion=="exclude_original")
                    for k in k_list:
                        register(stat, "embedding"+str(i), exclusion, method.__name__, "top"+str(k),
                                 counts[k-1].detach().item())
                        register(stat, "embedding"+str(i), exclusion, method.__name__, "accuracy_top"+str(k),
                                 (counts[k-1]/float(total)).detach().item())
        return stat

    def word_similarity(self, w1, w2, normalize=False):
        if normalize:
            w1_effect = self.normalized_posneg_effects[w1]
            w2_effect = self.normalized_posneg_effects[w2]
        else:
            w1_effect = self.posneg_effects[w1]
            w2_effect = self.posneg_effects[w2]
        return (w1_effect * w2_effect).sum(dim=1).detach().numpy()


class HybridEmbeddingMixin(TopKMixin):
    def accuracy(self, max_k, pred_emb, target, reduction="mean"):
        # returns the results for all k < max_k
        # B: batch
        # E: embedding
        # V: vocab
        pred_emb   = F.normalize(pred_emb) # [B, E]
        matrix = self.normalized_matrix
        similarity = torch.matmul(pred_emb, matrix.T) # [B, V]
        return self.topk_accuracy(max_k, similarity, target, reduction=reduction)

    def predict(self, max_k, pred_emb):
        pred_emb   = F.normalize(pred_emb) # [B, E]
        matrix = self.normalized_matrix
        similarity = torch.matmul(pred_emb, matrix.T) # [B, V]
        return self.topk_indices(max_k, similarity)

    def embed_sentence(self, indices, reduction="mean"):
        """
        sentences represented as a list of indices 
        """
        emb1 = self.model1.embed_sentence(indices, reduction)
        emb2 = self.model2.embed_sentence(indices, reduction)
        return torch.cat((emb1,emb2), 0)

    @property
    def normalized_matrix(self):
        # cache the L2 normalized matrix
        # TODO this may need to change, normalize after cat or before cat
        if hasattr(self, "_normalized_matrix"):
            return self._normalized_matrix

        matrix = torch.cat((self.model1.embeddings2.weight, 
                            self.model2.posneg_effects), 1) # [vocab, embeddingX2]
        matrix = F.normalize(matrix)   # [vocab, embedding]
        self._normalized_matrix = matrix
        return matrix

    @property
    def matrix(self):
        return torch.cat((self.model1.embeddings2.weight, self.model2.posneg_effects), 1)

    # Assume a1:a2 = b1:b2, where a1,a2,b1,b2 are the word index/indices.
    # b2 is the target word index/indices.
    def accuracy_3cosadd(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (1)
        # the input vector is normalized
        matrix = self.normalized_matrix # normalized
        b2_emb_1 = self.model1.analogy_common(max_k, a1, a2, b1)
        b2_emb_2 = self.model2.analogy_common(max_k, a1, a2, b1, order=5)
        b2_emb = torch.cat((b2_emb_1, b2_emb_2), 1)
        b2_emb = F.normalize(b2_emb, dim=0) 
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_only_b(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b2_emb_1 = self.model1.analogy_only_b_common(max_k, a1, a2, b1)
        b2_emb_2 = self.model2.analogy_only_b_common(max_k, a1, a2, b1)
        b2_emb = torch.cat((b2_emb_1, b2_emb_2), 1)
        b2_emb = F.normalize(b2_emb, dim=0) 
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_ignore_a(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b2_emb_1 = self.model1.analogy_ignore_a_common(max_k, a1, a2, b1)
        b2_emb_2 = self.model2.analogy_ignore_a_common(max_k, a1, a2, b1, order=1)
        b2_emb = torch.cat((b2_emb_1, b2_emb_2), 1)
        b2_emb = F.normalize(b2_emb, dim=0) 
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)
    
    def accuracy_add_opposite(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b2_emb_1 = self.model1.analogy_add_opposite_common(max_k, a1, a2, b1)
        b2_emb_2 = self.model2.analogy_add_opposite_common(max_k, a1, a2, b1, order=5)
        b2_emb = torch.cat((b2_emb_1, b2_emb_2), 1)
        b2_emb = F.normalize(b2_emb, dim=0) 
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b2, reduction=reduction)

    def accuracy_reverse(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b1_emb_1 = self.model1.analogy_common(max_k, a2, a1, b2)
        b1_emb_2 = self.model2.analogy_common(max_k, a2, a1, b2, order=5)
        b1_emb = torch.cat((b1_emb_1, b1_emb_2), 1)
        b1_emb = F.normalize(b1_emb, dim=0) 
        similarity = torch.matmul(b1_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b1, reduction=reduction)

    def accuracy_reverse_only_b(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b1_emb_1 = self.model1.analogy_only_b_common(max_k, a2, a1, b2)
        b1_emb_2 = self.model2.analogy_only_b_common(max_k, a2, a1, b2)
        b1_emb = torch.cat((b1_emb_1, b1_emb_2), 1)
        b1_emb = F.normalize(b1_emb, dim=0) 
        similarity = torch.matmul(b1_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_accuracy(max_k, similarity, b1, reduction=reduction)

    def predict_3cosadd(self, max_k, a1, a2, b1, b2, reduction="mean", exclude_original=False):
        # Levy and Goldberg
        # https://www.aclweb.org/anthology/W14-1618.pdf
        # Equation (1)
        # the input vector is normalized
        matrix = self.normalized_matrix # normalized
        b2_emb_1 = self.model1.analogy_common(max_k, a1, a2, b1)
        b2_emb_2 = self.model2.analogy_common(max_k, a1, a2, b1, order=5)
        b2_emb = torch.cat((b2_emb_1, b2_emb_2), 1)
        b2_emb = F.normalize(b2_emb, dim=0) 
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_only_b(self, max_k, a1, a2, b1, b2, exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b1_emb_1 = self.model1.analogy_only_b_common(max_k, a2, a1, b1)
        b1_emb_2 = self.model2.analogy_only_b_common(max_k, a2, a1, b1)
        b1_emb = torch.cat((b1_emb_1, b1_emb_2), 1)
        b1_emb = F.normalize(b1_emb, dim=0) 
        similarity = torch.matmul(b1_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_ignore_a(self, max_k, a1, a2, b1, b2, exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b2_emb_1 = self.model1.analogy_ignore_a_common(max_k, a1, a2, b1)
        b2_emb_2 = self.model2.analogy_ignore_a_common(max_k, a1, a2, b1, order=1)
        b2_emb = torch.cat((b2_emb_1, b2_emb_2), 1)
        b2_emb = F.normalize(b2_emb, dim=0) 
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)
    
    def predict_add_opposite(self, max_k, a1, a2, b1, b2, exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b2_emb_1 = self.model1.analogy_add_opposite_common(max_k, a1, a2, b1)
        b2_emb_2 = self.model2.analogy_add_opposite_common(max_k, a1, a2, b1, order=5)
        b2_emb = torch.cat((b2_emb_1, b2_emb_2), 1)
        b2_emb = F.normalize(b2_emb, dim=0) 
        similarity = torch.matmul(b2_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b1] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_reverse(self, max_k, a1, a2, b1, b2, exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b1_emb_1 = self.model1.analogy_common(max_k, a2, a1, b2)
        b1_emb_2 = self.model2.analogy_common(max_k, a2, a1, b2, order=5)
        b1_emb = torch.cat((b1_emb_1, b1_emb_2), 1)
        b1_emb = F.normalize(b1_emb, dim=0) 
        similarity = torch.matmul(b1_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def predict_reverse_only_b(self, max_k, a1, a2, b1, b2, exclude_original=False):
        matrix = self.normalized_matrix # normalized
        b1_emb_1 = self.model1.analogy_only_b_common(max_k, a2, a1, b2)
        b1_emb_2 = self.model2.analogy_only_b_common(max_k, a2, a1, b2)
        b1_emb = torch.cat((b1_emb_1, b1_emb_2), 1)
        b1_emb = F.normalize(b1_emb, dim=0) 
        similarity = torch.matmul(b1_emb, matrix.T) # normalized vs normalized
        if exclude_original:
            idx = torch.arange(len(a1)).to(matrix.device)
            similarity[idx,a1] = -float("inf")
            similarity[idx,a2] = -float("inf")
            similarity[idx,b2] = -float("inf")
        return self.topk_indices(max_k, similarity)

    def print_knn(self, words, max_k, data):
        from colors import yellow

        idxs = opt_cuda(torch.tensor(list(map(data.word_to_idx, words)),dtype=int))

        self.model2.set_initial_state(len(idxs),idxs.device)
        self = opt_cuda(self)

        stat = {}
        for i,embedding in enumerate(self.model1.embeddings):
            if len(self.model1.embeddings) > 1:
                print("---- embedding",i,"--------------------")

            embedding = self.target_embedding(idxs)
            for method in [self.predict]:

                topk = method(max_k, embedding)

                for word, topk_per_word in zip(words, topk):
                    print("======================================================")
                    print(f"{max_k}-NN of", yellow(word.ljust(15)), "in", method.__name__, ":",end="")
                    for j in topk_per_word:
                        print(data.idx_to_word(j).ljust(15),end="")
                    print("")
                    register(stat, i, method.__name__, word, list(map(data.idx_to_word,topk_per_word.tolist())))
        return stat

    def print_analogy(self, a1_words, a2_words, b1_words, b2_words, max_k, data):
        from colors import yellow, magenta, red

        a1 = opt_cuda(torch.tensor(list(map(data.word_to_idx, a1_words)),dtype=int))
        a2 = opt_cuda(torch.tensor(list(map(data.word_to_idx, a2_words)),dtype=int))
        b1 = opt_cuda(torch.tensor(list(map(data.word_to_idx, b1_words)),dtype=int))
        b2 = opt_cuda(torch.tensor(list(map(data.word_to_idx, b2_words)),dtype=int))

        self.model2.set_initial_state(len(a1),a1.device)
        self = opt_cuda(self)

        stat = {}
        for i,embedding in enumerate(self.model1.embeddings):
            if len(self.model1.embeddings) > 1:
                print("---- embedding",i,"--------------------")

            for method in [self.predict_3cosadd,
                           self.predict_only_b,
                           self.predict_ignore_a,
                           self.predict_add_opposite]:

                b2_topk_include = method(max_k, a1, a2, b1, b2, exclude_original=False)
                b2_topk_exclude = method(max_k, a1, a2, b1, b2, exclude_original=True)
                for a1_word, a2_word, b1_word, b2_word, topk_per_word_include, topk_per_word_exclude in zip(a1_words, a2_words, b1_words, b2_words, b2_topk_include, b2_topk_exclude):
                    print("======================================================")
                    print(f"{max_k}-NN of", f"{yellow(a2_word)}-{yellow(a1_word)}+{yellow(b1_word)}(={yellow(b2_word)})", "in", method.__name__, ":")
                    for j in topk_per_word_include:
                        print(magenta(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Include original words")
                    for j in topk_per_word_exclude:
                        print(red(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Exclude original words")
                    print("")
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(include)", list(map(data.idx_to_word,topk_per_word_include.tolist())))
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(exclude)", list(map(data.idx_to_word,topk_per_word_exclude.tolist())))

            for method in [self.predict_reverse,
                           self.predict_reverse_only_b]:

                b1_topk_include = method(max_k, a1, a2, b1, b2, exclude_original=False)
                b1_topk_exclude = method(max_k, a1, a2, b1, b2, exclude_original=True)
                for a1_word, a2_word, b1_word, b2_word, topk_per_word_include, topk_per_word_exclude in zip(a1_words, a2_words, b1_words, b2_words, b1_topk_include, b1_topk_exclude):
                    print("======================================================")
                    print(f"{max_k}-NN of", f"{yellow(a1_word)}-{yellow(a2_word)}+{yellow(b2_word)}(={yellow(b1_word)})", "in", method.__name__, ":")
                    for j in topk_per_word_include:
                        print(magenta(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Include original words")
                    for j in topk_per_word_exclude:
                        print(red(data.idx_to_word(j).ljust(15)),end="")
                    print("<-Exclude original words")
                    print("")
                    register(stat, i, method.__name__, f"{a1_word}-{a2_word}+{b2_word}={b1_word}(include)", list(map(data.idx_to_word,topk_per_word_include.tolist())))
                    register(stat, i, method.__name__, f"{a2_word}-{a1_word}+{b1_word}={b2_word}(exclude)", list(map(data.idx_to_word,topk_per_word_exclude.tolist())))
        return stat

    def evaluate_analogy_category(self, category_data, data):
        """
        :param self: pytorch self with embedding matrix
        :param category_data: list of four words pair analogy
        :param data: dataset object with word to index information
        :return stat: accuracy and count of anlalogy data
        """
        stat = {}
        stat["total"] = len(category_data)
        total = len(category_data)
        category_data = np.array([ list(map(data.word_to_idx,row))
                                  for row in category_data ])
        category_data = torch.tensor(category_data)
        # disgard examples where any word is <unk>, automatically wrong
        category_data = category_data[(category_data!=0).all(1),:]
        after_total = category_data.shape[0]

        print(f"{total-after_total}/{total} ({100-after_total/total*100:.1f}%) removed (with <unk> in any of 4 words)")
        k_list = [1, 5, 10]

        b = 500 # batch size
        if after_total > b:
            category_data_b = category_data[:(len(category_data)//b)*b]
            category_data_b = category_data_b.reshape((len(category_data)//b),b,-1)
        else:
            category_data_b = None
        leftover_data = category_data[(len(category_data)//b)*b:]
        del category_data

        for i,embedding in enumerate(self.model1.embeddings):
            # limiting the number of tracks to reduce memory usage
            for exclusion in ["exclude_original"]: # "include_original",
                for method in [self.accuracy_3cosadd,
                               self.accuracy_only_b,
                               self.accuracy_ignore_a,
                               self.accuracy_add_opposite,
                               self.accuracy_reverse,
                               self.accuracy_reverse_only_b
                               ]:
                    counts = torch.zeros(10)
                    if category_data_b is not None:
                        for batch in category_data_b:
                            counts += method(10,batch[:,0],batch[:,1],batch[:,2],batch[:,3],reduction="sum",exclude_original=exclusion=="exclude_original")
                    if leftover_data.shape[0]>0:
                        counts += method(10,leftover_data[:,0],leftover_data[:,1],leftover_data[:,2],leftover_data[:,3],reduction="sum",exclude_original=exclusion=="exclude_original")

                    for k in k_list:
                        register(stat, "embedding"+str(i), exclusion, method.__name__, "top"+str(k),
                                 (counts[k-1]).detach().item())
                        register(stat, "embedding"+str(i), exclusion, method.__name__, "accuracy_top"+str(k),
                                 (counts[k-1]/float(total)).detach().item())
        return stat

    def word_similarity(self, w1, w2, normalize=False):
        if normalize:
            w1_effect = self.normalized_matrix[w1]
            w2_effect = self.normalized_matrix[w2]
        else:
            w1_effect = self.matrix[w1]
            w2_effect = self.matrix[w2]
        return (w1_effect * w2_effect).sum(dim=1).detach().numpy()

class TemperatureAnnealingMixin:
    def __init__(self, hyper, data):
        super().__init__(hyper, data)
        self.temperature = self.hyper["annealing_max"]
        self.callbacks.append(
            LambdaCallback(
                on_batch_begin=self.annealing_step))
        if self.hyper["straight_through"]:
            self.discretizer = utils.STBinConcrete(self.hyper["beta"])
        else:
            self.discretizer = utils.BinConcrete(self.hyper["beta"])
        self.metrics["temperature"] = lambda **locals: self.temperature

    def annealing_step(self,epoch=None,batch_n=0,num_batches=None,split=None,**locals):
        # assume anealing ends as training ends
        if split=="train":
            self.temperature = \
                anneal_function(\
                    self.hyper["annealing_schedule"],
                    self.hyper["annealing_max"],
                    self.hyper["annealing_min"],
                    self.hyper["annealing_start"],
                    epoch + batch_n/num_batches,
                    self.hyper["annealing_end"])
        else:
            self.temperature = 1e-10




# Note: I store the embedding layer in self.embeddings
# and then store them in self.embeddings1, 2.
# This is in order to allow the evaluator able to iterate
# over several embedding (e.g. CBOW_NEG may use different embedding
# for the input and output)

class NegativeSampling(nn.Module):
    def __init__(self, num_samples, embeddings_out, weights=None, loss="logsigmoiddot", reduction='mean', center=0.0):
        super().__init__()
        self.num_samples = num_samples
        self.embeddings_out = embeddings_out
        self.weights = None
        if weights is not None:
            assert min(weights) >= 0, "Each weight should be >= 0"
            assert sum(weights) > 0, "Weight sum should be > 0"
            self.weights = Variable(weights).float()
        self.reduction = reduction
        self.lossname = loss
        self.loss = utils.losses[loss]
        self.center = center

    def sample(self, pred_emb):
        B, E, K = [*pred_emb.shape, self.num_samples]
        if self.weights is not None:
            noise = torch.multinomial(self.weights, K*B, True)
        else:
            noise = torch.randint(0, self.embeddings_out.weights.shape[0], K*B)
        noise = noise.to(pred_emb.device)

        return noise

    def forward(self, pred_emb, true):
        assert pred_emb.device == true.device
        device = true.device
        true_emb = self.embeddings_out(true)   # [B, E]

        if self.lossname in ["cosine_distance", "logsigmoiddot"]:
            # directional loss requires the centered vectors
            pred_emb = pred_emb - self.center
            true_emb = true_emb - self.center

        loss_pos = self.loss(true_emb, pred_emb, reduction=self.reduction)
        loss_neg = torch.tensor(0.0, device=device)

        if self.lossname not in ["cosine_distance", "logsigmoiddot"]:
            pred_emb = pred_emb - self.center
            true_emb = true_emb - self.center

        B, E, K = [*pred_emb.shape, self.num_samples]
        if K==0:
            # note: in all paths, return the centered embeddings
            return loss_pos, loss_neg, pred_emb, true_emb

        noise     = self.sample(pred_emb)
        noise_emb = self.embeddings_out(noise)  # [K*B, E]
        noise_emb = noise_emb.view(K, B, E)     # [K,B,E]
        noise_emb = noise_emb - self.center

        if self.lossname in ["kl", "js"]:
            negative_lossname = "logsigmoiddot"
        else:
            negative_lossname = self.lossname

        negative_loss = utils.losses[negative_lossname]

        if negative_lossname in ["logsigmoiddot", "cosine_distance"]:
            # directional loss (geodesic distance). minimize the loss toward the opposite vector
            noise_emb = -noise_emb
            for k in range(K):
                loss_neg += negative_loss(noise_emb[k], pred_emb, reduction=self.reduction)
        elif negative_lossname in ["l1", "l2"]:
            # minimize the inverse of the loss.
            # This creates the landscape similar to logsigmoiddot that
            # the negative loss is very large when the target is close but
            # is small if the vector is "far enough".
            # The landscape resembles those of magnets or negatively charged electrons.
            for k in range(K):
                loss_neg += 1.0 / negative_loss(noise_emb[k], pred_emb, reduction=self.reduction)
        #
        # note: in all paths, embeddings must be centered
        return loss_pos+loss_neg, loss_pos, loss_neg, pred_emb, true_emb


class CBOW_NEG(ContinuousEmbeddingMixin, Trainer, nn.Module):
    datasetclass = WikiText2DataSetCBOW
    def __init__(self, hyper, data):
        super().__init__(hyper, data)
        self.embeddings = [nn.Embedding(data.vocab_size, self.hyper["embedding"])]
        self.embeddings1 = self.embeddings[0]
        self.embeddings2 = self.embeddings[0]
        loss = NegativeSampling(self.hyper["negative_sample"],
                                self.embeddings1,
                                loss=self.hyper["loss"],
                                # follows the original paper
                                weights=torch.pow(torch.tensor(data._idx_to_freq),0.75),
                                reduction='sum')
        def compute_loss(pred_emb, true):
            # print("CBOWNEG! pred_emb:", pred_emb.shape, "true:", true.shape)
            l,p,n,x,y = loss(pred_emb, true)
            # print("CBOWNEG! l1:", l)
            self.loss_main = l.detach()
            self.loss_pos  = p.detach()
            self.loss_neg  = n.detach()
            self.loss_cos = utils.losses["cosine_distance"](x,y,reduction="sum").detach()
            # self.loss_L2  = utils.losses["l2"]             (x,y,reduction="sum").detach()
            self.mean     = y.mean().detach().cpu()
            self.variance = y.var().detach().cpu()
            return l
        self.loss = compute_loss
        self.metrics["loss_main"] = lambda **locals: self.loss_main/self.hyper["batch_size"]
        self.metrics["loss_pos"]  = lambda **locals: self.loss_pos/self.hyper["batch_size"]
        self.metrics["loss_neg"]  = lambda **locals: self.loss_neg/self.hyper["batch_size"]
        self.metrics["loss_cos"]  = lambda **locals: self.loss_cos/self.hyper["batch_size"]
        # self.metrics["loss_L2"]   = lambda **locals: self.loss_L2/(self.hyper["batch_size"]*self.hyper["embedding"])
        self.metrics["stat_mean"] = lambda **locals: self.mean
        self.metrics["stat_var"]  = lambda **locals: self.variance

    def target_embedding(self, x):
        return self.embeddings2(x)

    def set_initial_state(self,**locals):
        pass

    def forward(self, x):
        x = self.embeddings1(x)
        x = torch.sum(x, dim=1)
        return x

# to accomadate trainer loop structure, each context = target word
# target = one of context words, context is fused with batch dimension
class SkipGram_NEG(ContinuousEmbeddingMixin, Trainer, nn.Module):
    datasetclass = WikiText2DataSetSkipGram
    def __init__(self, hyper, data):
        super().__init__(hyper, data)
        self.embeddings = [nn.Embedding(data.vocab_size, self.hyper["embedding"])]
        self.embeddings1 = self.embeddings[0]
        self.embeddings2 = self.embeddings[0]
        loss = NegativeSampling(self.hyper["negative_sample"],
                                self.embeddings1,
                                loss=self.hyper["loss"],
                                # follows the original paper
                                weights=torch.pow(torch.tensor(data._idx_to_freq),0.75),
                                reduction='sum')
        def compute_loss(pred_emb, true):
            l,p,n,x,y = loss(pred_emb, true)
            self.loss_main = l.detach()
            self.loss_pos  = p.detach()
            self.loss_neg  = n.detach()
            self.loss_cos = utils.losses["cosine_distance"](x,y,reduction="sum").detach()
            # self.loss_L2  = utils.losses["l2"]             (x,y,reduction="sum").detach()
            self.mean     = y.mean().detach().cpu()
            self.variance = y.var().detach().cpu()
            return l
        self.loss = compute_loss
        self.metrics["loss_main"] = lambda **locals: self.loss_main/self.hyper["batch_size"]
        self.metrics["loss_pos"]  = lambda **locals: self.loss_pos/self.hyper["batch_size"]
        self.metrics["loss_neg"]  = lambda **locals: self.loss_neg/self.hyper["batch_size"]
        self.metrics["loss_cos"]  = lambda **locals: self.loss_cos/self.hyper["batch_size"]
        # self.metrics["loss_L2"]   = lambda **locals: self.loss_L2/(self.hyper["batch_size"]*self.hyper["embedding"])
        self.metrics["stat_mean"] = lambda **locals: self.mean
        self.metrics["stat_var"]  = lambda **locals: self.variance

    def target_embedding(self, x):
        return self.embeddings2(x)

    def set_initial_state(self,**locals):
        pass

    def forward(self, x):
        x = self.embeddings1(x)
        # note: this has a singleton axis for the shape compatibility with CBOW
        x = x.squeeze()
        return x

class BTL_Mixin(EffectEmbeddingMixin, TemperatureAnnealingMixin, Trainer, nn.Module):
    def __init__(self, hyper, data):
        super().__init__(hyper, data)
        self.embeddings = [nn.Embedding(data.vocab_size, self.hyper["embedding"])]
        self.embeddings1 = self.embeddings[0]
        self.embeddings2 = self.embeddings[0]
        if self.hyper["initialization"] == "logistic":
            u = torch.rand(data.vocab_size, self.hyper["embedding"], dtype=torch.float64)
            M, eps = torch.finfo(torch.float64).max, torch.finfo(torch.float64).eps
            Mu = u * M
            Mu = torch.clamp(Mu, eps, M-eps)
            self.embeddings2.weight = torch.nn.Parameter((torch.log(Mu)-torch.log(M-Mu)).float())
        self.bn1 = nn.BatchNorm1d(self.hyper["embedding"], affine=self.hyper["affine"])
        self.bn2 = nn.BatchNorm1d(self.hyper["embedding"], affine=self.hyper["affine"])
        self.callbacks.append(
            LambdaCallback(
                on_batch_begin=lambda *args,**kwargs: self.discretizer.clear()))
        self.metrics["loss_kl"]   = lambda **locals: self.discretizer.losses/self.hyper["batch_size"]
        self.metrics["loss_main"] = lambda **locals: self.loss_main/self.hyper["batch_size"]
        self.metrics["loss_pos"]  = lambda **locals: self.loss_pos/self.hyper["batch_size"]
        self.metrics["loss_neg"]  = lambda **locals: self.loss_neg/self.hyper["batch_size"]
        self.metrics["loss_cos"]  = lambda **locals: self.loss_cos/self.hyper["batch_size"]
        # self.metrics["loss_L2"]   = lambda **locals: self.loss_L2/(self.hyper["batch_size"]*self.hyper["embedding"])
        self.metrics["stat_mean"] = lambda **locals: self.mean
        self.metrics["stat_var"]  = lambda **locals: self.variance

    def set_initial_state(self,B,device):
        E = self.hyper["embedding"]
        if self.hyper["initial_state"] == "random":
            self.initial_state = torch.randint(0,2,(B,E),dtype=torch.float,device=device)
        elif self.hyper["initial_state"] == "zeros":
            self.initial_state = torch.zeros((B,E),device=device)
        elif self.hyper["initial_state"] == "ones":
            self.initial_state = torch.ones((B,E),device=device)
        elif self.hyper["initial_state"] == "half":
            self.initial_state = torch.full((B,E),0.5,device=device)

    def discrete_add(self, s, xi, temperature, noise, discretize=True):
        s = self.bn1(s)
        s += xi
        # s += self.bn2(xi)
        if discretize:
            s = self.discretizer(s, temperature, noise)
        return s


class CBOW_BTL_Mixin(BTL_Mixin):
    datasetclass = WikiText2DataSetCBOW

class SG_BTL_Mixin(BTL_Mixin):
    datasetclass = WikiText2DataSetSkipGram

class SG_BTL_Sequential(SG_BTL_Mixin):
    def __init__(self, hyper, data):
        super().__init__(hyper, data)
        loss = NegativeSampling(self.hyper["negative_sample"],
                                self.target_embedding,
                                loss=self.hyper["loss"],
                                # follows the original paper
                                weights=torch.pow(torch.tensor(data._idx_to_freq),0.75),
                                reduction='sum',
                                center=0.5)
        def compute_loss(pred_emb, true):
            l,p,n,x,y = loss(pred_emb, true)
            self.loss_main = l.detach()
            self.loss_pos  = p.detach()
            self.loss_neg  = n.detach()
            self.loss_cos  = utils.losses["cosine_distance"](x,y,reduction="sum").detach()
            # self.loss_L2   = utils.losses["l2"]             (x,y,reduction="sum").detach()
            self.mean      = y.mean().detach().cpu()
            self.variance  = y.var().detach().cpu()
            return l+self.discretizer.losses
        self.loss = compute_loss

    def target_embedding(self, x):
        x = self.embeddings1(x)
        s = self.initial_state
        B1, _ = x.shape
        B2, _ = s.shape
        if B1 == B2*self.hyper["negative_sample"]:
            # for negative sampling
            s = s.repeat(self.hyper["negative_sample"],1) 
        else:
            assert B1 == B2

        s = self.discrete_add(s, x, self.temperature, self.training and self.hyper["noise"])
        return s

    def forward(self, x): # since C = 1
        x = self.embeddings1(x)                                         # [B,C,E]
        B, C, E = x.shape       # batch, context, embedding
        self.set_initial_state(B,x.device)
        # initial mental state (s stands for state)
        s = self.initial_state
        s = self.discrete_add(s, x[:,0,:], self.temperature, self.training and self.hyper["noise"])
        return s


class CBOW_BTL_Sequential(CBOW_BTL_Mixin):
    def __init__(self, hyper, data):
        super().__init__(hyper, data)
        loss = NegativeSampling(self.hyper["negative_sample"],
                                self.target_embedding,
                                loss=self.hyper["loss"],
                                # follows the original paper
                                weights=torch.pow(torch.tensor(data._idx_to_freq),0.75),
                                reduction='sum',
                                center=0.5)
        def compute_loss(pred_emb, true):
            l,p,n,x,y = loss(pred_emb, true)
            self.loss_main = l.detach()
            self.loss_pos  = p.detach()
            self.loss_neg  = n.detach()
            self.loss_cos  = utils.losses["cosine_distance"](x,y,reduction="sum").detach()
            # self.loss_L2   = utils.losses["l2"]             (x,y,reduction="sum").detach()
            self.mean      = y.mean().detach().cpu()
            self.variance  = y.var().detach().cpu()
            return l+self.discretizer.losses
            # return l
        self.loss = compute_loss

    def target_embedding(self, x):
        x = self.embeddings1(x)
        s = self.initial_state
        B1, _ = x.shape
        B2, _ = s.shape
        if B1 == B2*self.hyper["negative_sample"]:
            # for negative sampling
            s = s.repeat(self.hyper["negative_sample"],1) 
        else:
            assert B1 == B2

        s = self.discrete_add(s, x, self.temperature, self.training and self.hyper["noise"])
        return s

    def forward(self, x):
        x = self.embeddings1(x)                                         # [B,C,E]
        B, C, E = x.shape       # batch, context, embedding
        self.set_initial_state(B,x.device)
        # initial mental state (s stands for state)
        s = self.initial_state
        for i in range(C):
            s = self.discrete_add(s, x[:,i,:], self.temperature, self.training and self.hyper["noise"])
        return s


class CBOW_HybridModel(HybridEmbeddingMixin, Trainer, nn.Module):
    datasetclass = WikiText2DataSetCBOW
    def __init__(self, hyper, data, model1=CBOW_NEG, model2=CBOW_BTL_Sequential):
        super().__init__(hyper, data)
        from copy import deepcopy
        CBOW_hyper = deepcopy(hyper)
        CBOW_hyper["initialization"] = "gaussian"
        model1 = model1(CBOW_hyper, data)
        model2 = model2(hyper, data)
        self.model1 = model1
        self.model2 = model2
        self.hyper = self.model2.hyper
        self.callbacks.append(
            LambdaCallback(
                on_batch_begin=lambda *args,**kwargs: self.model2.discretizer.clear()))
        # these attributes will never be used for calculation
        self.embeddings = self.model1.embeddings

        def compute_loss(pred_emb, true):
            l,p,n,c,m,v = self.loss_func(pred_emb, true)
            self.loss_main = l.detach()
            self.loss_pos  = p.detach()
            self.loss_neg  = n.detach()
            self.loss_cos  = c.detach()
            self.mean      = m.detach()
            self.variance  = v.detach()
            return l
        self.loss = compute_loss

    def target_embedding(self,x):
        B = x.shape
        target1 = self.model1.target_embedding(x)
        target2 = self.model2.target_embedding(x)
        return torch.cat((target1,target2), 1)

    def forward(self,x):
        B, E = x.shape
        target1 = self.model1.forward(x[:,:E//2])
        target2 = self.model2.forward(x[:,E//2:])
        return torch.cat((target1,target2), 1)

    def loss_func(self,pred_emb,true):
        B, E = pred_emb.shape
        l1 = self.model1.loss(pred_emb[:,:E//2],true)
        # print("l1:", l1)
        l2 = self.model2.loss(pred_emb[:,E//2:],true)
        # print("l2:", l2)
        # l = self.model1.loss_main + self.model2.loss_main
        l = l1+l2
        p  = self.model1.loss_pos + self.model2.loss_pos
        n  = self.model1.loss_neg + self.model2.loss_neg
        loss_cos  = self.model1.loss_cos + self.model2.loss_cos
        mean      = self.model1.mean + self.model2.mean
        variance  = self.model1.variance + self.model2.variance
        return l, p, n, loss_cos, mean, variance

class CBOW_HybridModelX(HybridEmbeddingMixin, Trainer, nn.Module):
    datasetclass = WikiText2DataSetCBOW
    def __init__(self, hyper, data, model1=CBOW_NEG, model2=CBOW_BTL_Sequential):
        super().__init__(hyper, data)
        from copy import deepcopy
        CBOW_hyper = deepcopy(hyper)
        CBOW_hyper["initialization"] = "gaussian"
        model1 = model1(CBOW_hyper, data)
        model2 = model2(hyper, data)
        self.model1 = model1
        self.model2 = model2
        self.hyper = self.model2.hyper
        self.callbacks.append(
            LambdaCallback(
                on_batch_begin=lambda *args,**kwargs: self.model2.discretizer.clear()))
        # these attributes will never be used for calculation, only used in evaluation 
        # scripts to determine the device location (cpu or gpu)
        self.embeddings = self.model1.embeddings
        loss = NegativeSampling(self.hyper["negative_sample"],
                                self.target_embedding,
                                loss=self.hyper["loss"],
                                # follows the original paper
                                weights=torch.pow(torch.tensor(data._idx_to_freq),0.75),
                                reduction='sum',
                                center=0.5)
        def compute_loss(pred_emb, true):
            l,p,n,x,y = loss(pred_emb, true)
            self.loss_main = l.detach()
            self.loss_pos  = p.detach()
            self.loss_neg  = n.detach()
            self.loss_cos  = utils.losses["cosine_distance"](x,y,reduction="sum").detach()
            # self.loss_L2   = utils.losses["l2"]             (x,y,reduction="sum").detach()
            self.mean      = y.mean().detach().cpu()
            self.variance  = y.var().detach().cpu()
            return l+self.model2.discretizer.losses
            # return l
        self.loss = compute_loss

    def target_embedding(self,x):
        B = x.shape
        target1 = self.model1.target_embedding(x)
        target2 = self.model2.target_embedding(x)
        return torch.cat((target1,target2), 1)

    def forward(self,x):
        B, E = x.shape
        target1 = self.model1.forward(x[:,:E//2])
        target2 = self.model2.forward(x[:,E//2:])
        return torch.cat((target1,target2), 1)

