""" Model for meta-transfer learning. """
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from models.backbones import Res12, WRN28
from utils.misc import emb_loss, count_acc


def euclidean_metric(query, proto):
    '''
    :param a: query
    :param b: proto
    :return: (num_sample, way)
    '''
    n = query.shape[0]  # num_samples
    m = proto.shape[0]  # way
    query = query.unsqueeze(1).expand(n, m, -1)
    proto = proto.unsqueeze(0).expand(n, m, -1)
    logits = -((query - proto) ** 2).sum(dim=2)  # -torch.pow(a-b,2)
    return logits  # (way, num_samples)


def cosine_metric(query, proto):
    '''
    :param query:  (bs, dim)
    :param proto:  (way, dim)
    :return: (bs, way)
    '''
    q = query.shape[0]  # bs
    p = proto.shape[0]  # way
    que2 = query.unsqueeze(1).expand(q, p, -1)
    pro2 = proto.unsqueeze(0).expand(q, p, -1)
    logit = torch.cosine_similarity(que2, pro2, dim=2)
    return logit  # (bs, way)


class BaseLearner(nn.Module):
    """The class for inner loop."""

    def __init__(self, args, z_dim):
        super().__init__()
        self.args = args
        self.z_dim = z_dim
        self.vars = nn.ParameterList()
        self.fc1_w = nn.Parameter(torch.ones([self.args.way, self.z_dim]))
        torch.nn.init.kaiming_normal_(self.fc1_w)
        self.vars.append(self.fc1_w)
        self.fc1_b = nn.Parameter(torch.zeros(self.args.way))
        self.vars.append(self.fc1_b)

    def forward(self, input_x, the_vars=None):
        if the_vars is None:
            the_vars = self.vars
        fc1_w = the_vars[0]
        fc1_b = the_vars[1]
        net = F.linear(input_x, fc1_w, fc1_b)
        return net

    def parameters(self):
        return self.vars


