import os
import copy
import argparse
import logging
from collections import defaultdict

import numpy as np

from torch.utils.data import DataLoader
from baselines.FedPGP import *
from baselines.clients import LocalUpdate_FedPGP_domain
from FedDTL_utils.utils import get_dataset_domain


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=20, help="number of rounds of training 25")
    parser.add_argument('--num_users', type=int, default=5, help="number of users: K")
    parser.add_argument('--shot', type=str, default="all-shot", choices=["few-shot", "all-shot"])
    parser.add_argument('--trainer', type=str, default="FedPGP")
    parser.add_argument('--seed', type=int, default=2025)
    parser.add_argument('--num_users_per_domain', type=int, default=3)

    parser.add_argument('--dataname', type=str, default='Office_Caltech10', help="name of dataset", choices=['Office_Caltech10', 'DomainNet'])
    parser.add_argument('--IID', type=str, default="Dirichlet_domain", choices=["IID", "Dirichlet_domain", "IID_domain"])
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument('--alpha', type=float, default=0.1, help="Dir(alpha)")
    parser.add_argument('--temp', type=float, default=0.5)
    parser.add_argument('--mu', type=float, default=1)

    parser.add_argument('--model_type', type=str, default='ViT-B/16', choices=['ViT-B/32', 'ViT-B/16'])
    parser.add_argument("--lora_mode", type=str, default="vision+text")
    parser.add_argument('--frac', type=float, default=1.0, help='the fraction of users')
    parser.add_argument('--local_ep', type=int, default=2, help="the number of local epochs")
    parser.add_argument('--local_bs', type=int, default=32, help="local batch size")
    parser.add_argument('--unseen_flag', type=bool, default=False, help="unseen classes")

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = args_parser()
    device = args.device

    if args.IID == "Dirichlet_domain":
        log_dir = './logs/domain/FedPGP/Dir_multi'
    elif args.IID == "IID_domain":
        log_dir = './logs/domain/FedPGP/IID_multi'
    else:
        log_dir = './logs/domain/FedPGP'

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_filename = f'{args.dataname}_{args.shot}.txt'
    logging.basicConfig(level=logging.INFO, format='%(message)s', filename=os.path.join(log_dir, log_filename))
    logger = logging.getLogger(__name__)
    logger.info("dataname: {}--IID: {}--alpha: {}--shot: {}--model_type: {}--seed: {}--method: {}".format(args.dataname, args.IID, args.alpha, args.shot, args.model_type, args.seed, args.trainer))

    design_details = {"trainer": 'FedPGP',
                      "vision_depth": 0,
                      "language_depth": 0, "vision_ctx": 0,
                      "language_ctx": 0}
    encoder_ini, transform = clip.load(args.model_type, device=device, design_details=design_details)

    train_dataset, test_dataset, classes, user_groups, _ = get_dataset_domain(args, transform)
    testloader = DataLoader(test_dataset, batch_size=args.local_bs, shuffle=False)

    global_image_encoder = encoder_ini.visual.to(args.device)
    for i, p in global_image_encoder.named_parameters():
        p.requires_grad = False
    global_text_encoder = text_encoder_FedPGP(encoder_ini.transformer, encoder_ini.positional_embedding,
                                              encoder_ini.ln_final, encoder_ini.text_projection).to(args.device)
    for i, p in global_text_encoder.named_parameters():
        p.requires_grad = False

    global_prompt_learner = PromptLearner(classes, encoder_ini, args.device).to(args.device)

    user_label_counts_array = np.zeros((len(user_groups), len(classes)), dtype=int)
    for i, (user_id, user_indices) in enumerate(user_groups.items()):
        user_indices = list(user_indices)
        user_labels = np.array(train_dataset.targets)[user_indices]
        unique_labels, counts = np.unique(user_labels, return_counts=True)
        for label, count in zip(unique_labels, counts):
            user_label_counts_array[i, label] = count
    print("User label counts:")
    print(user_label_counts_array)

    if ((args.IID == "Dirichlet_domain") or (args.IID == "IID_domain")):
        args.num_users = len(user_groups)

    user_base_labels = {}
    for u in range(args.num_users):
        labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
        user_base_labels[u] = labels_u
    logger.info("User base labels: {}".format(user_base_labels))

    weight_user = np.sum(user_label_counts_array, axis=1) / len(train_dataset)
    global_accuracys = []
    domain_accs = defaultdict(dict)
    local_accs = defaultdict(dict)

    clients = [LocalUpdate_FedPGP_domain(args=args, train_data=train_dataset, idxs=user_groups[idx], client_index=idx,
                                  device=device, prompt_learner=copy.deepcopy(global_prompt_learner), logger=logger)
               for idx in range(args.num_users)]

    for global_epoch in range(args.epochs):
        print(f'\n | Global Training Round : {global_epoch + 1} |\n')
        global_weights = {}
        global_accuracys_per = []

        for idx in range(args.num_users):
            local_model = clients[idx]
            model_params = {name: param.clone() for name, param in global_prompt_learner.named_parameters() if (name=='sigma')}
            local_model.update_model(model_params=model_params)

            print('--------------client {}-------------'.format(idx))
            local_weight = local_model.local_train(image_encoder=global_image_encoder, text_encoder=global_text_encoder)

            if global_weights == {}:
                global_weights['sigma'] = local_weight['sigma'].clone() * weight_user[idx]
            else:
                global_weights['sigma'].data += local_weight['sigma'].clone().data * weight_user[idx]

        for i, p in global_prompt_learner.named_parameters():
            if (i == 'sigma'):
                p.data.copy_(global_weights[i].data)

        if ((global_epoch+1) % args.epochs == 0):
            global_image_encoder.eval()
            global_text_encoder.eval()
            for idx in range(args.num_users):
                global_accuracys_i, domain_test_acc, local_acc = clients[idx].local_test(testloader, global_image_encoder, global_text_encoder, global_prompt_learner, user_base_labels)
                global_accuracys_per.append(global_accuracys_i)
                domain_accs[idx] = domain_test_acc
                local_accs[idx] = local_acc
            tmp = sum(global_accuracys_per) / args.num_users
            global_accuracys.append(round(tmp, 2))
            logger.info(f"{args.trainer} Global Test Accuracy for all clients = {global_accuracys_per}")

    local_print_test_accs = []
    for k in local_accs.keys():
        local_print_test_accs.append(local_accs[k])
    logger.info(f"{args.trainer} Local Model Test Accuracy for clients = {local_print_test_accs}")
    logger.info(f"{args.trainer} Global Model Test Accuracy for average clients = {global_accuracys}")
    logger.info(f"{args.trainer} Local Model Test Accuracy for clients [average] = {round(sum(local_print_test_accs) / len(local_print_test_accs), 2)}")

    domain_avg_acc = defaultdict(list)
    for user_id in domain_accs:
        for domain, acc in domain_accs[user_id].items():
            domain_avg_acc[domain].append(acc)

    logger.info("======== Averaged Per-Domain Accuracy Across All Clients ========")
    domain_avg_final = {}
    logger.info(f"{args.trainer}_global_per_domain = ")
    for domain in sorted(domain_avg_acc.keys()):
        avg_acc = float(np.mean(domain_avg_acc[domain]))
        avg_acc = round(avg_acc, 2)
        domain_avg_final[domain] = avg_acc
        logger.info(f"  \"{domain}\": [{avg_acc}],")
