import copy
from tqdm import tqdm
import argparse
import logging
from collections import defaultdict

import numpy as np

from torch.utils.data import DataLoader
from baselines.PromptFL import *
from baselines.clients import LocalUpdate_PromptFL_domain
from FedDTL_utils.utils import get_dataset_domain


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=20, help="number of rounds of training")
    parser.add_argument('--num_users', type=int, default=5, help="number of users: K")
    parser.add_argument('--shot', type=str, default="few-shot", choices=["few-shot", "all-shot"])
    parser.add_argument('--trainer', type=str, default="PromptFL")
    parser.add_argument('--seed', type=int, default=2025)
    parser.add_argument('--num_users_per_domain', type=int, default=3)
    parser.add_argument('--shot_num', type=int, default=16)

    parser.add_argument('--dataname', type=str, default='Office_Caltech10', help="name of dataset", choices=['Office_Caltech10', 'DomainNet'])
    parser.add_argument('--IID', type=str, default="IID_domain", choices=["IID", "Dirichlet_domain", "IID_domain"])
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument('--alpha', type=float, default=0.1, help="Dir(alpha)")
    parser.add_argument('--temp', type=float, default=0.5)
    parser.add_argument('--mu', type=float, default=1)

    parser.add_argument('--model_type', type=str, default='ViT-B/16', choices=['ViT-B/32', 'ViT-B/16'])
    parser.add_argument("--lora_mode", type=str, default="vision+text")
    parser.add_argument('--frac', type=float, default=1.0, help='the fraction of users')
    parser.add_argument('--local_ep', type=int, default=2, help="the number of local epochs")
    parser.add_argument('--local_bs', type=int, default=32, help="local batch size")
    parser.add_argument('--unseen_flag', type=bool, default=False, help="unseen classes")

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = args_parser()
    device = args.device

    if args.IID == "Dirichlet_domain":
        log_dir = './logs/domain/PromptFL/Dir_multi'
    elif args.IID == "IID_domain":
        log_dir = './logs/domain/PromptFL/IID_multi'
    else:
        log_dir = './logs/domain/PromptFL'

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_filename = f'{args.dataname}_{args.shot}.txt'
    logging.basicConfig(level=logging.INFO, format='%(message)s', filename=os.path.join(log_dir, log_filename))
    logger = logging.getLogger(__name__)
    logger.info("dataname: {}--IID: {}--alpha: {}--shot: {}--model_type: {}--seed: {}--method: {}".format(args.dataname, args.IID, args.alpha, args.shot_num, args.model_type, args.seed, args.trainer))

    encoder_ini, transform = clip.load(args.model_type, device=device)

    train_dataset, test_dataset, classes, user_groups, _ = get_dataset_domain(args, transform)
    testloader = DataLoader(test_dataset, batch_size=args.local_bs, shuffle=False)

    global_image_encoder = encoder_ini.visual.to(args.device)
    for i, p in global_image_encoder.named_parameters():
        p.requires_grad = False
    global_text_encoder = text_encoder_PromptFL(encoder_ini.transformer, encoder_ini.positional_embedding,
                                                encoder_ini.ln_final, encoder_ini.text_projection).to(args.device)
    for i, p in global_text_encoder.named_parameters():
        p.requires_grad = False

    global_prompt_learner = PromptLearner(classes, encoder_ini, args.device).to(args.device)

    user_label_counts_array = np.zeros((len(user_groups), len(classes)), dtype=int)
    for i, (user_id, user_indices) in enumerate(user_groups.items()):
        user_indices = list(user_indices)
        user_labels = np.array(train_dataset.targets)[user_indices]
        unique_labels, counts = np.unique(user_labels, return_counts=True)
        for label, count in zip(unique_labels, counts):
            user_label_counts_array[i, label] = count
    print("User label counts:")
    print(user_label_counts_array)

    if ((args.IID == "Dirichlet_domain") or (args.IID == "IID_domain")):
        args.num_users = len(user_groups)

    user_base_labels = {}
    for u in range(args.num_users):
        labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
        user_base_labels[u] = labels_u
    logger.info("User base labels: {}".format(user_base_labels))

    weight_user = np.sum(user_label_counts_array, axis=1) / len(train_dataset)
    global_accuracys = []
    local_test_acc = defaultdict(list)

    clients = [LocalUpdate_PromptFL_domain(args=args, train_data=train_dataset, idxs=user_groups[idx], client_index=idx,
                                           device=device, prompt_learner=copy.deepcopy(global_prompt_learner), logger=logger)
               for idx in range(args.num_users)]

    for global_epoch in range(args.epochs):
        print(f'\n | Global Training Round : {global_epoch + 1} |\n')
        global_weights = {}

        for idx in range(args.num_users):
            local_model = clients[idx]
            model_params = {name: param.clone() for name, param in global_prompt_learner.named_parameters() if param.requires_grad}
            local_model.update_model(model_params=model_params)

            print('--------------client {}-------------'.format(idx))
            local_weight = local_model.local_train(image_encoder=global_image_encoder, text_encoder=global_text_encoder)

            if global_weights == {}:
                for i in local_weight.keys():
                    global_weights[i] = local_weight[i].clone() * weight_user[idx]
            else:
                for i in local_weight.keys():
                    global_weights[i].data += local_weight[i].clone().data * weight_user[idx]

        for i, p in global_prompt_learner.named_parameters():
            if p.requires_grad:
                p.data.copy_(global_weights[i].data)

        # ======= Test Dataset for Global Performance =======
        if ((global_epoch+1) % args.epochs == 0):
            global_image_encoder.eval()
            global_text_encoder.eval()
            with torch.no_grad():
                global_correct, global_total = 0, 0
                domain_correct, domain_total = defaultdict(int), defaultdict(int)
                user_correct, user_total = defaultdict(int), defaultdict(int)

                logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
                scale = logit_scale.exp()

                for images, labels, domains in tqdm(testloader):
                    images, labels = images.to(device), labels.to(device)

                    image_features = global_image_encoder(images)
                    prompts = global_prompt_learner()
                    tokenized_prompts = global_prompt_learner.tokenized_prompts
                    text_features = global_text_encoder(prompts, tokenized_prompts)

                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                    logits = scale * image_features @ text_features.t()

                    preds = logits.argmax(dim=1)
                    global_correct += (preds == labels).sum().item()
                    global_total += labels.size(0)

                    for idx in range(args.num_users):
                        client_domain = clients[idx].domain_idx
                        domain_mask = (domains == client_domain)
                        domain_mask = domain_mask.to(device)
                        label_mask = torch.tensor([l.item() in user_base_labels[idx] for l in labels], device=device)
                        local_mask = domain_mask & label_mask

                        if local_mask.any():
                            user_correct[idx] += (preds[local_mask] == labels[local_mask]).sum().item()
                            user_total[idx] += local_mask.sum().item()

                    for d in torch.unique(domains):
                        d = d.item()
                        mask = (domains == d)
                        domain_correct[d] += (preds[mask] == labels[mask]).sum().item()
                        domain_total[d] += mask.sum().item()

                print('Global Test Acc: {:.2f}%'.format(100 * global_correct / global_total))
                global_accuracys = 100 * global_correct / global_total
                global_accuracys = round(global_accuracys, 2)

                for idx in range(args.num_users):
                    local_test_acc[idx].append(100 * user_correct[idx] / user_total[idx])

                if args.dataname == "Office_Caltech10":
                    domain_names = ['amazon', 'dslr', 'webcam', 'caltech']
                else:
                    domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']

                domain_test_acc = {}
                for d in sorted(domain_total.keys()):
                    acc = domain_correct[d] / max(1, domain_total[d])
                    acc = round(acc * 100, 2)
                    domain_test_acc[domain_names[d]] = acc

    local_print_test_accs = []
    for k in range(args.num_users):
        local_print_test_accs.append(local_test_acc[k][-1])
    logger.info(f"{args.trainer} Local Model Test Accuracy for clients [average] = {round(sum(local_print_test_accs) / len(local_print_test_accs), 2)}")
    logger.info(f"{args.trainer} Global Model Test Accuracy for average clients = {global_accuracys}")

    logger.info("======== Averaged Per-Domain Accuracy Across All Clients ========")
    logger.info(f"{args.trainer}_global_per_domain = ")
    for domain in sorted(domain_test_acc.keys()):
        logger.info(f"  \"{domain}\": [{domain_test_acc[domain]}],")
