import torch
import copy
import numpy as np
import random
from collections import defaultdict, Counter
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch.nn import functional as F


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)


class DatasetSplit_domain(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label, domain = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label), torch.tensor(domain)


class LocalUpdate_FedMaPLe(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, logger, user_base_labels):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.prompt_learner = prompt_learner
        self.logger = logger
        self.user_base_labels = user_base_labels

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.prompt_learner.named_parameters():
                if p.requires_grad:
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.0035)

            for images, labels in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts
                text_features = text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
                image_features = image_encoder(images, shared_ctx, deep_compound_prompts_vision)

                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                logits = scale * image_features @ text_features.t()

                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[FedMaPLe] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.prompt_learner.named_parameters() if param.requires_grad}
        return updated_params


class LocalUpdate_PromptFL(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, logger, user_base_labels):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.prompt_learner = prompt_learner
        self.logger = logger
        self.user_base_labels = user_base_labels

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.prompt_learner.named_parameters():
                if p.requires_grad:
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.001)

            for images, labels in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                image_features = image_encoder(images)
                prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts
                text_features = text_encoder(prompts, tokenized_prompts)

                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                logits = scale * image_features @ text_features.t()

                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[PromptFL] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.prompt_learner.named_parameters() if param.requires_grad}
        return updated_params


class LocalUpdate_FedPGP(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, logger, user_base_labels):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.cos = nn.CosineSimilarity(dim=-1)
        self.prompt_learner = prompt_learner
        self.logger = logger
        self.user_base_labels = user_base_labels

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.prompt_learner.named_parameters():
                if (i == 'sigma'):
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.001)

            for images, labels in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                image_features = image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                embedding, prompts_sigma, _, prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                text_features_0 = text_encoder(embedding, tokenized_prompts)
                text_features_sigma = text_encoder(prompts_sigma, tokenized_prompts)
                text_features_0 = text_features_0 / text_features_0.norm(dim=-1, keepdim=True)
                text_features_sigma = text_features_sigma / text_features_sigma.norm(dim=-1, keepdim=True)

                posi = self.cos(text_features_0, text_features_sigma)
                nega = self.cos(text_features_sigma, text_features)
                perd_logits = scale * image_features @ text_features.t()

                loss = self.criterion_CE(perd_logits, labels)
                logits = torch.cat((posi.reshape(-1, 1), nega.reshape(-1, 1)), dim=1)
                logits /= self.args.temp
                target = torch.zeros(logits.size(0)).to(self.device).long()
                loss2 = self.criterion_CE(logits, target)
                loss += self.args.mu * loss2

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = perd_logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[FedPGP] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.prompt_learner.named_parameters() if param.requires_grad}
        return updated_params

    def local_test(self, idx, testloader_base, testloader_new, image_encoder, text_encoder, train_classes,
                   global_prompt_learner, new_prompt_learner):
        with torch.no_grad():
            if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                user_correct = defaultdict(int)
                user_total = defaultdict(int)
            test_total, test_correct = 0.0, 0.0
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()
            with torch.no_grad():
                global_prompt_learner.U.copy_(self.prompt_learner.U)
                global_prompt_learner.V.copy_(self.prompt_learner.V)

            for images, labels in tqdm(testloader_base):
                images, labels = images.to(self.device), labels.to(self.device)
                image_features = image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                _, _, _, prompts = global_prompt_learner()
                tokenized_prompts = global_prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                logits = scale * image_features @ text_features.t()

                preds = logits.argmax(dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)
                if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                    for user_id, label_list in self.user_base_labels.items():
                        if user_id == idx:
                            label_tensor = torch.tensor(label_list, device=self.device)
                            mask = torch.isin(labels, label_tensor)
                            if mask.any():
                                user_correct[user_id] += (preds[mask] == labels[mask]).sum().item()
                                user_total[user_id] += mask.sum().item()
            print('Global Base Test Acc: {:.2f}%'.format(100 * test_correct / test_total))
            global_accuracys_base = 100 * test_correct / test_total
            if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                if user_total[idx] > 0:
                    local_test_acc = 100 * user_correct[idx] / user_total[idx]
                else:
                    local_test_acc = 0
            else:
                local_test_acc = 100 * test_correct / test_total

            test_total1, test_correct1 = 0.0, 0.0
            with torch.no_grad():
                new_prompt_learner.U.copy_(self.prompt_learner.U)
                new_prompt_learner.V.copy_(self.prompt_learner.V)
            for images, labels in tqdm(testloader_new):
                images, labels = images.to(self.device), labels.to(self.device)
                image_features = image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                _, _, _, prompts = new_prompt_learner()
                tokenized_prompts = new_prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                logits = scale * image_features @ text_features.t()

                preds = logits.argmax(dim=1) + int(len(train_classes))
                test_correct1 += (preds == labels).sum().item()
                test_total1 += labels.size(0)
            print('Global New Test Acc: {:.2f}%'.format(100 * test_correct1 / test_total1))
            global_accuracys_new = 100 * test_correct1 / test_total1
        return global_accuracys_base, global_accuracys_new, local_test_acc