class TasLearner(nn.Module):
    """The class for outer loop."""
    def __init__(self, args, mode='meta', num_cls=64):
        super().__init__()
        self.args = args
        self.mode = mode
        self.update_lr = args.base_lr
        self.update_step = args.update_step
        z_dim = 640
        if self.args.dataset == 'cub':
            out_dim = 312
        else:
            out_dim = 300

        if self.args.model_type == 'res12':
            self.encoder = Res12()
        elif self.args.model_type == 'wrn28':
            self.encoder = WRN28()

        if self.mode == 'pre':
            self.pre_fc = nn.Sequential(nn.Linear(640, num_cls))

        elif self.mode == 'proto':
            self.fea_learner = nn.Sequential(nn.Linear(640, out_dim))  # 512

        elif self.mode == 'concat' or self.mode == 'fusion':
            self.word_learner = nn.Linear(z_dim, out_dim)
            self.feat_learner = nn.Linear(z_dim, out_dim)



    def forward(self, inp):
        """The function to forward the model.
        Args:
          inp: input images.
        Returns:
          the outputs of MTL model.
        """
        if self.mode == 'pre':
            return self.pretrain_forward(inp)
        elif self.mode == 'proto':
            data_shot, label_shot, data_query = inp
            return self.proto_forward(data_shot, label_shot, data_query)
        elif self.mode == 'concat':
            data_shot, emb_s, data_query = inp
            return self.concat_forward(data_shot, emb_s, data_query)
        elif self.mode == 'fusion':
            data_shot, emb_s, data_query = inp
            return self.fusion_forward(data_shot, emb_s, data_query)
        elif self.mode == 'preval':
            data_shot, label_shot, data_query = inp
            return self.preval_forward(data_shot, label_shot, data_query)
        else:
            raise ValueError('Please set the correct mode.')

    def pretrain_forward(self, inp):
        """The function to forward pretrain phase.
        Args:
          inp: input images.
        Returns:
          the outputs of pretrain model.
        """
        return self.pre_fc(self.encoder(inp, map=False))

    def preval_forward(self, data_shot, label_shot, data_query):
        """The function to forward meta-validation during pretrain phase.
        Args:
          data_shot: train images for the task
          label_shot: train labels for the task
          data_query: test images for the task.
        Returns:
          logits_q: the predictions for the test samples.
        """
        query = self.encoder(data_query, map=False)
        embedding_shot = self.encoder(data_shot, map=False)
        support = embedding_shot.view(self.args.shot, self.args.way, -1).transpose(1, 0)  # (way, shot, dim)
        proto = torch.mean(support, dim=1)
        logits_dist = euclidean_metric(query, proto)
        logits_sim = torch.mm(query, F.normalize(proto, p=2, dim=-1).t())
        return logits_dist, logits_sim

    def proto_forward(self, data_shot, label_shot, data_query):
        """The function to forward meta-validation during pretrain phase.
        Args:
          data_shot: train images for the task
          label_shot: train labels for the task
          data_query: test images for the task.
        Returns:
          logits_q: the predictions for the test samples.
        """
        query = self.encoder(data_query, map=False)
        embedding_shot = self.encoder(data_shot, map=False)
        support = embedding_shot.view(self.args.shot, self.args.way, -1).transpose(1, 0)  # (way, shot, dim)
        proto = torch.mean(support, dim=1)

        logits_dist = euclidean_metric(query, proto)
        return logits_dist


    def concat_forward(self, data_shot, emb_s, data_query):
        embedding_query = self.encoder(data_query, map=False)
        embedding_shot = self.encoder(data_shot, map=False)

        for param in self.encoder.parameters():
            param.requires_grad = False
        #Adam SGD AdaGrad RMSprop
        optimizer1 = torch.optim.Adam([
                                      {'params': filter(lambda p: p.requires_grad, self.encoder.parameters())},
                                      {'params': self.word_learner.parameters(), 'lr': 0.0001}
                                      ], lr=0.001,
                                      weight_decay=1e-5)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=50, gamma=0.2)
        self.word_learner.train()
        for i in range(100):
            support_emb = self.word_learner(embedding_shot)
            loss = emb_loss(support_emb, emb_s, self.args)
            optimizer1.zero_grad()
            loss.backward(retain_graph=True)
            optimizer1.step()
            lr_scheduler.step()
        self.word_learner.eval()
        query_emb = self.word_learner(embedding_query)

        vis_query = self.feat_learner(embedding_query)
        vis_support = self.feat_learner(embedding_shot)
        vis_support = vis_support.view(self.args.shot, self.args.way, -1).transpose(1, 0)
        visual_proto = torch.mean(vis_support, dim=1)

        proto = torch.cat((visual_proto, emb_s[:self.args.way].type(vis_support.dtype)), dim=1)
        feat_q = torch.cat((vis_query, query_emb), dim=1)

        logits_dist = euclidean_metric(feat_q, proto)
        if self.args.dataset=='cub':
            logits_dist = cosine_metric(feat_q, proto)
        return logits_dist, proto, feat_q

        # logits_sim = cosine_metric(feat_q, proto)
        # return logits_sim, accw, accs


    def fusion_forward(self, data_shot, emb_s, data_query):
        embedding_query = self.encoder(data_query, map=False)
        embedding_shot = self.encoder(data_shot, map=False)

        for param in self.encoder.parameters():
            param.requires_grad = False
        #Adam SGD AdaGrad RMSprop
        optimizer1 = torch.optim.Adam([
                                      {'params': filter(lambda p: p.requires_grad, self.encoder.parameters())},
                                      {'params': self.word_learner.parameters(), 'lr': 0.0001}
                                      ], lr=0.001,
                                      weight_decay=1e-5)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=50, gamma=0.2)
        self.word_learner.train()
        for i in range(100):
            support_emb = self.word_learner(embedding_shot)
            loss = emb_loss(support_emb, emb_s, self.args)
            optimizer1.zero_grad()
            loss.backward(retain_graph=True)
            optimizer1.step()
            lr_scheduler.step()
        self.word_learner.eval()
        query_emb = self.word_learner(embedding_query)

        vis_query = self.feat_learner(embedding_query)
        vis_support = self.feat_learner(embedding_shot)
        vis_support = vis_support.view(self.args.shot, self.args.way, -1).transpose(1, 0)
        visual_proto = torch.mean(vis_support, dim=1)

        lamda = self.args.lamda
        proto = lamda * visual_proto + (1-lamda)*emb_s[:self.args.way].type(vis_support.dtype)
        feat_q = lamda * vis_query + (1-lamda)*query_emb

        logits_dist = euclidean_metric(feat_q, proto)
        if self.args.dataset == 'cub':
            logits_dist = cosine_metric(feat_q, proto)

        return logits_dist, proto, feat_q, visual_proto, vis_query, query_emb

        # logits_sim = cosine_metric(feat_q, proto)
        # return logits_sim, accw, accs