import os
import copy
import argparse
import logging
from collections import defaultdict
from tqdm import tqdm

import numpy as np

from torch.utils.data import DataLoader
from baselines.PromptFL import *
from baselines.clients import LocalUpdate_PromptFL
from FedDTL_utils.utils import get_dataset_new, get_dataset_new_dir, get_dataset_path


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=20, help="number of rounds of training")
    parser.add_argument('--num_users', type=int, default=5, help="number of users: K")
    parser.add_argument('--shot', type=str, default="few-shot", choices=["few-shot", "all-shot"])
    parser.add_argument('--trainer', type=str, default="PromptFL")
    parser.add_argument('--seed', type=int, default=2025)
    parser.add_argument('--shot_num', type=int, default=16)

    parser.add_argument('--dataname', type=str, default='Food101', help="name of dataset", choices=['CIFAR10', 'EuroSAT', 'CIFAR100', 'OxfordPet', 'Flower102', 'Food101', 'Caltech101', 'Caltech256', 'Tiny_ImageNet'])
    parser.add_argument('--IID', type=str, default="IID", choices=["IID", "Dirichlet", "Non-IID"])
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument('--alpha', type=float, default=0.1, help="Dir(alpha)")
    parser.add_argument('--temp', type=float, default=0.5)
    parser.add_argument('--mu', type=float, default=1)

    parser.add_argument('--model_type', type=str, default='ViT-B/16', choices=['ViT-B/32', 'ViT-B/16'])
    parser.add_argument("--lora_mode", type=str, default="vision+text")
    parser.add_argument('--frac', type=float, default=1.0, help='the fraction of users')
    parser.add_argument('--local_ep', type=int, default=2, help="the number of local epochs")
    parser.add_argument('--local_bs', type=int, default=32, help="local batch size")
    parser.add_argument('--unseen_flag', type=bool, default=False, help="unseen classes")

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = args_parser()
    device = args.device

    log_dir = './logs/PromptFL_new'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_filename = f'{args.shot}_{args.IID}.txt'
    logging.basicConfig(level=logging.INFO, format='%(message)s', filename=os.path.join(log_dir, log_filename))
    logger = logging.getLogger(__name__)
    logger.info("dataname: {}--IID: {}--alpha: {}--shot_num: {}--model_type: {}--seed: {}".format(args.dataname, args.IID, args.alpha, args.shot_num, args.model_type, args.seed))

    encoder_ini, transform = clip.load(args.model_type, device=device)

    if (args.IID == "IID"):
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_new(args, transform)
    elif (args.IID == "Non-IID"):
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_path(args, transform)
    elif (args.IID == "Dirichlet"):
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, _ = get_dataset_new_dir(args, transform)
    else:
        raise ValueError("IID of IID must be either IID or Non-IID")
    testloader_base = DataLoader(test_dataset_base, batch_size=args.local_bs, shuffle=False)
    testloader_new = DataLoader(test_dataset_new, batch_size=args.local_bs, shuffle=False)

    global_image_encoder = encoder_ini.visual.to(args.device)
    for i, p in global_image_encoder.named_parameters():
        p.requires_grad = False
    global_text_encoder = text_encoder_PromptFL(encoder_ini.transformer, encoder_ini.positional_embedding,
                                                encoder_ini.ln_final, encoder_ini.text_projection).to(args.device)
    for i, p in global_text_encoder.named_parameters():
        p.requires_grad = False

    global_prompt_learner = PromptLearner(train_classes, encoder_ini, args.device).to(args.device)
    new_prompt_learner = PromptLearner(test_classes, encoder_ini, args.device).to(args.device)

    user_label_counts_array = np.zeros((len(user_groups), len(train_classes)), dtype=int)
    for i, (user_id, user_indices) in enumerate(user_groups.items()):
        user_indices = list(user_indices)
        user_labels = np.array(train_dataset.targets)[user_indices]
        unique_labels, counts = np.unique(user_labels, return_counts=True)
        for label, count in zip(unique_labels, counts):
            user_label_counts_array[i, label] = count
    print("User label counts:")
    print(user_label_counts_array)

    if args.IID == "Dirichlet":
        user_base_labels = {}
        for u in range(args.num_users):
            labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
            user_base_labels[u] = labels_u
    logger.info("User base labels: {}".format(user_base_labels))

    weight_user = np.sum(user_label_counts_array, axis=1) / len(train_dataset)
    global_accuracys_base, global_accuracys_new = [], []
    local_test_acc = defaultdict(list)

    clients = [LocalUpdate_PromptFL(args=args, train_data=train_dataset, idxs=user_groups[idx], client_index=idx,
                                    device=device, prompt_learner=copy.deepcopy(global_prompt_learner),
                                    logger=logger, user_base_labels=user_base_labels)
               for idx in range(args.num_users)]

    for global_epoch in range(args.epochs):
        print(f'\n | Global Training Round : {global_epoch + 1} |\n')
        global_weights = {}

        for idx in range(args.num_users):
            local_model = clients[idx]
            model_params = {name: param.clone() for name, param in global_prompt_learner.named_parameters() if param.requires_grad}
            local_model.update_model(model_params=model_params)

            print('--------------client {}-------------'.format(idx))
            local_weight = local_model.local_train(image_encoder=global_image_encoder, text_encoder=global_text_encoder)

            if global_weights == {}:
                for i in local_weight.keys():
                    global_weights[i] = local_weight[i].clone() * weight_user[idx]
            else:
                for i in local_weight.keys():
                    global_weights[i].data += local_weight[i].clone().data * weight_user[idx]

        for i, p in global_prompt_learner.named_parameters():
            if p.requires_grad:
                p.data.copy_(global_weights[i].data)

        # ======= Test Dataset for Global Performance =======
        with torch.no_grad():
            user_correct, user_total = defaultdict(int), defaultdict(int)
            global_params = dict(global_prompt_learner.named_parameters())
            for i, p in new_prompt_learner.named_parameters():
                if (p.requires_grad) and (i in global_params):
                    p.copy_(global_params[i])

            if ((global_epoch+1) % args.epochs == 0):
                global_image_encoder.eval()
                global_text_encoder.eval()

                test_total, test_correct = 0.0, 0.0
                logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
                scale = logit_scale.exp()

                for images, labels in tqdm(testloader_base):
                    images, labels = images.to(device), labels.to(device)

                    image_features = global_image_encoder(images)
                    prompts = global_prompt_learner()
                    tokenized_prompts = global_prompt_learner.tokenized_prompts
                    text_features = global_text_encoder(prompts, tokenized_prompts)

                    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                    logits = scale * image_features @ text_features.t()

                    preds = logits.argmax(dim=1)
                    test_correct += (preds == labels).sum().item()
                    test_total += labels.size(0)

                    for idx in range(args.num_users):
                        label_tensor = torch.tensor(user_base_labels[idx], device=device)
                        mask = torch.isin(labels, label_tensor)
                        if mask.any():
                            user_correct[idx] += (preds[mask] == labels[mask]).sum().item()
                            user_total[idx] += mask.sum().item()

                print('Global Base Test Acc: {:.2f}%'.format(100 * test_correct / test_total))
                global_accuracys_base = 100 * test_correct / test_total
                for idx in range(args.num_users):
                    local_test_acc[idx].append(100 * user_correct[idx] / user_total[idx])

                test_total1, test_correct1 = 0.0, 0.0
                for images, labels in tqdm(testloader_new):
                    images, labels = images.to(device), labels.to(device)

                    image_features = global_image_encoder(images)
                    prompts = new_prompt_learner()
                    tokenized_prompts = new_prompt_learner.tokenized_prompts
                    text_features_new = global_text_encoder(prompts, tokenized_prompts)

                    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                    text_features_new = text_features_new / text_features_new.norm(dim=-1, keepdim=True)
                    logits = scale * image_features @ text_features_new.t()

                    preds = logits.argmax(dim=1) + int(len(train_classes))
                    test_correct1 += (preds == labels).sum().item()
                    test_total1 += labels.size(0)

                print('Global New Test Acc: {:.2f}%'.format(100 * test_correct1 / test_total1))
                global_accuracys_new = 100 * test_correct1 / test_total1

    for idx in range(args.num_users):
        logger.info('[Local] Global Model Test Accuracy for client {}: {}'.format(idx, local_test_acc[idx]))
    logger.info('[Base] Global Model Test Accuracy: {}'.format(round(global_accuracys_base, 2)))
    logger.info('[New] Global Model Test Accuracy: {}'.format(round(global_accuracys_new, 2)))
    HM = 2 * global_accuracys_base * global_accuracys_new / (global_accuracys_base + global_accuracys_new)
    logger.info('[HM] Global Model Test Accuracy: {}'.format(round(HM, 2)))