class LocalUpdate_pFedMMA(object):
    def __init__(self, args, train_data, idxs, client_index, device, adapter_learner, logger, user_base_labels):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.adapter_learner = adapter_learner
        self.logger = logger
        self.user_base_labels = user_base_labels

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.adapter_learner.named_parameters():
                if ('shared_adapter' in i):
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.adapter_learner.parameters()), lr=0.001)

            for images, labels in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                token_embedding, adapter_bank = self.adapter_learner()
                tokenized_prompts = self.adapter_learner.tokenized_prompts

                text_features = text_encoder(token_embedding, tokenized_prompts, adapter_bank, mode="text")
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = image_encoder(images, adapter_bank, mode="visual")
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[pFedMMA] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.adapter_learner.named_parameters() if
                          (param.requires_grad) and ('shared_adapter' in name)}
        return updated_params

    def local_test(self, idx, testloader_base, testloader_new, image_encoder, text_encoder, train_classes,
                   global_adapter_learner, new_adapter_learner):
        with torch.no_grad():
            if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                user_correct = defaultdict(int)
                user_total = defaultdict(int)
            test_total, test_correct = 0.0, 0.0
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()
            with torch.no_grad():
                personalized_params = dict(self.adapter_learner.named_parameters())
                for i, p in global_adapter_learner.named_parameters():
                    if (("visual_adapter" in i) or ("text_adapter" in i)):
                        p.copy_(personalized_params[i])

            for images, labels in tqdm(testloader_base):
                images, labels = images.to(self.device), labels.to(self.device)
                token_embedding, adapter_bank = global_adapter_learner()
                tokenized_prompts = global_adapter_learner.tokenized_prompts

                text_features = text_encoder(token_embedding, tokenized_prompts, adapter_bank, mode="text")
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = image_encoder(images, adapter_bank, mode="visual")
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                preds = logits.argmax(dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)
                if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                    for user_id, label_list in self.user_base_labels.items():
                        if idx == user_id:
                            label_tensor = torch.tensor(label_list, device=self.device)
                            mask = torch.isin(labels, label_tensor)
                            if mask.any():
                                user_correct[user_id] += (preds[mask] == labels[mask]).sum().item()
                                user_total[user_id] += mask.sum().item()
            print('Global Base Test Acc: {:.2f}%'.format(100 * test_correct / test_total))
            global_accuracys_base = 100 * test_correct / test_total
            if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                if user_total[idx] > 0:
                    local_test_acc = 100 * user_correct[idx] / user_total[idx]
                else:
                    local_test_acc = 0
            else:
                local_test_acc = 100 * test_correct / test_total

            test_total1, test_correct1 = 0.0, 0.0
            with torch.no_grad():
                personalized_params = dict(self.adapter_learner.named_parameters())
                for i, p in new_adapter_learner.named_parameters():
                    if (("visual_adapter" in i) or ("text_adapter" in i)):
                        p.copy_(personalized_params[i])

            for images, labels in tqdm(testloader_new):
                images, labels = images.to(self.device), labels.to(self.device)
                token_embedding, adapter_bank = new_adapter_learner()
                tokenized_prompts = new_adapter_learner.tokenized_prompts

                text_features = text_encoder(token_embedding, tokenized_prompts, adapter_bank, mode="text")
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = image_encoder(images, adapter_bank, mode="visual")
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                preds = logits.argmax(dim=1) + int(len(train_classes))
                test_correct1 += (preds == labels).sum().item()
                test_total1 += labels.size(0)
            print('Global New Test Acc: {:.2f}%'.format(100 * test_correct1 / test_total1))
            global_accuracys_new = 100 * test_correct1 / test_total1

        return global_accuracys_base, global_accuracys_new, local_test_acc


