import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.cluster import KMeans
from utils.misc import route_plan

def np_proto(feat, label, way):
    feat_proto = np.zeros((way, feat.shape[1]))
    for lb in np.unique(label):
        ds = np.where(label == lb)[0]
        feat_ = feat[ds]
        feat_proto[lb] = np.mean(feat_, axis=0)
    return feat_proto

def updateproto(Xs, ys, cls_center, way):
    """
    assign the cluster labels with the novel labels
    """
    proto = np_proto(Xs, ys, way)
    dist = ((proto[:, np.newaxis, :]-cls_center[np.newaxis, :, :])**2).sum(2)
    id = dist.argmin(1)
    feat_proto = np.zeros((way, Xs.shape[1]))
    for i in range(way):
        feat_proto[i] = (cls_center[id[i]] + proto[i])/2
    return feat_proto


def updateproto_(Xs, ys, cls_center, way):
    """
    assign the cluster labels with the novel labels
    """
    proto = np_proto(Xs, ys, way)
    dist = ((proto[:, np.newaxis, :]-cls_center[np.newaxis, :, :])**2).sum(2)
    W = route_plan(dist)
    _, id = np.where(W > 0)
    feat_proto = np.zeros((way, Xs.shape[1]))
    for i in range(way):
        feat_proto[i] = (proto[i] + cls_center[id[i]])/2
    return feat_proto


def compactness_loss(gen_feat, gen_label, proto, supp_label):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
    loss = 0
    for lb in torch.unique(supp_label):
        id = torch.where(gen_label == lb)[0]
        loss = loss + loss_fn(gen_feat[id], proto[lb])
    return loss

def euclidean_metric(a, b):
    n = a.shape[0]
    m = b.shape[0]
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    logits = -((a - b)**2).sum(dim=2)
    return logits

def compute_proto(feat, label, way):
    feat_proto = torch.zeros(way, feat.size(1))
    for lb in torch.unique(label):
        ds = torch.where(label == lb)[0]
        feat_ = feat[ds]
        feat_proto[lb] = torch.mean(feat_, dim=0)
    if torch.cuda.is_available():
        feat_proto = feat_proto.type(feat.type())
    return feat_proto


