import torch
from adaptation import LMMDLoss
from utils.clip_util import AverageMeter
import utils.clip_util as clu
import torch.nn as nn
import clip
from utils.loss_function import CrossEntropyLabelSmooth
from utils.clip_util import convert_models_to_fp32
from utils.clip_util import FocalLossWithSmoothing
from tqdm import tqdm
import random
import time

def totrain(model):
    model.model.train()
    model.fea_attn.train()


def clip_lossfn(_f, factor_f, device):
    loss_img = torch.nn.CrossEntropyLoss()
    loss_txt = torch.nn.CrossEntropyLoss()
    _f = _f / _f.norm(dim=1, keepdim=True)
    factor_f = factor_f / factor_f.norm(dim=1, keepdim=True)
    logits_per_factor = 100 * factor_f @ _f.t()
    logits_per_f = logits_per_factor.t()
    ground_truth = torch.arange(len(_f), dtype=torch.long, device=device)
    return (loss_img(logits_per_factor, ground_truth) + loss_txt(logits_per_f, ground_truth)) / 2


def clip_lossfn_1(_f, factor_f, device):
    loss_img = torch.nn.CrossEntropyLoss()
    loss_txt = torch.nn.CrossEntropyLoss()
    logits_per_factor = factor_f @ _f.t()
    logits_per_f = logits_per_factor.t()
    ground_truth = torch.arange(len(_f), dtype=torch.long, device=device)
    return (loss_img(logits_per_factor, ground_truth) + loss_txt(logits_per_f, ground_truth)) / 2