class LocalUpdate_pFedDC(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, image_encoder, logger,
                 user_base_labels):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.prompt_learner = prompt_learner
        self.image_encoder = image_encoder
        self.logger = logger
        self.user_base_labels = user_base_labels

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.image_encoder.named_parameters():
                if (i == 'transformer.ctx_learner.ctx'):
                    with torch.no_grad():
                        p[:, :5, :].copy_(model_params[i][:, :5, :])
            for i, p in self.prompt_learner.named_parameters():
                if (i == 'ctx_learner.ctx'):
                    with torch.no_grad():
                        p.data[:16, :].copy_(model_params[i][:16, :])

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def local_train(self, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.image_encoder.parameters()), lr=0.01,
                                        weight_decay=0.05, momentum=0.9)
            optimizer1 = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.01,
                                         weight_decay=0.05, momentum=0.9)

            for images, labels in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = self.image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                optimizer1.zero_grad()
                loss.backward()
                optimizer.step()
                optimizer1.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[pFedDC] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {}
        for name, p in self.image_encoder.named_parameters():
            if name == 'transformer.ctx_learner.ctx':
                updated_params[name] = p[:, :5, :].clone().detach()
        for name, p in self.prompt_learner.named_parameters():
            if name == 'ctx_learner.ctx':
                updated_params[name] = p[:16, :].clone().detach()
        return updated_params

    def local_test(self, client_idx, testloader_base, testloader_new, text_encoder, train_classes, new_prompt_learner):
        with torch.no_grad():
            if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                user_correct = defaultdict(int)
                user_total = defaultdict(int)
            test_total, test_correct = 0.0, 0.0
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

            for images, labels in tqdm(testloader_base):
                images, labels = images.to(self.device), labels.to(self.device)
                prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = self.image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                preds = logits.argmax(dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)
                if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                    for user_id, label_list in self.user_base_labels.items():
                        if user_id == client_idx:
                            label_tensor = torch.tensor(label_list, device=self.device)
                            mask = torch.isin(labels, label_tensor)
                            if mask.any():
                                user_correct[client_idx] += (preds[mask] == labels[mask]).sum().item()
                                user_total[client_idx] += mask.sum().item()
            print('Global Base Test Acc: {:.2f}%'.format(100 * test_correct / test_total))
            global_accuracys_base = 100 * test_correct / test_total
            if ((self.args.IID == "Non-IID") or (self.args.IID == "Dirichlet")):
                local_test_acc = 100 * user_correct[client_idx] / user_total[client_idx]
            else:
                local_test_acc = 100 * test_correct / test_total

            test_total1, test_correct1 = 0.0, 0.0
            with torch.no_grad():
                personalized_params = dict(self.prompt_learner.named_parameters())
                for i, p in new_prompt_learner.named_parameters():
                    if (("tokenized_prompts" not in i) and ("token_prefix" not in i) and ("token_suffix" not in i)):
                        p.copy_(personalized_params[i])

            for images, labels in tqdm(testloader_new):
                images, labels = images.to(self.device), labels.to(self.device)
                prompts = new_prompt_learner()
                tokenized_prompts = new_prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = self.image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                preds = logits.argmax(dim=1) + int(len(train_classes))
                test_correct1 += (preds == labels).sum().item()
                test_total1 += labels.size(0)
            print('Global New Test Acc: {:.2f}%'.format(100 * test_correct1 / test_total1))
            global_accuracys_new = 100 * test_correct1 / test_total1

        return global_accuracys_base, global_accuracys_new, local_test_acc


