import torch.nn as nn

from aggregators.encoder import NeighborhoodEncoder
from metric_learning.matching_model import Matcher
from metric_learning.prototype_model import Protonet


class DyFewShot(nn.Module):
    def __init__(self, args, ent_emb, rel_emb, device):
        super(DyFewShot, self).__init__()
        self.encoder = NeighborhoodEncoder(snap_encoder=args.snap_encoder,
                                           seq_encoder=args.seq_encoder,
                                           mask=args.mask,
                                           h_dim=args.h_dim,
                                           out_dim=args.out_dim,
                                           n_head=args.n_head,
                                           dropout=args.enc_dropout,
                                           finetune=args.finetune,
                                           ent_embds=ent_emb,
                                           rel_embds=rel_emb,
                                           rel_num=args.rel_num,
                                           ent_num=args.ent_num,
                                           emb_dim=args.emb_dim,
                                           seq_len=args.hist_len,
                                           device=device
                                            )

        if args.meta_type == 'protonet':
            self.meta_learner = Protonet(self.encoder, device=device)
        elif args.meta_type == 'matcher':
            self.meta_learner = Matcher(self.encoder, dropout=args.meta_dropout, steps=args.steps, h_dim=args.h_dim, out_dim=args.out_dim, device=device)
        elif args.meta_type == 'proto_match':
            self.meta_learner = ProtoMatcher(self.encoder, dropout=args.meta_dropout, steps=args.steps, h_dim=args.h_dim, out_dim=args.out_dim, n_shots=args.shots, device=device)
        elif args.meta_type == 'protohatt_match':
            self.meta_learner = ProtoHattMatcher(self.encoder, dropout=args.meta_dropout, steps=args.steps,
                                             h_dim=args.h_dim, out_dim=args.out_dim, n_shots=args.shots, device=device)

    def loss(self, sample):

        return self.meta_learner.loss(sample)

    def forward(self, sample):
        return self.meta_learner(sample)