import os
import copy
import clip
import json
import argparse
import torch
import torch.nn as nn
import logging
import random
import numpy as np

from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import defaultdict, Counter

from FedDTL_utils.models_new import *
from FedDTL_utils.utils import get_dataset_new, get_dataset_domain, get_dataset_new_dir, get_dataset_path


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--CLIP_type', type=str, default='ori')
    parser.add_argument('--epochs', type=int, default=1, help="number of rounds of training")
    parser.add_argument('--num_users', type=int, default=5, help="number of users: K")
    parser.add_argument('--dataname', type=str, default='Caltech256', help="name of dataset",
                        choices=['CIFAR10', 'EuroSAT', 'CIFAR100', 'OxfordPet', 'Flower102', 'Food101', 'Tiny_ImageNet',
                                 'Caltech101', 'Caltech256', 'Office_Caltech10', 'DomainNet'])
    parser.add_argument('--IID', type=str, default="Non-IID", choices=["IID", "Non-IID", "Dirichlet", "Dirichlet_domain", "IID_domain"])
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument('--alpha', type=float, default=0.5, help="Dir(alpha)")
    parser.add_argument('--num_users_per_domain', type=int, default=3, help="Dir(beta)")
    parser.add_argument('--seed', type=int, default=2025, help="Dir(beta)")

    parser.add_argument("--text_input", type=str, default="simple")
    parser.add_argument("--r", type=int, default=4, help="number of LoRA Rank")
    parser.add_argument('--balance_layer', type=int, default=10, help="the start layer with lora")
    parser.add_argument('--model_type', type=str, default='ViT-B/16', choices=['ViT-B/32', 'ViT-B/16'])
    parser.add_argument("--lora_mode", type=str, default="vision+text")
    parser.add_argument('--frac', type=float, default=1.0, help='the fraction of users: C')
    parser.add_argument('--local_bs', type=int, default=64, help="local batch size: B")
    parser.add_argument('--unseen_flag', type=bool, default=False, help="unseen classes")

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = args_parser()
    device = args.device

    log_dir = './logs/zero_shot_CLIP'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_filename = f'{args.CLIP_type}.txt'
    logging.basicConfig(level=logging.INFO, format='%(message)s', filename=os.path.join(log_dir, log_filename))
    logger = logging.getLogger(__name__)
    logger.info("dataname: {}--IID: {}--alpha: {}--model_type: {}--text_input: {}--seed: {}".format(args.dataname, args.IID, args.alpha, args.model_type, args.text_input, args.seed))

    if args.CLIP_type == 'ori':
        encoder_ini, transform = clip.load(args.model_type, device=device)
    else:
        raise ValueError('CLIP type not supported')

    if ((args.dataname == "Office_Caltech10") or (args.dataname == "DomainNet")):
        train_dataset, test_dataset, classes, user_groups, _ = get_dataset_domain(args, transform)

        domain_test_acc_print = defaultdict(list)
        trainloader = DataLoader(train_dataset, batch_size=args.local_bs, shuffle=True)
        testloader = DataLoader(test_dataset, batch_size=args.local_bs, shuffle=False)

        user_label_counts_array = np.zeros((len(user_groups), len(classes)), dtype=int)
        for i, (user_id, user_indices) in enumerate(user_groups.items()):
            user_indices = list(user_indices)
            user_labels = np.array(train_dataset.targets)[user_indices]
            unique_labels, counts = np.unique(user_labels, return_counts=True)
            for label, count in zip(unique_labels, counts):
                user_label_counts_array[i, label] = count

        if ((args.IID == "Dirichlet_domain") or (args.IID == "IID_domain")):
            args.num_users = len(user_groups)

        user_base_labels = {}
        for u in range(args.num_users):
            labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
            user_base_labels[u] = labels_u
        logger.info("User base labels: {}".format(user_base_labels))

        if args.dataname == "DomainNet":
            client_domain = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]
        else:
            client_domain = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]

        global_image_encoder = image_encoder_new_0(encoder_ini.visual).to(device)
        global_text_encoder = text_encoder_0(encoder_ini.transformer, encoder_ini.token_embedding,
                                               encoder_ini.positional_embedding, encoder_ini.ln_final,
                                               encoder_ini.text_projection)
        for i, p in global_image_encoder.named_parameters():
            p.requires_grad = False
        for i, p in global_text_encoder.named_parameters():
            p.requires_grad = False

        template = ['a photo of a {}.']
        texts = [template[0].format(classname.replace('_', ' ')) for classname in classes]

        texts = clip.tokenize(texts).to(device)
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        global_accuracys, local_accuracys = [], []
        local_correct = defaultdict(int)
        local_total = defaultdict(int)

        if args.dataname == "Office_Caltech10":
            domain_names = ['amazon', 'dslr', 'webcam', 'caltech']
        else:
            domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']

        with torch.no_grad():
            global_correct, global_total = 0, 0
            domain_correct = defaultdict(int)
            domain_total = defaultdict(int)

            for images, labels, domains in tqdm(testloader):
                images, labels = images.to(device), labels.to(device)
                _, image_features = global_image_encoder(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                text_features = global_text_encoder(texts)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                logits = scale * image_features @ text_features.t()

                preds = logits.argmax(dim=1)
                global_correct += (preds == labels).sum().item()
                global_total += labels.size(0)

                for d in torch.unique(domains):
                    d = d.item()
                    mask = (domains == d)
                    domain_correct[d] += (preds[mask] == labels[mask]).sum().item()
                    domain_total[d] += mask.sum().item()

                for idx in range(args.num_users):
                    domain_mask = (domains == client_domain[idx])
                    domain_mask = domain_mask.to(device)
                    label_mask = torch.tensor([l.item() in user_base_labels[idx] for l in labels], device=device)
                    local_mask = domain_mask & label_mask

                    if local_mask.any():
                        local_correct[idx] += (preds[local_mask] == labels[local_mask]).sum().item()
                        local_total[idx] += local_mask.sum().item()

            global_acc = 100 * global_correct / global_total
            print('Global Test Acc: {:.2f}%'.format(global_acc))

            domain_acc = {domain_names[d]: round(100 * domain_correct[d] / domain_total[d], 2) for d in domain_total}

            local_print_test_accs = []
            for idx in range(args.num_users):
                local_accuracys.append(round(100 * local_correct[idx] / local_total[idx], 2))
            logger.info(f"Local Model Test Accuracy for clients = {local_accuracys}")
            logger.info(f"Global Model Test Accuracy for average clients = {global_acc}")
            logger.info(f"Local Model Test Accuracy for clients [average] = {round(sum(local_accuracys) / len(local_accuracys), 2)}")

            logger.info("======== Averaged Per-Domain Accuracy Across All Clients ========")
            logger.info(f"global_per_domain = ")
            for domain in sorted(domain_acc.keys()):
                avg_acc = domain_acc[domain]
                logger.info(f"  \"{domain}\": [{avg_acc}],")

    else:
        if (args.IID == "IID"):
            train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_new(args, transform)
        elif (args.IID == "Non-IID"):
            train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_path(args, transform)
        elif (args.IID == "Dirichlet"):
            train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, _ = get_dataset_new_dir(args, transform)
        else:
            raise ValueError("IID of IID must be either IID or Non-IID")
        testloader_base = DataLoader(test_dataset_base, batch_size=args.local_bs, shuffle=False)
        testloader_new = DataLoader(test_dataset_new, batch_size=args.local_bs, shuffle=False)

        user_label_counts_array = np.zeros((len(user_groups), len(train_classes)), dtype=int)
        for i, (user_id, user_indices) in enumerate(user_groups.items()):
            user_indices = list(user_indices)
            user_labels = np.array(train_dataset.targets)[user_indices]
            unique_labels, counts = np.unique(user_labels, return_counts=True)
            for label, count in zip(unique_labels, counts):
                user_label_counts_array[i, label] = count
        print("User label counts:")
        print(user_label_counts_array)

        if args.IID == "Dirichlet":
            user_base_labels = {}
            for u in range(args.num_users):
                labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
                user_base_labels[u] = labels_u

        global_image_encoder = image_encoder_new_0(encoder_ini.visual).to(device)
        global_text_encoder = text_encoder_0(encoder_ini.transformer, encoder_ini.token_embedding,
                                               encoder_ini.positional_embedding, encoder_ini.ln_final,
                                               encoder_ini.text_projection)
        for i, p in global_image_encoder.named_parameters():
            p.requires_grad = False
        for i, p in global_text_encoder.named_parameters():
            p.requires_grad = False

        template = ['a photo of a {}.']
        train_texts = [template[0].format(classname.replace('_', ' ')) for classname in train_classes]
        test_texts = [template[0].format(classname.replace('_', ' ')) for classname in test_classes]

        train_texts = clip.tokenize(train_texts).to(device)
        test_texts = clip.tokenize(test_texts).to(device)
        with torch.no_grad():
            logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            scale = logit_scale.exp()

        global_image_encoder.eval()
        global_text_encoder.eval()

        with torch.no_grad():
            user_correct, user_total = defaultdict(int), defaultdict(int)
            test_total, test_correct = 0.0, 0.0
            text_features = global_text_encoder(train_texts)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)
            for images, labels in tqdm(testloader_base):
                images, labels = images.to(device), labels.to(device)
                _, img_proj = global_image_encoder(images)
                img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                logits = scale * img_proj @ text_features.t()

                preds = logits.argmax(dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)

                for idx in range(args.num_users):
                    label_tensor = torch.tensor(user_base_labels[idx], device=device)
                    mask = torch.isin(labels, label_tensor)
                    if mask.any():
                        user_correct[idx] += (preds[mask] == labels[mask]).sum().item()
                        user_total[idx] += mask.sum().item()

            test_acc = test_correct / test_total
            print('Global Base Test Acc: {:.2f}%'.format(100 * test_correct / test_total))
            print('user_correct: {}'.format(user_correct))
            print('user_total: {}'.format(user_total))
            logger.info('[Base] Global Model Test Accuracy: {}'.format(round(100 * test_correct / test_total, 2)))

            local_acc_sum = 0
            for idx in range(args.num_users):
                print('[Local] Global Model Test Accuracy for Client {}: {}'.format(idx, round(100 * user_correct[idx] / user_total[idx], 2)))
                local_acc_sum += user_correct[idx] / user_total[idx]
                logger.info('[Local] Global Model Test Accuracy for Client {}: {}'.format(idx, round(100 * user_correct[idx] / user_total[idx], 2)))

            test_total1, test_correct1 = 0.0, 0.0
            text_features = global_text_encoder(test_texts)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)
            for images, labels in tqdm(testloader_new):
                images, labels = images.to(device), labels.to(device)
                _, img_proj = global_image_encoder(images)
                img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                logits = scale * img_proj @ text_features.t()

                preds = logits.argmax(dim=1) + int(len(train_classes))
                test_correct1 += (preds == labels).sum().item()
                test_total1 += labels.size(0)
            print('Global New Test Acc: {:.2f}%'.format(100 * test_correct1 / test_total1))
            logger.info('[New] Global Model Test Accuracy: {}'.format(round(100 * test_correct1 / test_total1, 2)))

            logger.info('[Local] Local Model Test Accuracy for all clients: {}'.format(round(100 * local_acc_sum / args.num_users, 2)))