def train(args, model, data_loader, optimizer, device, testloader, mmd_loss, server_model, previous_nets, lossfn=None, discr=None, gen=None, frame=None, optimizer_causal=None):
    totrain(model)
    texts = model.labels
    source_data = iter(data_loader)
    if args.method == 'ours':
        train_loss_ce = AverageMeter()
        train_loss_clip_z = AverageMeter()
        train_loss_clip_s = AverageMeter()
        train_loss_clip_c = AverageMeter()
        train_loss_p = AverageMeter()
        train_loss_do = AverageMeter()
        if args.dataset == 'BrainTumor':
            labels_features = clu.get_text_features_list(model.labels, model.model, device=device).float()
            for _ in (range(0, args.n_iter)):
                image, text, label, r_text_enc = next(source_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    label = label.to(device)
                    image_features = model.model.encode_image(image).float()
                    text_features = model.model.encode_text(text).float()
                    data_args = (image_features, label, labels_features)
                    loss, (p_loss, ce_loss) = lossfn(*data_args, args.exp_weight)
                    train_loss_ce.update(ce_loss)
                    train_loss_p.update(p_loss)

                    loss_do = torch.tensor(0.0)
                    if args.use_clips:
                        clip_loss_s = clip_lossfn(text_features, torch.mul(discr.s1x(image_features), image_features),
                                                  device=device)
                        loss = loss + clip_loss_s
                        train_loss_clip_s.update(clip_loss_s.item())

                    if args.use_clipc:
                        clip_loss_c = clip_lossfn(text_features, torch.mul(discr.c1x(image_features), image_features),
                                                  device=device)
                        loss = loss + clip_loss_c
                        train_loss_clip_c.update(clip_loss_c.item())

                    if args.use_clipz:
                        clip_loss_z = clip_lossfn(text_features, torch.mul(discr.z1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_z
                        train_loss_clip_z.update(clip_loss_z.item())
                    optimizer.zero_grad()
                    optimizer_causal.zero_grad()
                    loss.backward()
                    optimizer.step()
                    optimizer_causal.step()

            print("ce loss:", train_loss_ce.avg, 'exp loss:', train_loss_p.avg, 'z clip loss:', train_loss_clip_z.avg,
                  's clip loss:', train_loss_clip_s.avg, 'c clip loss:', train_loss_clip_c.avg, 'do loss:', train_loss_do.avg)
        elif args.dataset == 'RealSkin' or args.dataset == 'Dermnet' or args.dataset == 'havior' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
            labels_features = clu.get_text_features_list(model.labels, model.model, device=device).float()
            for _ in (range(0, len(data_loader))):
                image, text, label, r_text_enc = next(source_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    r_text_enc = r_text_enc.to(device)
                    label = label.to(device)
                    image_features = model.model.encode_image(image).float()

                    text_features = model.model.encode_text(text).float()
                    r_text_features = model.model.encode_text(r_text_enc).float()

                    data_args = (image_features, label, labels_features)
                    loss, (p_loss, ce_loss) = lossfn(*data_args, args.exp_weight)
                    train_loss_ce.update(ce_loss)
                    train_loss_p.update(p_loss)

                    if args.use_clips:
                        clip_loss_s = clip_lossfn(text_features, torch.mul(discr.s1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_s
                        train_loss_clip_s.update(clip_loss_s.item())

                    if args.use_clipc:
                        clip_loss_c = clip_lossfn(text_features, torch.mul(discr.c1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_c
                        train_loss_clip_c.update(clip_loss_c.item())
                    if args.use_clipz:
                        clip_loss_z = clip_lossfn(r_text_features, torch.mul(discr.z1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_z
                        train_loss_clip_z.update(clip_loss_z.item())

                    optimizer.zero_grad()
                    optimizer_causal.zero_grad()
                    loss.backward()
                    optimizer.step()
                    optimizer_causal.step()
            print("ce loss:", train_loss_ce.avg, 'exp loss:', train_loss_p.avg, 'z clip loss:', train_loss_clip_z.avg,
                  's clip loss:', train_loss_clip_s.avg, 'c clip loss:', train_loss_clip_c.avg, 'do loss:', train_loss_do.avg)

def train_K(args, model, data_loader, optimizer, device, testloader, mmd_loss, server_model, previous_nets, lossfn=None, discr=None, gen=None, frame=None, optimizer_causal=None):
    totrain(model)
    texts = model.labels
    t_features = clu.get_text_features_list(texts, model.model, device=device).float() #update each round based on the model
    loss_img = nn.CrossEntropyLoss()
    loss_txt = nn.CrossEntropyLoss()
    train_loss_clf = AverageMeter()
    train_loss_transfer = AverageMeter()
    source_data = iter(data_loader)
    target_data = iter(testloader)
    loss_all = 0
    loss_do_fn = torch.nn.CrossEntropyLoss(reduction='mean')
    if args.method == 'ours':
        train_loss_ce = AverageMeter()
        train_loss_clip_z = AverageMeter()
        train_loss_clip_s = AverageMeter()
        train_loss_clip_c = AverageMeter()
        train_loss_p = AverageMeter()
        train_loss_do = AverageMeter()
        if args.dataset == 'BrainTumor':
            labels_features = clu.get_text_features_list(model.labels, model.model, device=device).float()
            for _ in (range(0, args.n_iter)):
                image, text, label, r_text_enc = next(source_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    label = label.to(device)
                    # r_text_enc = r_text_enc.to(device)
                    image_features = model.model.encode_image(image).float()
                    # image_features_att = model.fea_attn(image_features)
                    # image_features = torch.mul(image_features_att, image_features)
                    text_features = model.model.encode_text(text).float()
                    data_args = (image_features, label, labels_features)
                    loss, (p_loss, ce_loss) = lossfn(*data_args, args.exp_weight)
                    train_loss_ce.update(ce_loss)
                    train_loss_p.update(p_loss)

                    loss_do = torch.tensor(0.0)
                    if args.use_clips:
                        clip_loss_s = clip_lossfn(text_features, torch.mul(discr.s1x(image_features), image_features),
                                                  device=device)
                        loss = loss + clip_loss_s
                        train_loss_clip_s.update(clip_loss_s.item())

                    if args.use_clipc:
                        clip_loss_c = clip_lossfn(text_features, torch.mul(discr.c1x(image_features), image_features),
                                                  device=device)
                        loss = loss + clip_loss_c
                        train_loss_clip_c.update(clip_loss_c.item())

                    if args.use_clipz:
                        clip_loss_z = clip_lossfn(text_features, torch.mul(discr.z1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_z
                        train_loss_clip_z.update(clip_loss_z.item())
                    optimizer.zero_grad()
                    optimizer_causal.zero_grad()
                    loss.backward()
                    optimizer.step()
                    optimizer_causal.step()

            print("ce loss:", train_loss_ce.avg, 'exp loss:', train_loss_p.avg, 'z clip loss:', train_loss_clip_z.avg,
                  's clip loss:', train_loss_clip_s.avg, 'c clip loss:', train_loss_clip_c.avg, 'do loss:', train_loss_do.avg)
        elif args.dataset == 'RealSkin' or args.dataset == 'Dermnet' or args.dataset == 'havior' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
            labels_features = clu.get_text_features_list(model.labels, model.model, device=device).float()
            for _ in (range(0, len(data_loader))):
                image, text, label, r_text_enc = next(source_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    r_text_enc = r_text_enc.to(device)
                    label = label.to(device)
                    image_features = server_model.model.encode_image(image).float()
                    # image_features_att = model.fea_attn(image_features)
                    # image_features = torch.mul(image_features_att, image_features)

                    text_features = server_model.model.encode_text(text).float()
                    r_text_features = server_model.model.encode_text(r_text_enc).float()

                    data_args = (image_features, label, labels_features)
                    loss, (p_loss, ce_loss) = lossfn(*data_args, args.exp_weight)
                    train_loss_ce.update(ce_loss)
                    train_loss_p.update(p_loss)

                    if args.use_clips:
                        clip_loss_s = clip_lossfn(text_features, torch.mul(discr.s1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_s
                        train_loss_clip_s.update(clip_loss_s.item())

                    if args.use_clipc:
                        clip_loss_c = clip_lossfn(text_features, torch.mul(discr.c1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_c
                        train_loss_clip_c.update(clip_loss_c.item())
                    if args.use_clipz:
                        clip_loss_z = clip_lossfn(r_text_features, torch.mul(discr.z1x(image_features), image_features), device=device)
                        loss = loss + clip_loss_z
                        train_loss_clip_z.update(clip_loss_z.item())

                    optimizer.zero_grad()
                    optimizer_causal.zero_grad()
                    loss.backward()
                    optimizer.step()
                    optimizer_causal.step()
            print("ce loss:", train_loss_ce.avg, 'exp loss:', train_loss_p.avg, 'z clip loss:', train_loss_clip_z.avg,
                  's clip loss:', train_loss_clip_s.avg, 'c clip loss:', train_loss_clip_c.avg, 'do loss:', train_loss_do.avg)

    elif args.method == 'faaclip':
        if args.dataset == 'BrainTumor':
            for _ in (range(0, args.n_iter)):
                image, text, label = next(source_data)  # .next()
                image_t, _, _ = next(target_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    image_t = image_t.to(device)
                    image_features = server_model.model.encode_image(image).float()
                    text_features = server_model.model.encode_text(text).float()
                    image_features_att = model.fea_attn(image_features)
                    image_features = torch.mul(image_features_att, image_features)
                    test_features = server_model.model.encode_image(image_t).float()
                    with torch.no_grad():
                        test_features_att = model.fea_attn(test_features)
                        test_features = torch.mul(test_features_att, test_features)

                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)
                    test_features = test_features / \
                                    test_features.norm(dim=1, keepdim=True)

                    loss_m = mmd_loss(image_features, test_features)
                    logit_scale = server_model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    cla_loss = (loss_img(logits_per_image, ground_truth) +
                                loss_txt(logits_per_text, ground_truth)) / 2
                    loss = cla_loss + 0.5 * loss_m

                    train_loss_clf.update(cla_loss.item())
                    train_loss_transfer.update(0.5 * loss_m.item())
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            print("cla loss: ", train_loss_clf.avg, 'trans loss:', train_loss_transfer.avg)
        elif args.dataset == 'RealSkin' or args.dataset == 'Dermnet' or args.dataset == 'havior' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
            for _ in (range(0, len(data_loader))):
                image, text, label = next(source_data)  # .next()
                image_t, _, _ = next(target_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    image_t = image_t.to(device)
                    image_features = server_model.model.encode_image(image).float()
                    text_features = server_model.model.encode_text(text).float()
                    image_features_att = model.fea_attn(image_features)
                    image_features = torch.mul(image_features_att, image_features)
                    test_features = server_model.model.encode_image(image_t).float()
                    with torch.no_grad():
                        test_features_att = model.fea_attn(test_features)
                        test_features = torch.mul(test_features_att, test_features)

                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)
                    test_features = test_features / \
                                    test_features.norm(dim=1, keepdim=True)

                    loss_m = mmd_loss(image_features, test_features)
                    logit_scale = server_model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    cla_loss = (loss_img(logits_per_image, ground_truth) +
                                loss_txt(logits_per_text, ground_truth)) / 2
                    loss = cla_loss + 0.5 * loss_m

                    train_loss_clf.update(cla_loss.item())
                    train_loss_transfer.update(0.5 * loss_m.item())
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            print("cla loss: ", train_loss_clf.avg, 'trans loss:', train_loss_transfer.avg)

    if args.method == 'fedprox':
        if args.dataset == 'BrainTumor':
            for _ in tqdm(range(0, args.n_iter)):
                image, text, label = next(source_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    image_features = model.model.encode_image(image).float()
                    text_features = model.model.encode_text(text).float()
                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)
                    logit_scale = model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    loss = (loss_img(logits_per_image, ground_truth) +
                            loss_txt(logits_per_text, ground_truth)) / 2
                    train_loss_clf.update(loss.item())
                    # print(loss)
                    # loss_all += loss
                    if args.step > 0:
                        w_diff = torch.tensor(1e-10, device=device)
                        for w, w_t in zip(server_model.model.parameters(), model.parameters()):
                            w_diff += torch.pow(torch.norm(w - w_t), 2).float()  # model difference
                            # print(w_diff)
                        w_diff = torch.sqrt(w_diff)
                        train_loss_transfer.update((1e-2 / 2. * w_diff).item())
                        loss += 1e-2 / 2. * w_diff  # dif loss
                        # print(loss)
                    optimizer.zero_grad()
                    loss.backward()
                    convert_models_to_fp32(model)
                    optimizer.step()
                    clip.model.convert_weights(model)
            print("cla loss: ", train_loss_clf.avg, 'w_diff loss: ', train_loss_transfer.avg)
        elif args.dataset == 'RealSkin' or args.dataset == 'Dermnet' or args.dataset == 'havior' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
            for _ in tqdm(range(0, len(data_loader))):
                image, text, label = next(source_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    image_features = model.model.encode_image(image).float()
                    text_features = model.model.encode_text(text).float()
                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)
                    logit_scale = model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    loss = (loss_img(logits_per_image, ground_truth) +
                            loss_txt(logits_per_text, ground_truth)) / 2
                    train_loss_clf.update(loss.item())
                    # print(loss)
                    # loss_all += loss
                    if args.step > 0:
                        w_diff = torch.tensor(1e-10, device=device)
                        for w, w_t in zip(server_model.parameters(), model.parameters()):
                            w_diff += torch.pow(torch.norm(w - w_t), 2).float()  # model difference
                            # print(w_diff)
                        w_diff = torch.sqrt(w_diff)
                        train_loss_transfer.update((1e-2 / 2. * w_diff).item())
                        loss += 1e-2 / 2. * w_diff  # dif loss
                        # print(loss)
                    optimizer.zero_grad()
                    loss.backward()
                    convert_models_to_fp32(model)
                    optimizer.step()
                    clip.model.convert_weights(model)
            print("cla loss: ", train_loss_clf.avg, 'w_diff loss: ', train_loss_transfer.avg)
    if args.method == 'moon':
        cnt = 0
        cos = torch.nn.CosineSimilarity(dim=-1)
        criterion = nn.CrossEntropyLoss()
        mu = 1
        if args.dataset == 'BrainTumor':
            for _ in tqdm(range(0, args.n_iter)):
                image, text, label = next(source_data)
                optimizer.zero_grad()
                image = image.to(device)
                text = text.to(device)
                image_features = model.model.encode_image(image).float()
                image_features_glo = server_model.model.encode_image(image).float()
                text_features = model.model.encode_text(text).float()
                image_features = image_features / \
                                 image_features.norm(dim=1, keepdim=True)
                text_features = text_features / \
                                text_features.norm(dim=1, keepdim=True)
                image_features_glo = image_features_glo / \
                                     image_features_glo.norm(dim=1, keepdim=True)
                logit_scale = model.model.logit_scale.exp()
                logits_per_image = logit_scale * image_features @ text_features.t()
                # logits_per_image_glo = logit_scale * image_features_glo @ text_features_glo.t()
                logits_per_text = logits_per_image.t()

                ground_truth = torch.arange(
                    len(image), dtype=torch.long, device=device)

                loss = (loss_img(logits_per_image, ground_truth) +
                        loss_txt(logits_per_text, ground_truth)) / 2
                train_loss_clf.update(loss.item())
                # MOON contrastive loss below, we refered the original codes, it needs [logits_per_image] to measure.
                # Model-Contrastive Federated Learning
                posi = cos(image_features, image_features_glo)
                logits = posi.reshape(-1, 1)
                if args.step > 0:
                    image_features_pre = previous_nets.model.encode_image(image).float()
                    # text_features_pre = previous_nets.model.encode_text(text).float()
                    image_features_pre = image_features_pre / \
                                         image_features_pre.norm(dim=1, keepdim=True)
                    nega = cos(image_features, image_features_pre)
                    logits = torch.cat((logits, nega.reshape(-1, 1)), dim=1)
                    logits /= args.temp
                    labels = torch.zeros(image.size(0)).cuda().long()
                    loss += mu * criterion(logits, labels)
                    train_loss_transfer.update(mu * criterion(logits, labels))
                loss.backward()
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)
            print("cla loss: ", train_loss_clf.avg, 'MOON loss: ', train_loss_transfer.avg)
        elif args.dataset == 'RealSkin' or args.dataset == 'Dermnet' or args.dataset == 'havior' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
            for _ in tqdm(range(0, len(data_loader))):
                image, text, label = next(source_data)
                optimizer.zero_grad()
                image = image.to(device)
                text = text.to(device)
                image_features = model.model.encode_image(image).float()
                image_features_glo = server_model.model.encode_image(image).float()
                text_features = model.model.encode_text(text).float()
                image_features = image_features / \
                                 image_features.norm(dim=1, keepdim=True)
                text_features = text_features / \
                                text_features.norm(dim=1, keepdim=True)
                image_features_glo = image_features_glo / \
                                     image_features_glo.norm(dim=1, keepdim=True)
                logit_scale = model.model.logit_scale.exp()
                logits_per_image = logit_scale * image_features @ text_features.t()
                # logits_per_image_glo = logit_scale * image_features_glo @ text_features_glo.t()
                logits_per_text = logits_per_image.t()

                ground_truth = torch.arange(
                    len(image), dtype=torch.long, device=device)

                loss = (loss_img(logits_per_image, ground_truth) +
                        loss_txt(logits_per_text, ground_truth)) / 2
                train_loss_clf.update(loss.item())
                # MOON contrastive loss below, we refered the original codes, it needs [logits_per_image] to measure.
                # Model-Contrastive Federated Learning
                posi = cos(image_features, image_features_glo)
                logits = posi.reshape(-1, 1)
                if args.step > 0:
                    image_features_pre = previous_nets.model.encode_image(image).float()
                    # text_features_pre = previous_nets.model.encode_text(text).float()
                    image_features_pre = image_features_pre / \
                                         image_features_pre.norm(dim=1, keepdim=True)
                    nega = cos(image_features, image_features_pre)
                    logits = torch.cat((logits, nega.reshape(-1, 1)), dim=1)
                    logits /= args.temp
                    labels = torch.zeros(image.size(0)).cuda().long()
                    loss += mu * criterion(logits, labels)
                    train_loss_transfer.update(mu * criterion(logits, labels))
                loss.backward()
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)
            print("cla loss: ", train_loss_clf.avg, 'MOON loss: ', train_loss_transfer.avg)

    if args.method == 'fedclip':
        if args.dataset == 'BrainTumor':
            for _ in tqdm(range(0, args.n_iter)):
                image, text, label = next(source_data)  # .next()
                # image_t, _, _ = next(target_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    image_features = model.model.encode_image(image).float()
                    text_features = model.model.encode_text(text).float()
                    image_features_att = model.fea_attn(image_features)
                    image_features = torch.mul(image_features_att, image_features)
                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)

                    logit_scale = model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    cla_loss = (loss_img(logits_per_image, ground_truth) +
                                loss_txt(logits_per_text, ground_truth)) / 2

                    train_loss_clf.update(cla_loss.item())
                    optimizer.zero_grad()
                    cla_loss.backward()
                    optimizer.step()
            print("cla loss: ", train_loss_clf.avg)
        elif args.dataset == 'RealSkin' or args.dataset == 'Dermnet' or args.dataset == 'havior' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
            for _ in (range(0, len(data_loader))):
                image, text, label = next(source_data)  # .next()
                # image_t, _, _ = next(target_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    image_features = server_model.model.encode_image(image).float()
                    text_features = server_model.model.encode_text(text).float()
                    image_features_att = model.fea_attn(image_features)
                    image_features = torch.mul(image_features_att, image_features)
                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)

                    logit_scale = server_model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    cla_loss = (loss_img(logits_per_image, ground_truth) +
                                loss_txt(logits_per_text, ground_truth)) / 2

                    train_loss_clf.update(cla_loss.item())
                    optimizer.zero_grad()
                    cla_loss.backward()
                    optimizer.step()
            print("cla loss: ", train_loss_clf.avg)

    if args.method == 'fedavg':
        if args.dataset == 'BrainTumor':
            for _ in tqdm(range(0, args.n_iter)):
                image, text, label = next(source_data)  # .next()
                # image_t, _, _ = next(target_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    # image_t = image_t.to(device)
                    image_features = model.model.encode_image(image).float()
                    text_features = model.model.encode_text(text).float()

                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)

                    logit_scale = model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    cla_loss = (loss_img(logits_per_image, ground_truth) +
                                loss_txt(logits_per_text, ground_truth)) / 2

                    train_loss_clf.update(cla_loss.item())
                    optimizer.zero_grad()
                    cla_loss.backward()
                    convert_models_to_fp32(model)
                    optimizer.step()
                    clip.model.convert_weights(model)
            print("cla loss: ", train_loss_clf.avg)
        elif args.dataset == 'RealSkin' or args.dataset == 'Dermnet' or args.dataset == 'havior' or args.dataset == 'OfficeHome' or args.dataset == 'ModernOffice31' or args.dataset == 'PACS':
            for _ in tqdm(range(0, len(data_loader))):
                image, text, label = next(source_data)  # .next()
                # image_t, _, _ = next(target_data)  # .next()
                if len(text) > 1:
                    image = image.to(device)
                    text = text.to(device)
                    # image_t = image_t.to(device)
                    image_features = model.model.encode_image(image).float()
                    text_features = model.model.encode_text(text).float()

                    image_features = image_features / \
                                     image_features.norm(dim=1, keepdim=True)
                    text_features = text_features / \
                                    text_features.norm(dim=1, keepdim=True)

                    logit_scale = model.model.logit_scale.exp()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    ground_truth = torch.arange(
                        len(image), dtype=torch.long, device=device)

                    cla_loss = (loss_img(logits_per_image, ground_truth) +
                                loss_txt(logits_per_text, ground_truth)) / 2

                    train_loss_clf.update(cla_loss.item())
                    optimizer.zero_grad()
                    cla_loss.backward()
                    convert_models_to_fp32(model)
                    optimizer.step()
                    clip.model.convert_weights(model)
            print("cla loss: ", train_loss_clf.avg)