class Classifier(nn.Module):
    """The class for inner loop."""
    def __init__(self, way, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.way = way
        self.vars = nn.ParameterList()
        self.fc1_w = nn.Parameter(torch.ones([self.way, self.z_dim]))
        torch.nn.init.kaiming_normal_(self.fc1_w)
        self.vars.append(self.fc1_w)
        self.fc1_b = nn.Parameter(torch.zeros(self.way))
        self.vars.append(self.fc1_b)

    def forward(self, input_x, the_vars=None):
        if the_vars is None:
            the_vars = self.vars
        fc1_w = the_vars[0]
        fc1_b = the_vars[1]
        net = F.linear(input_x, fc1_w, fc1_b)
        return net

    def parameters(self):
        return self.vars


class FClayer(nn.Module):
    """The class for inner loop."""
    def __init__(self, z_out, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.z_out = z_out
        self.vars = nn.ParameterList()
        self.fc1_w = nn.Parameter(torch.ones([self.z_out, self.z_dim]))
        torch.nn.init.kaiming_normal_(self.fc1_w)
        self.vars.append(self.fc1_w)
        # self.fc1_b = nn.Parameter(torch.zeros(self.z_out))
        # self.vars.append(self.fc1_b)
    def forward(self, input_x, the_vars=None):
        if the_vars is None:
            the_vars = self.vars
        fc1_w = the_vars[0]
        #fc1_b = the_vars[1]
        #net = F.linear(input_x, fc1_w, fc1_b)
        net = F.linear(input_x, fc1_w)
        return net

    def parameters(self):
        return self.vars

class DCLearner(nn.Module):
    """The class for outer loop."""
    def __init__(self, args, mode='st'):
        super().__init__()
        self.args = args
        self.mode = mode
        z_dim = 640
        if self.args.dataset == 'cub':
            z_sem = 312
        else:
            z_sem = 300

        if mode == 'st':
            self.fc_en = FClayer(z_sem, z_dim)
            self.activefun1 = nn.ReLU()
            self.trans = nn.Linear(z_sem, z_sem)
            #self.trans = FClayer(z_sem, z_sem)
            self.activefun2 = nn.Sigmoid() #nn.ReLU()
            self.fc_de = FClayer(z_dim, z_sem)
            self.classifyer = Classifier(self.args.way, z_dim)
        elif mode == 'dc':
            self.classifyer = Classifier(self.args.way, z_dim)

    def forward(self, inp):
        """The function to forward the model.
        Args:
          inp: input images.
        Returns:
          the outputs of MTL model.
        """
        if self.mode == 'st':
            feat_b, sem_b, sem_b1, label_b, feat_ns, sem_ns, label_ns, sem_n1, label_n1, feat_nq = inp
            return self.st_forward(feat_b, sem_b, sem_b1, label_b, feat_ns, sem_ns, label_ns, sem_n1, label_n1, feat_nq)
        elif self.mode == 'dc':
            feat_s, label_s, feat_q = inp
            return self.dc_forward(feat_s, label_s, feat_q)
        else:
            raise ValueError('Please set the correct mode.')


    def dc_forward(self, feat_s, label_s, feat_q):
        if self.args.classifiermethod == 'gradient':
            logits = self.classifyer(feat_s)
            loss = F.cross_entropy(logits, label_s)
            grad = torch.autograd.grad(loss, self.classifyer.parameters())
            fast_weights = list(map(lambda p: p[1] - 0.01 * p[0], zip(grad, self.classifyer.parameters())))

            for _ in range(1, 100):
                logits = self.classifyer(feat_s, fast_weights)
                loss = F.cross_entropy(logits, label_s)
                grad = torch.autograd.grad(loss, fast_weights)
                fast_weights = list(map(lambda p: p[1] - 0.01 * p[0], zip(grad, fast_weights)))
            logits_q = self.classifyer(feat_q, fast_weights)

        elif self.args.classifiermethod == 'metric':
            protos = compute_proto(feat_s, label_s, self.args.way)
            logits_q = euclidean_metric(feat_q, protos)

        return logits_q


    def st_forward(self, feat_b, sem_b, sem_b1, label_b, feat_ns, sem_ns, label_ns, sem_n1, label_n1, feat_nq):
        '''
        :param feat_b: base features (way*N, 640), N: the selected samples per class
        :param sem_b: base semantic feature (way*N, 300)
        :param sem_b1: base semantic feature (way, 300)
        :param label_b: base labels (way*N, )
        :param feat_ns: support features (way*shot, 300)
        :param sem_ns: support semantic features (way*shot, 300)
        :param label_ns: support labels (way*shot, )
        :param sem_n1: support semantic features for each class (way, 300)
        :param label_n1: support labels of each class (way, )
        :param feat_nq: query features
        :return: feat_n_1_1 the generated feature
        '''
        # transductive
        Xq = feat_nq.cuda().data.cpu().numpy()
        Xs = feat_ns.cuda().data.cpu().numpy()
        ys = label_ns.cuda().data.cpu().numpy()
        if self.args.shot == 1:
            km = KMeans(n_clusters=self.args.way, max_iter=1000, random_state=100)
        else:
            p_np = np_proto(Xs, ys, self.args.way)
            km = KMeans(n_clusters=self.args.way, init=p_np, max_iter=1000, random_state=100)
        #km = KMeans(n_clusters=self.args.way, max_iter=1000, random_state=100)
        yq_fit = km.fit(Xq)
        clus_center = yq_fit.cluster_centers_
        #proto1 = updateproto(Xs, ys, clus_center, self.args.way)
        proto1 = updateproto_(Xs, ys, clus_center, self.args.way)
        proto1 = torch.tensor(proto1).type(feat_ns.type())
        proto1 = F.normalize(proto1, dim=1)
        # inductive
        proto2 = compute_proto(feat_ns, label_ns, self.args.way)
        proto2 = F.normalize(proto2, dim=1)

        if self.args.setting == 'tran':
            proto = proto1
        elif self.args.setting == 'in':
            proto = proto2

        loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
        optimizer = torch.optim.Adam([{'params': self.fc_en.parameters(), 'lr': self.args.lr},
                                      {'params': self.trans.parameters(), 'lr': self.args.lr},
                                      {'params': self.fc_de.parameters(), 'lr': self.args.lr}], lr=self.args.lr)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)

        for i in range(50):
            lr_scheduler.step()
            self.fc_en.train()
            self.trans.train()
            self.fc_de.train()
            optimizer.zero_grad()
            '''-----------constraints on encoder---------'''
            # mapping constraint
            sem_b_1 = self.fc_en(feat_b)   # sem_b_1 = self.activefun2(sem_b_1)
            loss1 = loss_fn(sem_b, sem_b_1)
            # reconstruction constraint
            vars = nn.ParameterList()
            fc1_w = nn.Parameter(self.fc_en.fc1_w.transpose(1, 0))
            vars.append(fc1_w)
            feat_b_1 = self.fc_en(sem_b_1, vars)
            loss2 = loss_fn(feat_b, feat_b_1)
            '''---------------------------------'''

            '''--------------trans-------------'''
            # mapping constraint
            sem_n_1 = self.trans(sem_b1)  # sem_n_1=self.activefun(sem_n_1)
            loss3 = loss_fn(sem_n1, sem_n_1)
            '''---------------------------------'''

            '''--------constraints on decoder--------'''
            # mapping constraint
            feat_ns_1 = self.fc_de(sem_ns)
            loss5 = loss_fn(feat_ns, feat_ns_1)
            # reconstruction constraint
            vars = nn.ParameterList()
            fc1_w = nn.Parameter(self.fc_de.fc1_w.transpose(1, 0))
            vars.append(fc1_w)
            sem_ns_1 = self.fc_en(feat_ns_1, vars)
            loss6 = loss_fn(sem_ns, sem_ns_1)
            '''---------------------------------'''

            '''----------compactness constraint---------'''
            # transform the base data to novel data
            sem_n_1_1 = self.trans(sem_b_1) #sem_n_1_1 = self.activefun2(sem_n_1_1)
            feat_n_1_1 = self.fc_de(sem_n_1_1)
            loss7 = compactness_loss(feat_n_1_1, label_b, proto, label_ns)
            '''---------------------------------'''
            if self.args.Ablation == 'no':
                loss = loss1 + loss2 + loss3 + loss5 + loss6 + loss7
            elif self.args.Ablation == 'enc_recon':
                loss = loss1 + loss3 + loss5 + loss6 + loss7
            elif self.args.Ablation == 'dec_recon':
                loss = loss1 + loss2 + loss3 + loss5 + loss7
            elif self.args.Ablation == 'cpt':
                loss = loss1 + loss2 + loss3 + loss5 + loss6
            elif self.args.Ablation == 'all':
                loss = loss1 + loss3 + loss5

            loss.backward(retain_graph=True)
            optimizer.step()
            '''-----------------------finish training------------------------'''

            '''---------------------generate novel data--------------------'''
            feat = torch.cat((feat_n_1_1, feat_ns), dim=0)
            labels = torch.cat((label_b, label_ns), dim=0)
            feat = F.normalize(feat, dim=1)

            '''---------train classifier with the augmented data-----------'''

            if self.args.classifiermethod == 'gradient':
                logits = self.classifyer(feat)
                loss = F.cross_entropy(logits, labels)
                grad = torch.autograd.grad(loss, self.classifyer.parameters())
                fast_weights = list(map(lambda p: p[1] - self.args.gradlr * p[0], zip(grad, self.classifyer.parameters())))

                for _ in range(1, 100):
                    logits = self.classifyer(feat, fast_weights)
                    loss = F.cross_entropy(logits, labels)
                    grad = torch.autograd.grad(loss, fast_weights)
                    fast_weights = list(map(lambda p: p[1] - self.args.gradlr * p[0], zip(grad, fast_weights)))
                logits_q = self.classifyer(feat_nq, fast_weights)

            elif self.args.classifiermethod == 'metric':
                # prototypical method
                protos = compute_proto(feat, labels, self.args.way)
                logits_q = euclidean_metric(feat_nq, protos)

            elif self.args.classifiermethod == 'nonparam':
                # LR, SVM, KNN
                X_aug = feat.cuda().data.cpu().numpy()
                Y_aug = labels.cuda().data.cpu().numpy()
                data_query1 = feat_nq.cuda().data.cpu().numpy()
                if self.args.cls == 'lr':
                    classifier = LogisticRegression(max_iter=1000).fit(X=X_aug, y=Y_aug)
                elif self.args.cls == 'svm':
                    classifier = SVC(C=10, gamma='auto', kernel='linear', probability=True).fit(X=X_aug, y=Y_aug)
                elif self.args.cls == 'knn':
                    classifier = KNeighborsClassifier(n_neighbors=1).fit(X=X_aug, y=Y_aug)
                logits_q = classifier.predict(data_query1)
            else:
                raise ValueError('Please set the correct method.')

        return logits_q, sem_n_1_1, feat_n_1_1