import os
import copy
import argparse
import logging
from collections import defaultdict

import numpy as np

from torch.utils.data import DataLoader
from baselines.FedPGP import *
from baselines.clients import LocalUpdate_FedPGP
from FedDTL_utils.utils import get_dataset_new, get_dataset_new_dir, get_dataset_path


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=20, help="number of rounds of training")
    parser.add_argument('--num_users', type=int, default=5, help="number of users: K")
    parser.add_argument('--shot', type=str, default="few-shot", choices=["few-shot", "all-shot"])
    parser.add_argument('--trainer', type=str, default="FedPGP")
    parser.add_argument('--seed', type=int, default=2025)
    parser.add_argument('--shot_num', type=int, default=16)

    parser.add_argument('--dataname', type=str, default='Food101', help="name of dataset", choices=['CIFAR10', 'EuroSAT', 'CIFAR100', 'OxfordPet', 'Flower102', 'Food101', 'Caltech101', 'Caltech256', 'Tiny_ImageNet'])
    parser.add_argument('--IID', type=str, default="IID", choices=["IID", "Dirichlet", "Non-IID"])
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument('--alpha', type=float, default=0.1, help="Dir(alpha)")
    parser.add_argument('--temp', type=float, default=0.5)
    parser.add_argument('--mu', type=float, default=1)

    parser.add_argument('--model_type', type=str, default='ViT-B/16', choices=['ViT-B/32', 'ViT-B/16'])
    parser.add_argument("--lora_mode", type=str, default="vision+text")
    parser.add_argument('--frac', type=float, default=1.0, help='the fraction of users')
    parser.add_argument('--local_ep', type=int, default=2, help="the number of local epochs")
    parser.add_argument('--local_bs', type=int, default=32, help="local batch size")
    parser.add_argument('--unseen_flag', type=bool, default=False, help="unseen classes")

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = args_parser()
    device = args.device

    log_dir = './logs/FedPGP'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_filename = f'{args.shot}_{args.IID}.txt'
    logging.basicConfig(level=logging.INFO, format='%(message)s', filename=os.path.join(log_dir, log_filename))
    logger = logging.getLogger(__name__)
    logger.info("dataname: {}--IID: {}--alpha: {}--shot_num: {}--model_type: {}--seed: {}".format(args.dataname, args.IID, args.alpha, args.shot_num, args.model_type, args.seed))

    design_details = {"trainer": 'FedPGP',
                      "vision_depth": 0,
                      "language_depth": 0, "vision_ctx": 0,
                      "language_ctx": 0}
    encoder_ini, transform = clip.load(args.model_type, device=device, design_details=design_details)

    if (args.IID == "IID"):
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_new(args, transform)
    elif (args.IID == "Non-IID"):
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_path(args, transform)
    elif (args.IID == "Dirichlet"):
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, _ = get_dataset_new_dir(args, transform)
    else:
        raise ValueError("IID of IID must be either IID or Non-IID")
    testloader_base = DataLoader(test_dataset_base, batch_size=args.local_bs, shuffle=False)
    testloader_new = DataLoader(test_dataset_new, batch_size=args.local_bs, shuffle=False)

    global_image_encoder = encoder_ini.visual.to(args.device)
    for i, p in global_image_encoder.named_parameters():
        p.requires_grad = False
    global_text_encoder = text_encoder_FedPGP(encoder_ini.transformer, encoder_ini.positional_embedding,
                                              encoder_ini.ln_final, encoder_ini.text_projection).to(args.device)
    for i, p in global_text_encoder.named_parameters():
        p.requires_grad = False

    global_prompt_learner = PromptLearner(train_classes, encoder_ini, args.device).to(args.device)
    new_prompt_learner = PromptLearner(test_classes, encoder_ini, args.device).to(args.device)

    user_label_counts_array = np.zeros((len(user_groups), len(train_classes)), dtype=int)
    for i, (user_id, user_indices) in enumerate(user_groups.items()):
        user_indices = list(user_indices)
        user_labels = np.array(train_dataset.targets)[user_indices]
        unique_labels, counts = np.unique(user_labels, return_counts=True)
        for label, count in zip(unique_labels, counts):
            user_label_counts_array[i, label] = count
    print("User label counts:")
    print(user_label_counts_array)

    if args.IID == "Dirichlet":
        user_base_labels = {}
        for u in range(args.num_users):
            labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
            user_base_labels[u] = labels_u
    logger.info("User base labels: {}".format(user_base_labels))

    weight_user = np.sum(user_label_counts_array, axis=1) / len(train_dataset)
    global_accuracys_base, global_accuracys_new = [], []
    local_test_acc = defaultdict(list)

    clients = [LocalUpdate_FedPGP(args=args, train_data=train_dataset, idxs=user_groups[idx], client_index=idx,
                                  device=device, prompt_learner=copy.deepcopy(global_prompt_learner),
                                  logger=logger, user_base_labels=user_base_labels)
               for idx in range(args.num_users)]

    for global_epoch in range(args.epochs):
        print(f'\n | Global Training Round : {global_epoch + 1} |\n')
        global_weights = {}
        global_accuracys_base_per, global_accuracys_new_per = [], []

        for idx in range(args.num_users):
            local_model = clients[idx]
            model_params = {name: param.clone() for name, param in global_prompt_learner.named_parameters() if (name=='sigma')}
            local_model.update_model(model_params=model_params)

            print('--------------client {}-------------'.format(idx))
            local_weight = local_model.local_train(image_encoder=global_image_encoder, text_encoder=global_text_encoder)

            if global_weights == {}:
                global_weights['sigma'] = local_weight['sigma'].clone() * weight_user[idx]
            else:
                global_weights['sigma'].data += local_weight['sigma'].clone().data * weight_user[idx]

        for i, p in global_prompt_learner.named_parameters():
            if (i == 'sigma'):
                p.data.copy_(global_weights[i].data)

        # ======= Test Dataset for Global Performance =======
        with torch.no_grad():
            new_prompt_learner.sigma.copy_(global_prompt_learner.sigma)
        if ((global_epoch+1) % args.epochs == 0):
            global_image_encoder.eval()
            global_text_encoder.eval()
            for idx in range(args.num_users):
                global_accuracys_base_i, global_accuracys_new_i, local_test_accs = clients[idx].local_test(idx, testloader_base, testloader_new, global_image_encoder, global_text_encoder, train_classes, global_prompt_learner, new_prompt_learner)
                global_accuracys_base_per.append(global_accuracys_base_i)
                global_accuracys_new_per.append(global_accuracys_new_i)
                local_test_acc[idx].append(round(local_test_accs, 2))
            global_accuracys_base.append(sum(global_accuracys_base_per)/args.num_users)
            global_accuracys_new.append(sum(global_accuracys_new_per)/args.num_users)
            logger.info('[Base] Personal Model Test Accuracy: {}'.format(global_accuracys_base_per))
            logger.info('[New] Personal Model Test Accuracy: {}'.format(global_accuracys_new_per))

    for idx in range(args.num_users):
        logger.info('[Local] Global Model Test Accuracy for client {}: {}'.format(idx, local_test_acc[idx]))
    logger.info('[Base] Global Model Test Accuracy: {}'.format(global_accuracys_base))
    logger.info('[New] Global Model Test Accuracy: {}'.format(global_accuracys_new))

    HM = []
    for a_i, b_i in zip(global_accuracys_base, global_accuracys_new):
        HM_i = 2 * a_i * b_i
        HM_i = HM_i / (a_i + b_i)
        HM.append(round(HM_i, 2))
    a_new = []
    b_new = []
    for x in global_accuracys_base:
        a_new.append(round(x, 2))
    logger.info(f"[Base] Global Model Test Accuracy: {a_new}")
    for x in global_accuracys_new:
        b_new.append(round(x, 2))
    logger.info(f"[New] Global Model Test Accuracy:  {b_new}")
    logger.info(f"[HM] Global Model Test Accuracy:   {HM}")
