import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F
from    torch.utils.data import TensorDataset, DataLoader
from    torch import optim
import  numpy as np
import time
import copy

from    learner import Learner
from    copy import deepcopy
from functions import *

class Meta(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args, config):
        super(Meta, self).__init__()

        self.update_lr = args.update_lr
        self.meta_lr = args.meta_lr
        self.n_way = args.n_way
        self.n_spt = args.n_spt
        self.n_qry = args.n_qry
        self.update_step = args.update_step
        self.update_step_test = args.update_step_test
        self.args = args
        self.device = args.device
        self.net = Learner(config, args.imgc, args.imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)

    def forward(self, step, x_spt, y_spt, x_qry, y_qry):

        if step != 0 and step != 20000 and step % 5000 == 0:
            for g in self.meta_optim.param_groups:
                g['lr'] = 0.1 * g['lr']

        num_user = len(x_spt.keys())
        avg_weight = self.net.parameters()
        global_proto = None

        for round in range(1, self.args.round + 1):
            weights = []
            prototypes = {i: [] for i in range(self.n_way)}
            for i in range(num_user):
                _x_spt = x_spt[i];  _y_spt = y_spt[i]
                sup_feat = self.net(_x_spt, vars=avg_weight, bn_training=True).squeeze()  # [30,32]
                prototype = self.NIID_make_prototype(F.avg_pool2d(sup_feat,6,1,0).squeeze(), _y_spt, self.device)
                prob = NIID_PN_pred(prototype, F.avg_pool2d(sup_feat,6,1,0).squeeze(), self.device)  # [25,5]
                Lloss = self.NIID_IC_Loss(prob, _y_spt)
                Gloss = self.GlobalLoss(global_proto, sup_feat, _y_spt)
                loss = Lloss + 0.2 * Gloss
                grad = torch.autograd.grad(loss, self.net.parameters(), retain_graph=True)
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))
                weights.append(fast_weights)
                self.NIID_upload_proto(prototypes, prototype)
            avg_weight = average_weights(weights)
            global_proto = self.NIID_average_prototypes(prototypes, self.device)  # [25,5]

        qry_input, qry_label = self.make_qry_input(x_qry, y_qry, self.args.device)
        qry_feat = self.net(qry_input, vars=avg_weight, bn_training=True).squeeze()  # [100,32]
        prob = PN_pred(global_proto, F.avg_pool2d(qry_feat, 6, 1, 0).squeeze())  # [25,5]
        meta_loss = F.cross_entropy(prob, qry_label)

        self.meta_optim.zero_grad()
        meta_loss.backward()
        self.meta_optim.step()

        with torch.no_grad():
            pred_q = prob.argmax(dim=1)
            acc = (pred_q == qry_label).float().mean()

        return acc.item()

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        # x_spt : [10,3,75,3,84,84], x_qry : [10,3,25,3,84,84]
        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        net = deepcopy(self.net)
        num_user = len(x_spt.keys())
        global_proto = None

        # weights = []
        for round in range(1, self.args.round + 1):
            weights = []
            prototypes = {i: [] for i in range(self.n_way)}
            for i in range(num_user):
                _x_spt = x_spt[i]; _y_spt = y_spt[i]
                local_net = deepcopy(net)
                optimizer = torch.optim.SGD(local_net.parameters(), lr=self.update_lr)
                sup_feat = net(_x_spt, vars=local_net.parameters(), bn_training=True).squeeze()  # [25,32]
                prototype = self.NIID_make_prototype(F.avg_pool2d(sup_feat,6,1,0).squeeze(), _y_spt, self.device)  # [5,32]
                prob = NIID_PN_pred(prototype, F.avg_pool2d(sup_feat,6,1,0).squeeze(), self.device)  # [25,5]
                Lloss = self.NIID_IC_Loss(prob, _y_spt)
                Gloss = self.GlobalLoss(global_proto, sup_feat, _y_spt)
                loss = Lloss + 0.2 * Gloss
                loss.backward(retain_graph=True)
                optimizer.step()
                weights.append(local_net.state_dict())
                self.NIID_upload_proto(prototypes, prototype)
            w_glob = FedAvg(weights)
            net.load_state_dict(w_glob)

        with torch.no_grad():
            avg_prototype = self.NIID_average_prototypes(prototypes, self.device)
            qry_input, qry_label = self.make_qry_input(x_qry, y_qry, self.args.device)
            qry_feat = net(qry_input, vars=net.parameters(), bn_training=True).squeeze()  # [100,32]
            prob = PN_pred(avg_prototype, F.avg_pool2d(qry_feat,6,1,0).squeeze())  # [25,5]
            pred = prob.argmax(dim=1)
            acc = (pred == qry_label).float().mean()

        del net
        return acc.item()

    def make_qry_input(self, x_qry, y_qry, device):
        with torch.no_grad():
            num_user = len(x_qry.keys())
            n_q = 0
            for i in range(num_user):
                n_q += x_qry[i].size(0)
            qry_input = torch.zeros(n_q,3,84,84).to(device)
            qry_label = torch.zeros(n_q, dtype=int).to(device)
            q_idx = 0
            for user in range(num_user):
                n_q = x_qry[user].size(0)
                qry_input[q_idx:q_idx+n_q] = x_qry[user]
                qry_label[q_idx:q_idx+n_q] = y_qry[user]
                q_idx += n_q

            return qry_input, qry_label

    def make_prototype(self, feats, y_spt, device):
        labels = range(self.n_way)
        C = feats.size(1)
        prototypes = torch.ones((len(labels), C), dtype=float).to(device)
        for label in labels:
            pos = (y_spt == label)
            feat = feats[pos]
            if feat.mean() != feat.mean(): # if feat is nan
                continue
            prototype = feat.mean(dim=0)
            prototypes[label] = prototype

        return prototypes

    def GlobalLoss(self, global_proto, sup_feat, _y_spt):
        # global proto [5,32]
        # sup_feat [25,32,6,6]
        if global_proto == None:
            return 0
        input1 = global_proto.unsqueeze(dim=0)[(...,)+(None,)*2]
        input2 = sup_feat.unsqueeze(dim=1)
        dist = -(input2-input1).pow(2).sum(dim=2)
        label = torch.ones_like(dist[:,0,:,:], dtype=int)
        for i in range(len(_y_spt)):
            label[i] = _y_spt[i]
        pred = dist.permute([0,2,3,1]).contiguous().view(-1,5) # [900,5]
        label = label.flatten()

        loss = F.cross_entropy(pred,label)
        return loss

    def NIID_make_prototype(self, feats, y_spt, device):
        labels = torch.unique(y_spt)
        C = feats.size(1)
        prototypes = {i.item(): torch.ones(C) for i in labels}
        for label in labels:
            pos = (y_spt == label)
            feat = feats[pos]
            prototype = feat.mean(dim=0)
            prototypes[label.item()] = prototype
        return prototypes

    def NIID_IC_Loss(self, prob, _y_qry):
        # if prob.size(1) == 1:
        #     return 0
        labels = _y_qry.unique()
        truth = torch.clone(_y_qry)
        for i in range(len(truth)):
            for clsidx, cls in enumerate(labels):
                if truth[i] == cls:
                    truth[i] = clsidx
                    break
        loss = F.cross_entropy(prob, truth)
        return loss

    def NIID_upload_proto(self, local_protos, local_proto):
        for cls in local_proto.keys():
            if local_protos[cls] == []:
                local_protos[cls] = local_proto[cls].unsqueeze(dim=0)
            else:
                local_protos[cls] = torch.cat((local_protos[cls], local_proto[cls].unsqueeze(dim=0)), dim=0)

    def NIID_average_prototypes(self, prototypes, device):  # dict, 5 keys, each value is
        nb_cls = len(prototypes.keys())
        C = prototypes[0].size(1)
        out = torch.ones((nb_cls, C)).to(device)
        for cls in range(nb_cls):
            out[cls] = prototypes[cls].mean(dim=0)

        return out

