import os
import copy
import clip
import json
import yaml
import argparse
import torch
import logging
import random
import numpy as np

from collections import Counter
from tqdm import tqdm
from torch.utils.data import DataLoader
from collections import defaultdict
from new_loraclip import load

from FedDTL_utils.models_new import *
from FedDTL_utils.clients import LocalUpdate
from FedDTL_utils.utils import get_dataset_new, SFTConvergenceMonitor, get_dataset_path, get_dataset_new_dir


def args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--epochs', type=int, default=20, help="number of rounds of training")
    parser.add_argument('--data_num', type=str, default="few-shot", choices=["few-shot", "all-shot"])
    parser.add_argument('--dataname', type=str, default='Food101', help="name of dataset", choices=['CIFAR10', 'CIFAR100', 'Flower102', 'Caltech256', 'OxfordPet', "Caltech101", "Food101", "EuroSAT", "Tiny_ImageNet"])
    parser.add_argument('--IID', type=str, default="IID", choices=["IID", "Non-IID", "Dirichlet"])
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument('--alpha', type=float, default=0.1, help="Dir(alpha)")
    parser.add_argument('--local_ep', type=int, default=2)
    parser.add_argument('--num_users', type=int, default=5)
    parser.add_argument('--shot_num', type=int, default=16)

    parser.add_argument('--RL_style', type=str, default="GRPO")
    parser.add_argument('--load_style', type=str, default="mix", choices=["latest", "final_SFT", "mix"])
    parser.add_argument('--IE_style', type=str, default="SFT_RL")
    parser.add_argument('--split_style', type=str, default="split")
    parser.add_argument('--train_style', type=str, default="IE_TE")
    parser.add_argument('--moniter_style', type=str, default="final")

    parser.add_argument('--sample_style', type=str, default="random")
    parser.add_argument('--random_label_ratio', type=float, default=1.0)
    parser.add_argument('--seed', type=int, default=2025)
    parser.add_argument("--proj_lora", type=int, default=4)
    parser.add_argument('--clip_eps', type=float, default=0.2)
    parser.add_argument('--kl_coef', type=float, default=0.5)
    parser.add_argument('--noise_std', type=float, default=0.1)
    parser.add_argument('--GRPO_epochs', type=int, default=3)
    parser.add_argument('--GRPO_sampling_epochs', type=int, default=3)

    parser.add_argument("--r", type=int, default=4)
    parser.add_argument('--balance_layer', type=int, default=10, help="the start layer with lora")
    parser.add_argument('--model_type', type=str, default='ViT-B/16', choices=['ViT-B/32', 'ViT-B/16'])
    parser.add_argument("--lora_mode", type=str, default="vision+text")
    parser.add_argument('--frac', type=float, default=1.0, help='the fraction of users')
    parser.add_argument('--local_bs', type=int, default=64, help="local batch size")
    parser.add_argument('--unseen_flag', type=bool, default=False, help="unseen classes")
    parser.add_argument('--num_TE_epochs', type=int, default=1)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = args_parser()
    device = args.device

    with open("dataset/dataset_setting.yaml") as f:
        config = yaml.safe_load(f)
    if (args.IID == "IID") or (args.IID == "Dirichlet"):
        dataset_config = config["datasets_IID"].get(args.dataname)
    elif (args.IID == "Non-IID"):
        dataset_config = config["datasets_Non-IID"].get(args.dataname)
    else:
        raise ValueError("IID dataset not supported")
    config_setting = {"acc_threshold": dataset_config["acc_threshold"],
                      "patience": dataset_config["patience"]}

    log_dir = './logs'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_filename = f'{args.dataname}.txt'
    logging.basicConfig(level=logging.INFO, format='%(message)s', filename=os.path.join(log_dir, log_filename))
    logger = logging.getLogger(__name__)
    logger.info("dataname: {}--IID: {}--alpha: {}--seed: {}--shot_num: {}".format(args.dataname, args.IID, args.alpha, args.seed, args.shot_num))
    logger.info("data_num: {}".format(args.data_num))

    encoder_ini, transform = load(args.model_type, device=device, r=args.r, lora_mode=args.lora_mode, balance_layer=args.balance_layer)

    if args.IID == "IID":
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_new(args, transform)
    elif args.IID == "Non-IID":
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, user_base_labels = get_dataset_path(args, transform)
    elif args.IID == "Dirichlet":
        train_dataset, test_dataset_base, test_dataset_new, train_classes, test_classes, user_groups, _ = get_dataset_new_dir(args, transform)
    else:
        raise ValueError("Dataset split not supported")
    testloader_base = DataLoader(test_dataset_base, batch_size=args.local_bs, shuffle=False)
    testloader_new = DataLoader(test_dataset_new, batch_size=args.local_bs, shuffle=False)

    global_image_encoder = image_encoder(encoder_ini.visual).to(device)
    global_text_encoder = text_encoder(encoder_ini.transformer, encoder_ini.token_embedding, encoder_ini.positional_embedding, encoder_ini.ln_final, encoder_ini.text_projection).to(device)

    user_label_counts_array = np.zeros((len(user_groups), len(train_classes)), dtype=int)
    for i, (user_id, user_indices) in enumerate(user_groups.items()):
        user_indices = list(user_indices)
        user_labels = np.array(train_dataset.targets)[user_indices]
        unique_labels, counts = np.unique(user_labels, return_counts=True)
        for label, count in zip(unique_labels, counts):
            user_label_counts_array[i, label] = count
    print("User label counts:")
    print(user_label_counts_array)

    if args.IID == "Dirichlet":
        user_base_labels = {}
        for u in range(args.num_users):
            labels_u = np.where(user_label_counts_array[u] > 0)[0].tolist()
            user_base_labels[u] = labels_u
    logger.info("User base labels: {}".format(user_base_labels))

    weight_user = np.sum(user_label_counts_array, axis=1) / len(train_dataset)
    global_accuracys_base, global_accuracys_new = [], []
    global_uploaded = []
    criterion_CE = nn.CrossEntropyLoss()

    template = ['a photo of a {}.']
    train_texts = [template[0].format(classname.replace('_', ' ')) for classname in train_classes]
    test_texts = [template[0].format(classname.replace('_', ' ')) for classname in test_classes]

    train_texts = clip.tokenize(train_texts).to(device)
    test_texts = clip.tokenize(test_texts).to(device)
    with torch.no_grad():
        logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        scale = logit_scale.exp()

    clients = [LocalUpdate(args=args, train_data=train_dataset, idxs=user_groups[idx], client_index=idx,
                           device=device, model=copy.deepcopy(global_image_encoder), logger=logger,
                           user_base_labels=user_base_labels[idx])
               for idx in range(args.num_users)]
    monitor = SFTConvergenceMonitor(**config_setting, device=device)
    flag = False
    print_train_acc = []
    local_test_acc = defaultdict(list)

    for global_epoch in range(args.epochs):
        print(f'\n | Global Training Round : {global_epoch + 1} |\n')
        global_weights = {}
        merge_img_proj_dict, merge_labels_dict = dict(), dict()
        train_acc = 0.0
        with torch.no_grad():
            text_features = global_text_encoder(train_texts)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)

        for idx in range(args.num_users):
            local_model = clients[idx]
            model_params = {name: param.clone() for name, param in global_image_encoder.named_parameters() if param.requires_grad}
            local_model.update_model(model_params=model_params)

            print('--------------client {}-------------'.format(idx))
            local_weight, local_img_proj_dict, label_dict, train_acc_local = local_model.local_train_split(text_features=text_features, flag=flag)

            train_acc += train_acc_local
            if global_weights == {}:
                for i in local_weight.keys():
                    global_weights[i] = local_weight[i].clone() * weight_user[idx]
            else:
                for i in local_weight.keys():
                    global_weights[i].data += local_weight[i].clone().data * weight_user[idx]

            merge_img_proj_dict[idx], merge_labels_dict[idx] = local_img_proj_dict, label_dict

        for i, p in global_image_encoder.named_parameters():
            if i in global_weights.keys():
                p.data.copy_(global_weights[i].data)

        # build mix_global_image_features dataloader
        all_img_proj, all_labels = [], []
        new_merge_img_proj_dict, new_merge_labels_dict = defaultdict(list), defaultdict(list)

        for client_index in range(args.num_users):
            for batch_idx in merge_labels_dict[client_index].keys():
                img_proj = merge_img_proj_dict[client_index][batch_idx]
                labels = merge_labels_dict[client_index][batch_idx]
                for g, l in zip(img_proj, labels):
                    all_img_proj.append(torch.tensor(g))
                    all_labels.append(torch.tensor(l))

        combined = list(zip(all_img_proj, all_labels))
        random.shuffle(combined)
        for idx in range(0, len(combined), args.local_bs):
            batch = combined[idx:idx + args.local_bs]
            if not batch: continue
            img_proj_batch, labels_batch = zip(*batch)
            new_merge_img_proj_dict[idx // args.local_bs] = torch.stack(img_proj_batch).numpy()
            new_merge_labels_dict[idx // args.local_bs] = torch.tensor(labels_batch).tolist()

        global_label_counter = Counter()
        uploaded_labels, uploaded_label_count = [], []
        for client_data in new_merge_labels_dict.values():
            global_label_counter.update(client_data)

        for label, count in sorted(global_label_counter.items()):
            uploaded_label_count.append(count)
            uploaded_labels.append(label)
        print(f"Uploaded label distribution: {uploaded_label_count}")
        print(F"The sum of uploaded sample number is: {sum(uploaded_label_count)}")
        if (global_epoch +1) == 1:
            logger.info(f"Uploaded label distribution: {uploaded_label_count}")
            logger.info(f"The sum of uploaded sample number is: {sum(uploaded_label_count)}")
        global_uploaded.append(sum(uploaded_label_count))

        # ======= Global_text_encoder_training =======
        text_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, global_text_encoder.parameters()), lr=1e-3)
        for _ in range(args.num_TE_epochs):
            for batch_idx in new_merge_labels_dict.keys():
                global_img_proj = torch.tensor(new_merge_img_proj_dict[batch_idx]).to(device)
                label = torch.tensor(new_merge_labels_dict[batch_idx]).to(device)

                text_features = global_text_encoder(train_texts)
                text_features = text_features / text_features.norm(dim=1, keepdim=True)

                text_logits = scale * global_img_proj @ text_features.t()
                loss = criterion_CE(text_logits, label)

                text_optimizer.zero_grad()
                loss.backward(retain_graph=True)
                text_optimizer.step()

        for client_index in range(args.num_users):
            merge_img_proj_dict[client_index].clear()
            merge_labels_dict[client_index].clear()
        merge_img_proj_dict.clear()
        merge_labels_dict.clear()
        new_merge_img_proj_dict.clear()
        new_merge_labels_dict.clear()

        # ======= Test Dataset for Global Performance =======
        with torch.no_grad():
            test_total, test_correct = 0.0, 0.0
            user_correct, user_total = defaultdict(int), defaultdict(int)

            text_features = global_text_encoder(train_texts)
            text_features = text_features / text_features.norm(dim=1, keepdim=True)
            for images, labels in tqdm(testloader_base, disable=True):
                images, labels = images.to(device), labels.to(device)
                _, img_proj = global_image_encoder(images)
                img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                logits = scale * img_proj @ text_features.t()

                preds = logits.argmax(dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)

                for idx in range(args.num_users):
                    label_tensor = torch.tensor(user_base_labels[idx], device=device)
                    mask = torch.isin(labels, label_tensor)
                    if mask.any():
                        user_correct[idx] += (preds[mask] == labels[mask]).sum().item()
                        user_total[idx] += mask.sum().item()

            test_acc = test_correct / test_total
            print('Global Base Test Acc: {:.2f}%'.format(100 * test_correct / test_total))
            global_accuracys_base.append(round(100 * test_correct / test_total, 2))

            for idx in range(args.num_users):
                local_test_acc[idx].append(100 * user_correct[idx] / user_total[idx])

            test_total1, test_correct1 = 0.0, 0.0
            text_features_new = global_text_encoder(test_texts)
            text_features_new = text_features_new / text_features_new.norm(dim=1, keepdim=True)
            for images, labels in tqdm(testloader_new, disable=True):
                images, labels = images.to(device), labels.to(device)
                _, img_proj = global_image_encoder(images)
                img_proj = img_proj / img_proj.norm(dim=1, keepdim=True)
                logits = scale * img_proj @ text_features_new.t()

                preds = logits.argmax(dim=1) + int(len(train_classes))
                test_correct1 += (preds == labels).sum().item()
                test_total1 += labels.size(0)
            print('Global New Test Acc: {:.2f}%'.format(100 * test_correct1 / test_total1))
            global_accuracys_new.append(round(100 * test_correct1 / test_total1, 2))

            if ((flag == False) and (args.IE_style == "SFT_RL")):
                train_acc /= args.num_users
                converged = monitor.check_convergence(acc=train_acc)
                print(f"acc: {train_acc:.4f}, converged: {converged}")
                print_train_acc.append(round(train_acc, 4))
                if converged:
                    flag = True
                    print('[Converged]: Acc stable for {} epochs'.format(global_epoch + 1))
                    logger.info('[Converged]: Acc stable for {} epochs'.format(global_epoch + 1))
                    save_name = f"{args.dataname}_{args.IID}_SFT_final1.pt"
                    save_path = os.path.join("./", save_name)
                    torch.save(global_weights, save_path)

    logger.info('Global uploaded: {}'.format(global_uploaded))
    logger.info('[Base] Global Model Test Accuracy: {}'.format(global_accuracys_base))
    logger.info('[New] Global Model Test Accuracy: {}'.format(global_accuracys_new))
    logger.info('train_acc: {}'.format(print_train_acc))
    for idx in range(args.num_users):
        logger.info('[Local] Global Model Test Accuracy for client {}: {}'.format(idx, local_test_acc[idx]))
