import os
import copy
import clip
import json
import yaml
import argparse
import torch
import logging
import random
import numpy as np

from collections import Counter
from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import defaultdict
from new_loraclip import load

from FedDTL_utils.models_new import *
from FedDTL_utils.clients import LocalUpdate_domain
from FedDTL_utils.utils import get_dataset_domain, SFTConvergenceMonitor


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=20, help="number of rounds of training")
    parser.add_argument('--data_num', type=str, default="few-shot", choices=["few-shot", "all-shot"])
    parser.add_argument('--num_users', type=int, default=6, help="number of users: K")
    parser.add_argument('--dataname', type=str, default='Office_Caltech10', help="name of dataset", choices=['Office_Caltech10', 'DomainNet'])
    parser.add_argument('--IID', type=str, default="IID", choices=["IID", "Dirichlet_domain", "IID_domain"])
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument('--alpha', type=float, default=0.5, help="Dir(alpha)")
    parser.add_argument('--num_users_per_domain', type=int, default=3)
    parser.add_argument('--seed', type=int, default=2025)
    parser.add_argument('--shot_num', type=int, default=16)

    parser.add_argument('--RL_style', type=str, default="GRPO")
    parser.add_argument('--load_style', type=str, default="mix", choices=["latest", "final_SFT", "mix"])
    parser.add_argument('--IE_style', type=str)
    parser.add_argument('--sample_style', type=str)
    parser.add_argument('--random_label_ratio', type=float, default=1.0)
    parser.add_argument('--split_style', type=str, default="split")
    parser.add_argument('--train_style', type=str, default="IE_TE")

    parser.add_argument('--clip_eps', type=float, default=0.2)
    parser.add_argument('--kl_coef', type=float, default=0.5)
    parser.add_argument('--noise_std', type=float, default=0.1)
    parser.add_argument('--GRPO_epochs', type=int, default=3)
    parser.add_argument('--GRPO_sampling_epochs', type=int, default=3)

    parser.add_argument("--proj_lora", type=int, default=4)
    parser.add_argument('--local_ep', type=int, default=2)
    parser.add_argument("--r", type=int, default=4)
    parser.add_argument('--balance_layer', type=int, default=10, help="the start layer with lora")
    parser.add_argument('--model_type', type=str, default='ViT-B/16', choices=['ViT-B/16'])
    parser.add_argument("--lora_mode", type=str, default="vision+text")
    parser.add_argument('--frac', type=float, default=1.0, help='the fraction of users')
    parser.add_argument('--local_bs', type=int, default=64, help="local batch size")
    parser.add_argument('--unseen_flag', type=bool, default=False)
    parser.add_argument('--num_TE_epochs', type=int, default=1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = args_parser()
    device = args.device

    with open("dataset/dataset_setting.yaml") as f:
        config = yaml.safe_load(f)
    if (args.IID == "IID") or (args.IID == "Dirichlet_domain") or (args.IID == "IID_domain"):
        dataset_config = config["datasets_IID"].get(args.dataname)
    else:
        raise ValueError("IID dataset not supported")
    config_setting = {"acc_threshold": dataset_config["acc_threshold"],
                      "patience": dataset_config["patience"]}

    if args.IID == "Dirichlet_domain":
        log_dir = './logs/domain/ours/Dir_multi'
    elif args.IID == "IID_domain":
        log_dir = './logs/domain/ours/IID_multi'
    else:
        log_dir = './logs/domain/ours'

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_filename = f'{args.dataname}.txt'
    logging.basicConfig(level=logging.INFO, format='%(message)s', filename=os.path.join(log_dir, log_filename))
    logger = logging.getLogger(__name__)
    logger.info("dataname: {}--IID: {}--alpha: {}".format(args.dataname, args.IID, args.alpha))
    logger.info("data_num: {}--seed: {}--shot_num: {}".format(args.data_num, args.seed, args.shot_num))

    encoder_ini, transform = load(args.model_type, device=device, r=args.r, lora_mode=args.lora_mode, balance_layer=args.balance_layer)

    train_dataset, test_dataset, classes, user_groups, _ = get_dataset_domain(args, transform)
    testloader = DataLoader(test_dataset, batch_size=args.local_bs, shuffle=False)

    global_image_encoder = image_encoder(encoder_ini.visual).to(device)
    global_text_encoder = text_encoder(encoder_ini.transformer, encoder_ini.token_embedding,
                                       encoder_ini.positional_embedding, encoder_ini.ln_final,
                                       encoder_ini.text_projection).to(device)

    user_label_counts_array = np.zeros((len(user_groups), len(classes)), dtype=int)
    for i, (user_id, user_indices) in enumerate(user_groups.items()):
        user_indices = list(user_indices)
        user_labels = np.array(train_dataset.targets)[user_indices]
        unique_labels, counts = np.unique(user_labels, return_counts=True)
        for label, count in zip(unique_labels, counts):
            user_label_counts_array[i, label] = count
    print("User label counts:")
    print(user_label_counts_array)

    if ((args.IID == "Dirichlet_domain") or (args.IID == "IID_domain")):
        args.num_users = len(user_groups)

    user_base_labels = {}
    for u in range(args.num_users):
        labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
        user_base_labels[u] = labels_u
    logger.info("User base labels: {}".format(user_base_labels))

    weight_user = np.sum(user_label_counts_array, axis=1) / len(train_dataset)
    global_accuracys_base = []
    global_uploaded = []
    criterion_CE = nn.CrossEntropyLoss()

    template = ['a photo of a {}.']
    texts = [template[0].format(classname.replace('_', ' ')) for classname in classes]

    texts = clip.tokenize(texts).to(device)
    with torch.no_grad():
        logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        scale = logit_scale.exp()

    clients = [LocalUpdate_domain(args=args, train_data=train_dataset, idxs=user_groups[idx], client_index=idx,
                                  device=device, model=copy.deepcopy(global_image_encoder), logger=logger)
               for idx in range(args.num_users)]
    monitor = SFTConvergenceMonitor(**config_setting, device=device)
    flag = False
    print_train_acc = []
    local_test_acc = defaultdict(list)

    for global_epoch in range(args.epochs):
        print(f'\n | Global Training Round : {global_epoch + 1} |\n')
        global_weights = {}
        merge_img_proj_dict, merge_labels_dict = dict(), dict()
        train_acc = 0.0
        with torch.no_grad():
            text_features = global_text_encoder(texts)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)

        for idx in range(args.num_users):
            local_model = clients[idx]
            model_params = {name: param.clone() for name, param in global_image_encoder.named_parameters() if param.requires_grad}
            local_model.update_model(model_params=model_params)

            print('--------------client {}-------------'.format(idx))
            local_weight, local_img_proj_dict, label_dict, train_acc_local = local_model.local_train_split(text_features=text_features, flag=flag)

            train_acc += train_acc_local
            if global_weights == {}:
                for i in local_weight.keys():
                    global_weights[i] = local_weight[i].clone() * weight_user[idx]
            else:
                for i in local_weight.keys():
                    global_weights[i].data += local_weight[i].clone().data * weight_user[idx]

            merge_img_proj_dict[idx], merge_labels_dict[idx] = local_img_proj_dict, label_dict

        for i, p in global_image_encoder.named_parameters():
            if i in global_weights.keys():
                p.data.copy_(global_weights[i].data)

        # build mix_global_image_features dataloader
        all_img_proj, all_labels = [], []
        new_merge_img_proj_dict, new_merge_labels_dict = defaultdict(list), defaultdict(list)

        for client_index in range(args.num_users):
            for batch_idx in merge_labels_dict[client_index].keys():
                img_proj = merge_img_proj_dict[client_index][batch_idx]
                labels = merge_labels_dict[client_index][batch_idx]
                for g, l in zip(img_proj, labels):
                    all_img_proj.append(torch.tensor(g))
                    all_labels.append(torch.tensor(l))

        combined = list(zip(all_img_proj, all_labels))
        random.shuffle(combined)
        for idx in range(0, len(combined), args.local_bs):
            batch = combined[idx:idx + args.local_bs]
            if not batch: continue
            img_proj_batch, labels_batch = zip(*batch)
            new_merge_img_proj_dict[idx // args.local_bs] = torch.stack(img_proj_batch).numpy()
            new_merge_labels_dict[idx // args.local_bs] = torch.tensor(labels_batch).tolist()

        global_label_counter = Counter()
        uploaded_labels, uploaded_label_count = [], []
        for client_data in new_merge_labels_dict.values():
            global_label_counter.update(client_data)

        for label, count in sorted(global_label_counter.items()):
            uploaded_label_count.append(count)
            uploaded_labels.append(label)
        print(f"Uploaded label distribution: {uploaded_label_count}")
        print(F"The sum of uploaded sample number is: {sum(uploaded_label_count)}")
        global_uploaded.append(sum(uploaded_label_count))

        # ======= Global_text_encoder_training =======
        text_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, global_text_encoder.parameters()), lr=1e-3)
        global_text_encoder.train()
        for _ in range(args.num_TE_epochs):
            for batch_idx in new_merge_labels_dict.keys():
                global_img_proj = torch.tensor(new_merge_img_proj_dict[batch_idx]).to(device)
                label = torch.tensor(new_merge_labels_dict[batch_idx]).to(device)

                text_features = global_text_encoder(texts)
                text_features = text_features / text_features.norm(dim=1, keepdim=True)

                text_logits = scale * global_img_proj @ text_features.t()
                loss = criterion_CE(text_logits, label)

                text_optimizer.zero_grad()
                loss.backward(retain_graph=True)
                text_optimizer.step()

        for client_index in range(args.num_users):
            merge_img_proj_dict[client_index].clear()
            merge_labels_dict[client_index].clear()
        merge_img_proj_dict.clear()
        merge_labels_dict.clear()
        new_merge_img_proj_dict.clear()
        new_merge_labels_dict.clear()

        # ======= Test Dataset for Global Performance =======
        if ((global_epoch + 1) % 1 == 0):
            domain_correct = defaultdict(int)
            domain_total = defaultdict(int)
            user_correct, user_total = defaultdict(int), defaultdict(int)

            with torch.no_grad():
                test_total, test_correct = 0.0, 0.0
                text_features = global_text_encoder(texts)
                text_features = text_features / text_features.norm(dim=1, keepdim=True)
                for images, labels, domains in tqdm(testloader):
                    images, labels = images.to(device), labels.to(device)
                    domains = domains.to(device)
                    _, img_proj = global_image_encoder(images)
                    img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                    logits = scale * img_proj @ text_features.t()

                    preds = logits.argmax(dim=1)
                    test_correct += (preds == labels).sum().item()
                    test_total += labels.size(0)

                    for idx in range(args.num_users):
                        client_domain = clients[idx].domain_idx
                        domain_mask = (domains == client_domain)
                        domain_mask = domain_mask.to(device)
                        label_mask = torch.tensor([l.item() in user_base_labels[idx] for l in labels], device=device)
                        local_mask = domain_mask & label_mask

                        if local_mask.any():
                            user_correct[idx] += (preds[local_mask] == labels[local_mask]).sum().item()
                            user_total[idx] += local_mask.sum().item()

                    for d in torch.unique(domains):
                        d = d.item()
                        mask = (domains == d)

                        domain_correct[d] += (preds[mask] == labels[mask]).sum().item()
                        domain_total[d] += mask.sum().item()

                test_acc = test_correct / test_total
                print('Global Test Acc: {:.2f}%'.format(100 * test_correct / test_total))
                global_accuracys_base.append(round(100 * test_correct / test_total, 2))

                for idx in range(args.num_users):
                    local_test_acc[idx].append(100 * user_correct[idx] / user_total[idx])

                if args.dataname == "Office_Caltech10":
                    domain_names = ['amazon', 'dslr', 'webcam', 'caltech']
                else:
                    domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']

                domain_test_acc = {}
                for d in sorted(domain_total.keys()):
                    acc = domain_correct[d] / max(1, domain_total[d])
                    acc = round(acc * 100, 2)
                    domain_test_acc[domain_names[d]] = acc

        with torch.no_grad():
            if (flag == False):
                train_acc /= args.num_users
                converged = monitor.check_convergence(acc=train_acc)
                print(f"acc: {train_acc:.4f}, converged: {converged}")
                print_train_acc.append(round(train_acc, 4))
                if converged:
                    flag = True
                    print('[Converged]: Acc stable for {} epochs'.format(global_epoch + 1))
                    logger.info('[Converged]: Acc stable for {} epochs'.format(global_epoch + 1))
                    save_name = f"{args.dataname}_SFT_final1.pt"
                    save_path = os.path.join("./", save_name)
                    torch.save(global_weights, save_path)

    local_print_test_accs = []
    for k in range(args.num_users):
        local_print_test_accs.append(local_test_acc[k][-1])

    logger.info(f"FedDTL Local Model Test Accuracy for clients [average] = {round(sum(local_print_test_accs) / len(local_print_test_accs), 2)}")
    logger.info(f"FedDTL Global Model Test Accuracy for average clients = {global_accuracys_base}")

    logger.info("======== Averaged Per-Domain Accuracy Across All Clients ========")
    logger.info(f"FedDTL_global_per_domain = ")
    for domain in sorted(domain_test_acc.keys()):
        logger.info(f"  \"{domain}\": [{domain_test_acc[domain]}],")
    logger.info('Global uploaded: {}'.format(global_uploaded))
