import copy
from torch.autograd import Variable
from utils import *
from clusterop import obtain_label, get_probs
import torch.nn as nn
import torch.optim as optim
from functions import MSE, SIMSE, HLoss, EMLoss, SparseParam, OverlapMask
from network import *
from torch.nn import functional as F
import torch
import math

cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda:0" if cuda else "cpu")
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor


def client_train(args, dataloader, modelB, modelS, modelD, domain, rounds):

    param_group = []
    learning_rate = args.lr
    for k, v in modelB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    for k, v in modelS.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    for k, v in modelD.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    loader = enumerate(dataloader[domain])

    loss_classification = nn.CrossEntropyLoss()
    loss_entropy = HLoss()
    loss_recon1 = MSE()
    loss_recon2 = SIMSE()
    loss_overlap = OverlapMask()
    loss_sparse = SparseParam()

    modelB.train()
    modelS.train()
    modelD.train()

    if domain == 'uclient':
        modelB.eval()
        mem_label = obtain_label(dataloader['target_test'], modelB, args)
        mem_label = torch.from_numpy(mem_label).cuda()
        mem_values = get_probs(dataloader['target_test'], modelB)
        mem_values = torch.from_numpy(mem_values).cuda()
        modelB.train()

    for epoch in range(args.epochs):
        max_iter = len(dataloader[domain])
        for iteration in range(max_iter):
            try:
                imgs, labels, index = loader.__next__()[1]
            except StopIteration:
                loader = enumerate(dataloader[domain])
                imgs, labels, index = loader.__next__()[1]

            lr_schedule(optimizer, iter_num=rounds, max_iter=args.rounds)

            imgs_var = Variable(imgs.type(FloatTensor))
            lbls_var = Variable(labels.type(LongTensor))
            lbls_var_onehot = to_onehot(lbls_var, args.num_class)

            probs, Ifeats, Bfeats, maskI = modelB(imgs_var)
            Sfeats, maskS = modelS(Bfeats)
            hat_x = modelD(Sfeats, Ifeats)

            total_loss = 0.0
            # classification loss
            if domain == 'lclient':
                total_loss += loss_classification(probs, lbls_var)
            else:
                pred = mem_label[index].unsqueeze(1)
                value = mem_values[index]
                mask = torch.gt(value, 0.1)
                probs = probs[mask]
                pred = pred[mask].squeeze(1)
                if pred.shape[0] != 0:
                    total_loss += nn.CrossEntropyLoss()(probs, pred)

                    softmax_out = nn.Softmax(dim=1)(probs)
                    msoftmax = softmax_out.mean(dim=0)
                    gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))

                    total_loss += 1.0 * loss_entropy(probs)

            # reconstruction loss
            total_loss += args.alpha * loss_recon1(hat_x, imgs_var)
            total_loss += args.alpha * loss_recon2(hat_x, imgs_var)

            # difference loss
            total_loss += args.beta * (loss_overlap(Ifeats, Sfeats) + loss_sparse(maskI, maskS))

            # optimizer
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()


    return modelB, modelS, modelD


def server_train(args, dataloader, model, lmodelB, umodelB, rounds):

    model_dict = model.state_dict()
    lam = 2 / (1 + math.exp(-1 * 10 * rounds / args.rounds)) - 1
    for name in model.state_dict():
        model_dict[name] = (lmodelB.state_dict()[name] + lam * umodelB.state_dict()[name]) / (1.0 + lam)

    model.load_state_dict(model_dict)

    model.eval()
    all_acc = evaluation(dataloader['all_test'], model)
    source_acc = evaluation(dataloader['source_test'], model)
    target_acc = evaluation(dataloader['target_test'], model)
    log_str = 'At Round {:.0f}, source acc = {:.4f}%, target acc = {:.4f}%'.format(rounds, source_acc, target_acc)
    print(log_str + '\n')
    torch.save(model.state_dict(), args.server_model)

    return model


def train(args, dataloader, model, lmodelB, lmodelS, lmodelD, umodelB, umodelS, umodelD):
    for rounds in range(args.rounds):
        lmodel_dict = lmodelB.state_dict()
        for name in lmodelB.state_dict():
            lmodel_dict[name] = model.state_dict()[name]
        lmodelB.load_state_dict(lmodel_dict)

        umodel_dict = umodelB.state_dict()
        for name in umodelB.state_dict():
            umodel_dict[name] = model.state_dict()[name]
        umodelB.load_state_dict(umodel_dict)

        lmodelB, lmodelS, lmodelD = client_train(args, dataloader, lmodelB, lmodelS, lmodelD, 'lclient', rounds)
        umodelB, umodelS, umodelD = client_train(args, dataloader, umodelB, umodelS, umodelD, 'uclient', rounds)

        model  = server_train(args, dataloader, model, lmodelB, umodelB, rounds)