class LocalUpdate_FedPGP_domain(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, logger):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.cos = nn.CosineSimilarity(dim=-1)
        self.prompt_learner = prompt_learner
        self.logger = logger
        self.domain_idx = self.infer_client_domain()

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.prompt_learner.named_parameters():
                if (i == 'sigma'):
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit_domain(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(16, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit_domain(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def infer_client_domain(self, ):
        domain_set = set()
        for _, _, domains in self.trainloader:
            domain_set.update(domains.tolist())
        if len(domain_set) != 1:
            raise ValueError(f"[Client {self.client_index}] expected exactly one domain, but got {domain_set}")
        domain_idx = next(iter(domain_set))
        print(f"client {self.client_index} domain idx: {domain_idx}")
        return int(domain_idx)

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.001)

            for images, labels, domains in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                image_features = image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                embedding, prompts_sigma, _, prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                text_features_0 = text_encoder(embedding, tokenized_prompts)
                text_features_sigma = text_encoder(prompts_sigma, tokenized_prompts)
                text_features_0 = text_features_0 / text_features_0.norm(dim=-1, keepdim=True)
                text_features_sigma = text_features_sigma / text_features_sigma.norm(dim=-1, keepdim=True)

                posi = self.cos(text_features_0, text_features_sigma)
                nega = self.cos(text_features_sigma, text_features)
                perd_logits = scale * image_features @ text_features.t()

                loss = self.criterion_CE(perd_logits, labels)
                logits = torch.cat((posi.reshape(-1, 1), nega.reshape(-1, 1)), dim=1)
                logits /= self.args.temp
                target = torch.zeros(logits.size(0)).to(self.device).long()
                loss2 = self.criterion_CE(logits, target)
                loss += self.args.mu * loss2

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = perd_logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[FedPGP] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.prompt_learner.named_parameters() if param.requires_grad}
        return updated_params

    def local_test(self, testloader, image_encoder, text_encoder, global_prompt_learner, user_base_labels):
        with torch.no_grad():
            global_correct, global_total = 0, 0
            domain_correct, domain_total = defaultdict(int), defaultdict(int)
            local_correct, local_total = 0, 0

            client_domain = self.domain_idx
            client_local_labels = user_base_labels

            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()
            with torch.no_grad():
                global_prompt_learner.U.copy_(self.prompt_learner.U)
                global_prompt_learner.V.copy_(self.prompt_learner.V)

            for images, labels, domains in tqdm(testloader):
                images, labels = images.to(self.device), labels.to(self.device)
                image_features = image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                _, _, _, prompts = global_prompt_learner()
                tokenized_prompts = global_prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                logits = scale * image_features @ text_features.t()

                preds = logits.argmax(dim=1)
                global_correct += (preds == labels).sum().item()
                global_total += labels.size(0)

                for d in torch.unique(domains):
                    d = d.item()
                    mask = (domains == d)
                    domain_correct[d] += (preds[mask] == labels[mask]).sum().item()
                    domain_total[d] += mask.sum().item()

                domain_mask = (domains == client_domain)
                domain_mask = domain_mask.to(self.device)
                label_mask = torch.tensor([l.item() in client_local_labels[self.client_index] for l in labels], device=self.device)
                local_mask = domain_mask & label_mask

                if local_mask.any():
                    local_correct += (preds[local_mask] == labels[local_mask]).sum().item()
                    local_total += local_mask.sum().item()

            print('Global Test Acc: {:.2f}%'.format(100 * global_correct / global_total))
            global_accuracys = 100 * global_correct / global_total
            global_accuracys = round(global_accuracys, 2)

            if self.args.dataname == "Office_Caltech10":
                domain_names = ['amazon', 'dslr', 'webcam', 'caltech']
            else:
                domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
            domain_test_acc = {}
            for d in sorted(domain_total.keys()):
                acc = domain_correct[d] / max(1, domain_total[d])
                acc = round(acc * 100, 2)
                domain_test_acc[domain_names[d]] = acc

            local_acc = round(100 * local_correct / local_total, 2)

        return global_accuracys, domain_test_acc, local_acc


class LocalUpdate_pFedMMA_domain(object):
    def __init__(self, args, train_data, idxs, client_index, device, adapter_learner, logger):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.adapter_learner = adapter_learner
        self.logger = logger
        self.domain_idx = self.infer_client_domain()

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.adapter_learner.named_parameters():
                if ('shared_adapter' in i):
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit_domain(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(16, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit_domain(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def infer_client_domain(self, ):
        domain_set = set()
        for _, _, domains in self.trainloader:
            domain_set.update(domains.tolist())
        if len(domain_set) != 1:
            raise ValueError(f"[Client {self.client_index}] expected exactly one domain, but got {domain_set}")
        domain_idx = next(iter(domain_set))
        print(f"client {self.client_index} domain idx: {domain_idx}")
        return int(domain_idx)

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.adapter_learner.parameters()), lr=0.001)

            for images, labels, domains in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                token_embedding, adapter_bank = self.adapter_learner()
                tokenized_prompts = self.adapter_learner.tokenized_prompts

                text_features = text_encoder(token_embedding, tokenized_prompts, adapter_bank, mode="text")
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = image_encoder(images, adapter_bank, mode="visual")
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[pFedMMA] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.adapter_learner.named_parameters() if
                          (param.requires_grad) and ('shared_adapter' in name)}
        return updated_params

    def local_test(self, testloader, image_encoder, text_encoder, global_adapter_learner, user_base_labels):
        with torch.no_grad():
            global_correct, global_total = 0, 0
            domain_correct, domain_total = defaultdict(int), defaultdict(int)
            local_correct, local_total = 0, 0

            client_domain = self.domain_idx
            client_local_labels = user_base_labels

            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

            with torch.no_grad():
                personalized_params = dict(self.adapter_learner.named_parameters())
                for i, p in global_adapter_learner.named_parameters():
                    if (("visual_adapter" in i) or ("text_adapter" in i)):
                        p.copy_(personalized_params[i])

            for images, labels, domains in tqdm(testloader):
                images, labels = images.to(self.device), labels.to(self.device)
                token_embedding, adapter_bank = global_adapter_learner()
                tokenized_prompts = global_adapter_learner.tokenized_prompts

                text_features = text_encoder(token_embedding, tokenized_prompts, adapter_bank, mode="text")
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = image_encoder(images, adapter_bank, mode="visual")
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                preds = logits.argmax(dim=1)
                global_correct += (preds == labels).sum().item()
                global_total += labels.size(0)

                for d in torch.unique(domains):
                    d = d.item()
                    mask = (domains == d)
                    domain_correct[d] += (preds[mask] == labels[mask]).sum().item()
                    domain_total[d] += mask.sum().item()

                domain_mask = (domains == client_domain)
                domain_mask = domain_mask.to(self.device)
                label_mask = torch.tensor([l.item() in client_local_labels[self.client_index] for l in labels], device=self.device)
                local_mask = domain_mask & label_mask

                if local_mask.any():
                    local_correct += (preds[local_mask] == labels[local_mask]).sum().item()
                    local_total += local_mask.sum().item()

            print('Global Test Acc: {:.2f}%'.format(100 * global_correct / global_total))
            global_accuracys = 100 * global_correct / global_total
            global_accuracys = round(global_accuracys, 2)

            if self.args.dataname == "Office_Caltech10":
                domain_names = ['amazon', 'dslr', 'webcam', 'caltech']
            else:
                domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
            domain_test_acc = {}
            for d in sorted(domain_total.keys()):
                acc = domain_correct[d] / max(1, domain_total[d])
                acc = round(acc * 100, 2)
                domain_test_acc[domain_names[d]] = acc

            local_acc = round(100 * local_correct / local_total, 2)

        return global_accuracys, domain_test_acc, local_acc


class LocalUpdate_pFedDC_domain(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, image_encoder, logger):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.prompt_learner = prompt_learner
        self.image_encoder = image_encoder
        self.logger = logger
        self.domain_idx = self.infer_client_domain()

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.image_encoder.named_parameters():
                if (i == 'transformer.ctx_learner.ctx'):
                    with torch.no_grad():
                        p[:, :5, :].copy_(model_params[i][:, :5, :])
            for i, p in self.prompt_learner.named_parameters():
                if (i == 'ctx_learner.ctx'):
                    with torch.no_grad():
                        p.data[:16, :].copy_(model_params[i][:16, :])

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit_domain(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(16, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit_domain(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def infer_client_domain(self, ):
        domain_set = set()
        for _, _, domains in self.trainloader:
            domain_set.update(domains.tolist())
        if len(domain_set) != 1:
            raise ValueError(f"[Client {self.client_index}] expected exactly one domain, but got {domain_set}")
        domain_idx = next(iter(domain_set))
        print(f"client {self.client_index} domain idx: {domain_idx}")
        return int(domain_idx)

    def local_train(self, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.image_encoder.parameters()), lr=0.01,
                                        weight_decay=0.05, momentum=0.9)
            optimizer1 = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.01,
                                         weight_decay=0.05, momentum=0.9)

            for images, labels, domains in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = self.image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                optimizer1.zero_grad()
                loss.backward()
                optimizer.step()
                optimizer1.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[pFedDC] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {}
        for name, p in self.image_encoder.named_parameters():
            if name == 'transformer.ctx_learner.ctx':
                updated_params[name] = p[:, :5, :].clone().detach()
        for name, p in self.prompt_learner.named_parameters():
            if name == 'ctx_learner.ctx':
                updated_params[name] = p[:16, :].clone().detach()
        return updated_params

    def local_test(self, testloader, text_encoder, user_base_labels):
        with torch.no_grad():
            global_correct, global_total = 0, 0
            domain_correct, domain_total = defaultdict(int), defaultdict(int)
            local_correct, local_total = 0, 0

            client_domain = self.domain_idx
            client_local_labels = user_base_labels

            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

            for images, labels, domains in tqdm(testloader):
                images, labels = images.to(self.device), labels.to(self.device)
                prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts

                text_features = text_encoder(prompts, tokenized_prompts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                image_features = self.image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                preds = logits.argmax(dim=1)
                global_correct += (preds == labels).sum().item()
                global_total += labels.size(0)

                for d in torch.unique(domains):
                    d = d.item()
                    mask = (domains == d)
                    domain_correct[d] += (preds[mask] == labels[mask]).sum().item()
                    domain_total[d] += mask.sum().item()

                domain_mask = (domains == client_domain)
                domain_mask = domain_mask.to(self.device)
                label_mask = torch.tensor([l.item() in client_local_labels[self.client_index] for l in labels], device=self.device)
                local_mask = domain_mask & label_mask

                if local_mask.any():
                    local_correct += (preds[local_mask] == labels[local_mask]).sum().item()
                    local_total += local_mask.sum().item()

            print('Global Test Acc: {:.2f}%'.format(100 * global_correct / global_total))
            global_accuracys = 100 * global_correct / global_total
            global_accuracys = round(global_accuracys, 2)

            if self.args.dataname == "Office_Caltech10":
                domain_names = ['amazon', 'dslr', 'webcam', 'caltech']
            else:
                domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
            domain_test_acc = {}
            for d in sorted(domain_total.keys()):
                acc = domain_correct[d] / max(1, domain_total[d])
                acc = round(acc * 100, 2)
                domain_test_acc[domain_names[d]] = acc

            local_acc = round(100 * local_correct / local_total, 2)

        return global_accuracys, domain_test_acc, local_acc


class LocalUpdate_FedMaPLe_domain(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, logger):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.prompt_learner = prompt_learner
        self.logger = logger
        self.domain_idx = self.infer_client_domain()

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.prompt_learner.named_parameters():
                if p.requires_grad:
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit_domain(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit_domain(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def infer_client_domain(self, ):
        domain_set = set()
        for _, _, domains in self.trainloader:
            domain_set.update(domains.tolist())
        if len(domain_set) != 1:
            raise ValueError(f"[Client {self.client_index}] expected exactly one domain, but got {domain_set}")
        domain_idx = next(iter(domain_set))
        print(f"client {self.client_index} domain idx: {domain_idx}")
        return int(domain_idx)

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.0035)

            for images, labels, _ in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)

                prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts
                text_features = text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
                image_features = image_encoder(images, shared_ctx, deep_compound_prompts_vision)

                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[FedMaPLe] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.prompt_learner.named_parameters() if param.requires_grad}
        return updated_params


class LocalUpdate_PromptFL_domain(object):
    def __init__(self, args, train_data, idxs, client_index, device, prompt_learner, logger):
        self.args = args
        self.testloader_cache = None
        if (self.args.shot == "all-shot"):
            self.trainloader = self.get_loader_all(train_data, list(idxs))
        elif (self.args.shot == "few-shot"):
            self.trainloader = self.get_loader_few(train_data, list(idxs))
        else:
            raise ValueError("Shot must be either few-shot or all-shot")
        self.device = device
        self.client_index = client_index
        self.criterion_CE = nn.CrossEntropyLoss()
        self.prompt_learner = prompt_learner
        self.logger = logger
        self.domain_idx = self.infer_client_domain()

    def update_model(self, model_params):
        if model_params is not None:
            for i, p in self.prompt_learner.named_parameters():
                if p.requires_grad:
                    p.data.copy_(model_params[i].data)

    def get_loader_all(self, train_data, idxs):
        idxs_train = idxs[:int(1.0 * len(idxs))]
        train_dataset = DatasetSplit_domain(train_data, idxs_train)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def get_loader_few(self, train_data, idxs):
        random.seed(self.args.seed)
        if hasattr(train_data, 'targets'):
            all_labels = train_data.targets
        elif hasattr(train_data, 'labels'):
            all_labels = train_data.labels
        else:
            raise AttributeError("train_data must have 'targets' or 'labels' attribute")

        idxs = list(idxs)
        label_to_indices = defaultdict(list)
        for idx in idxs:
            label = all_labels[idx]
            label_to_indices[label].append(idx)
        selected_indices = []
        for label, indices in label_to_indices.items():
            sampled = random.sample(indices, min(self.args.shot_num, len(indices)))
            selected_indices.extend(sampled)
        train_dataset = DatasetSplit_domain(train_data, selected_indices)
        trainloader = DataLoader(train_dataset, batch_size=self.args.local_bs, shuffle=True)
        return trainloader

    def infer_client_domain(self, ):
        domain_set = set()
        for _, _, domains in self.trainloader:
            domain_set.update(domains.tolist())
        if len(domain_set) != 1:
            raise ValueError(f"[Client {self.client_index}] expected exactly one domain, but got {domain_set}")
        domain_idx = next(iter(domain_set))
        print(f"client {self.client_index} domain idx: {domain_idx}")
        return int(domain_idx)

    def local_train(self, image_encoder, text_encoder):
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        for iter in range(self.args.local_ep):
            total, correct = 0.0, 0.0
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.prompt_learner.parameters()), lr=0.001)

            for images, labels, _ in tqdm(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                image_features = image_encoder(images)

                prompts = self.prompt_learner()
                tokenized_prompts = self.prompt_learner.tokenized_prompts
                text_features = text_encoder(prompts, tokenized_prompts)

                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                logits = scale * image_features @ text_features.t()
                loss = self.criterion_CE(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += labels.size(0)
                preds = logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
            print('[PromptFL] Client:{}, Train ing Acc: {:.2f}%'.format(self.client_index, 100 * correct / total))

        updated_params = {name: param for name, param in self.prompt_learner.named_parameters() if param.requires_grad}
        return updated_params