def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

def average_weights(weights):
    out = weights[0]
    for idx in range(1, len(weights)):
        for widx in range(len(out)):
            out[widx] = out[widx] + weights[idx][widx]

    for widx in range(len(out)):
        out[widx] = out[widx] / len(weights)

    return out

def make_prototype(feats, y_spt, device):
    labels = torch.unique(y_spt)
    C = feats.size(1)
    prototypes = torch.ones((len(labels),C), dtype=float).to(device)
    for label in labels:
        pos = (y_spt == label)
        feat = feats[pos]
        prototype = feat.mean(dim=0)
        prototypes[label] = prototype
    return prototypes

def PN_pred(prototype, qry_feat):
    # qryfeat = [25,32], prototype = [5,32]
    distance = qry_feat.unsqueeze(dim=1) - prototype # [25,5,32]
    distance = distance.pow(2).sum(dim=2)
    return -distance # [25,5]

def NIID_PN_pred(prototype, qry_feat, device):
    # qryfeat = [25,32], prototype = [5,32]
    # norm_proto = [25,32*36], norm_qry = [5,32*36]
    nb_cls = len(prototype.keys())
    proto = torch.ones((nb_cls, qry_feat.size(1))).to(device)
    for idx, cls in enumerate(prototype.keys()):
        proto[idx] = prototype[cls]

    distance = qry_feat.unsqueeze(dim=1) - proto  # [25,5,32]
    distance = distance.pow(2).sum(dim=2)
    return -distance  # [25,1] or [25,2] or [25,5](global proto)

def average_prototypes(prototypes): # num_user, n_class, c_
    out = prototypes[0] # [5,32]
    for idx in range(1, len(prototypes)):
        out = out + prototypes[idx]
    out = out / len(prototypes)

    return out

#
# def Local_qry(x_qry, y_qry, idx, n_user):
#     n_classes = y_qry.unique()
#     lenAll = len(x_qry)
#     _x_qry = torch.ones_like(x_qry)[:int(lenAll/n_user)]
#     _y_qry = torch.ones_like(y_qry)[:int(lenAll/n_user)]
#     for clsidx, cls in enumerate(n_classes):
#         loc = [i[0].item() for i in (y_qry==cls).nonzero()]
#         n = int(len(loc) / n_user)
#         user_loc = loc[idx*n:(idx+1)*n]
#         _x_qry[clsidx*n:(clsidx+1)*n] = x_qry[user_loc]
#         _y_qry[clsidx*n:(clsidx+1)*n] = y_qry[user_loc]
#
#     return _x_qry, _y_qry


def main():
    pass


if __name__ == '__main__':
    main()
