import torch
import copy
import numpy as np
from torchvision import datasets
from collections import defaultdict
from dataset.food101 import Food101Dataset
from dataset.Flower102 import Flowers102Dataset
from dataset.OxfordPet import OxfordPetDataset
from dataset.EuroSAT import EuroSATDataset
from dataset.Office_Caltech_10 import OfficeCaltechDataset
from dataset.DomainNet import DomainNetDataset_10
from dataset.tiny_imagenet import Tiny_ImageNetDataset
from dataset.caltech101 import Caltech101Dataset
from dataset.caltech256 import Caltech256Dataset


others = ['EuroSAT', 'Flower102', 'OxfordPet', 'Tiny_ImageNet', 'Food101', 'Caltech101', 'Caltech256']

def dataset_iid(dataset, num_users):
    np.random.seed(1234)
    num_items = int(len(dataset) / num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


def dataset_by_label(dataset, user_base_labels):
    dict_users = {u: set() for u in user_base_labels.keys()}
    for idx, target in enumerate(dataset.targets):
        for user_id, label_list in user_base_labels.items():
            if target in label_list:
                dict_users[user_id].add(idx)
                break

    return dict_users


def dataset_domain(dataset, num_users, domains):
    assert num_users == len(domains), "user number == domain number"

    dict_users = {i: set() for i in range(num_users)}

    for idx, domain_id in enumerate(dataset.domain_ids):
        dict_users[domain_id].add(idx)

    return dict_users


def dataset_per_domain_dirichlet(dataset, num_users_per_domain, alpha):
    labels = dataset.targets
    domain_ids = dataset.domain_ids
    np.random.seed(42)

    domain_to_indices = defaultdict(list)
    for i, d in enumerate(domain_ids):
        domain_to_indices[d].append(i)

    user_groups = {}
    global_uid = 0

    for domain, domain_sample_indices in domain_to_indices.items():
        cls_to_indices = defaultdict(list)
        for idx in domain_sample_indices:
            cls_to_indices[labels[idx]].append(idx)

        domain_users = [[] for _ in range(num_users_per_domain)]
        for c, idx_list in cls_to_indices.items():
            idx_list = np.array(idx_list)
            np.random.shuffle(idx_list)

            proportions = np.random.dirichlet([alpha] * num_users_per_domain)
            counts = (proportions * len(idx_list)).astype(int)

            while counts.sum() < len(idx_list):
                counts[np.argmax(counts)] += 1
            while counts.sum() > len(idx_list):
                counts[np.argmax(counts)] -= 1

            start = 0
            for u in range(num_users_per_domain):
                domain_users[u].extend(idx_list[start:start + counts[u]].tolist())
                start += counts[u]

        for u in range(num_users_per_domain):
            user_groups[global_uid] = set(domain_users[u])
            global_uid += 1

    return user_groups


def dataset_per_domain_iid(dataset, num_users_per_domain):
    labels = np.array(dataset.targets)
    domain_ids = np.array(dataset.domain_ids)

    user_groups = {}
    global_uid = 0

    domain_to_indices = defaultdict(list)
    for idx, d in enumerate(domain_ids):
        domain_to_indices[d].append(idx)

    for domain, domain_indices in domain_to_indices.items():
        domain_indices = np.array(domain_indices)

        # domain 内：class → indices
        class_to_indices = defaultdict(list)
        for idx in domain_indices:
            class_to_indices[labels[idx]].append(idx)

        domain_users = [[] for _ in range(num_users_per_domain)]
        for c, idx_list in class_to_indices.items():
            idx_list = np.array(idx_list)
            np.random.shuffle(idx_list)

            splits = np.array_split(idx_list, num_users_per_domain)
            for u in range(num_users_per_domain):
                domain_users[u].extend(splits[u].tolist())

        for u in range(num_users_per_domain):
            user_groups[global_uid] = set(domain_users[u])
            global_uid += 1

    return user_groups


def dataset_dir_new(dataset, num_users, alpha):
    np.random.seed(2025)
    num_classes = len(dataset.classes)

    targets = np.array(dataset.targets)
    dict_users = {i: [] for i in range(num_users)}

    class_idxs = [np.where(targets == y)[0] for y in range(num_classes)]

    for c in range(num_classes):
        np.random.shuffle(class_idxs[c])
        proportions = np.random.dirichlet([alpha] * num_users)

        proportions = (np.cumsum(proportions) * len(class_idxs[c])).astype(int)[:-1]
        splits = np.split(class_idxs[c], proportions)

        for i in range(num_users):
            dict_users[i].extend(splits[i])

    for i in range(num_users):
        dict_users[i] = set(dict_users[i])

    return dict_users


def get_dataset_new(args, transform):
    if args.dataname == 'CIFAR10':
        train_set = datasets.CIFAR10(root='./cifar10', train=True, download=False, transform=transform)
        test_set = datasets.CIFAR10(root='./cifar10', train=False, download=False, transform=transform)
    elif args.dataname == 'CIFAR100':
        train_set = datasets.CIFAR100(root='./cifar100', train=True, download=False, transform=transform)
        test_set = datasets.CIFAR100(root='./cifar100', train=False, download=False, transform=transform)
    elif args.dataname == 'OxfordPet':
        train_set = OxfordPetDataset(root='./oxfordpet/', split='train', transform=transform)
        test_set = OxfordPetDataset(root='./oxfordpet/', split='test', transform=transform)
    elif args.dataname == 'Flower102':
        train_set = Flowers102Dataset(root='./flower102/', split='train', transform=transform)
        test_set = Flowers102Dataset(root='./flower102/', split='test', transform=transform)
    elif args.dataname == 'Tiny_ImageNet':
        train_set = Tiny_ImageNetDataset(root='./tiny_imagenet/tiny-imagenet-200/', split='train', transform=transform)
        test_set = Tiny_ImageNetDataset(root='./tiny_imagenet/tiny-imagenet-200/', split='val', transform=transform)
    elif args.dataname == 'Food101':
        train_set = Food101Dataset(root='./food101/food-101', split='train', transform=transform)
        test_set = Food101Dataset(root='./food101/food-101', split='test', transform=transform)
    elif args.dataname == 'EuroSAT':
        train_set = EuroSATDataset(root='./eurosat', split='train', transform=transform)
        test_set = EuroSATDataset(root='./eurosat', split='test', transform=transform)
    elif args.dataname == 'Caltech101':
        train_set = Caltech101Dataset(root='./caltech101', split='train', transform=transform)
        test_set = Caltech101Dataset(root='./caltech101', split='test', transform=transform)
    elif args.dataname == 'Caltech256':
        train_set = Caltech256Dataset(root='./caltech256', split='train', transform=transform)
        test_set = Caltech256Dataset(root='./caltech256', split='test', transform=transform)
    else:
        raise ValueError("choose data_name from [CIFAR10, CIFAR100, OxfordPet, Flower102, Tiny_ImageNet, Food101, EuroSAT, Caltech101, Caltech256]")

    all_classes = [f'{label}' for label in train_set.classes]
    num_classes = len(all_classes)
    mid = num_classes // 2
    train_labels = set(range(mid))
    unseen_labels = set(range(mid, num_classes))
    user_base_labels = {}
    for u in range(args.num_users):
        user_base_labels[u] = list(range(mid))[0:mid]

    train_indices = [i for i, y in enumerate(train_set.targets) if y in train_labels]
    if (args.dataname in others):
        train_set.data = [train_set.data[i] for i in train_indices]
        train_set.targets = [train_set.targets[i] for i in train_indices]
    else:
        train_set.data = train_set.data[train_indices]
        train_set.targets = np.array(train_set.targets)[train_indices].tolist()

    test_indices_A = [i for i, y in enumerate(test_set.targets) if y in train_labels]
    test_set_A = copy.deepcopy(test_set)
    if (args.dataname in others):
        test_set_A.data = [test_set.data[i] for i in test_indices_A]
        test_set_A.targets = [test_set.targets[i] for i in test_indices_A]
    else:
        test_set_A.data = test_set.data[test_indices_A]
        test_set_A.targets = np.array(test_set.targets)[test_indices_A].tolist()

    test_indices_B = [i for i, y in enumerate(test_set.targets) if y in unseen_labels]
    test_set_B = copy.deepcopy(test_set)
    if (args.dataname in others):
        test_set_B.data = [test_set.data[i] for i in test_indices_B]
        test_set_B.targets = [test_set.targets[i] for i in test_indices_B]
    else:
        test_set_B.data = test_set.data[test_indices_B]
        test_set_B.targets = np.array(test_set.targets)[test_indices_B].tolist()

    user_groups = dataset_iid(train_set, args.num_users)
    train_classes = [all_classes[i] for i in sorted(train_labels)]
    test_classes = [all_classes[i] for i in sorted(unseen_labels)]

    return train_set, test_set_A, test_set_B, train_classes, test_classes, user_groups, user_base_labels


def get_dataset_new_dir(args, transform):
    if args.dataname == 'CIFAR10':
        train_set = datasets.CIFAR10(root='./cifar10', train=True, download=False, transform=transform)
        test_set = datasets.CIFAR10(root='./cifar10', train=False, download=False, transform=transform)
    elif args.dataname == 'CIFAR100':
        train_set = datasets.CIFAR100(root='./cifar100', train=True, download=False, transform=transform)
        test_set = datasets.CIFAR100(root='./cifar100', train=False, download=False, transform=transform)
    elif args.dataname == 'OxfordPet':
        train_set = OxfordPetDataset(root='./oxfordpet/', split='train', transform=transform)
        test_set = OxfordPetDataset(root='./oxfordpet/', split='test', transform=transform)
    elif args.dataname == 'Flower102':
        train_set = Flowers102Dataset(root='./flower102/', split='train', transform=transform)
        test_set = Flowers102Dataset(root='./flower102/', split='test', transform=transform)
    elif args.dataname == 'Tiny_ImageNet':
        train_set = Tiny_ImageNetDataset(root='./tiny_imagenet/tiny-imagenet-200/', split='train', transform=transform)
        test_set = Tiny_ImageNetDataset(root='./tiny_imagenet/tiny-imagenet-200/', split='val', transform=transform)
    elif args.dataname == 'Food101':
        train_set = Food101Dataset(root='./food101/food-101', split='train', transform=transform)
        test_set = Food101Dataset(root='./food101/food-101', split='test', transform=transform)
    elif args.dataname == 'EuroSAT':
        train_set = EuroSATDataset(root='./eurosat', split='train', transform=transform)
        test_set = EuroSATDataset(root='./eurosat', split='test', transform=transform)
    elif args.dataname == 'Caltech101':
        train_set = Caltech101Dataset(root='./caltech101', split='train', transform=transform)
        test_set = Caltech101Dataset(root='./caltech101', split='test', transform=transform)
    elif args.dataname == 'Caltech256':
        train_set = Caltech256Dataset(root='./caltech256', split='train', transform=transform)
        test_set = Caltech256Dataset(root='./caltech256', split='test', transform=transform)
    else:
        raise ValueError("choose data_name from [CIFAR10, CIFAR100, OxfordPet, Flower102, Tiny_ImageNet, Food101, EuroSAT, Caltech101, Caltech256]")

    all_classes = [f'{label}' for label in train_set.classes]
    num_classes = len(all_classes)
    mid = num_classes // 2
    train_labels = set(range(mid))
    unseen_labels = set(range(mid, num_classes))
    user_base_labels = {}
    for u in range(args.num_users):
        user_base_labels[u] = list(range(mid))[0:mid]

    train_indices = [i for i, y in enumerate(train_set.targets) if y in train_labels]
    if (args.dataname in others):
        train_set.data = [train_set.data[i] for i in train_indices]
        train_set.targets = [train_set.targets[i] for i in train_indices]
    else:
        train_set.data = train_set.data[train_indices]
        train_set.targets = np.array(train_set.targets)[train_indices].tolist()

    test_indices_A = [i for i, y in enumerate(test_set.targets) if y in train_labels]
    test_set_A = copy.deepcopy(test_set)
    if (args.dataname in others):
        test_set_A.data = [test_set.data[i] for i in test_indices_A]
        test_set_A.targets = [test_set.targets[i] for i in test_indices_A]
    else:
        test_set_A.data = test_set.data[test_indices_A]
        test_set_A.targets = np.array(test_set.targets)[test_indices_A].tolist()

    test_indices_B = [i for i, y in enumerate(test_set.targets) if y in unseen_labels]
    test_set_B = copy.deepcopy(test_set)
    if (args.dataname in others):
        test_set_B.data = [test_set.data[i] for i in test_indices_B]
        test_set_B.targets = [test_set.targets[i] for i in test_indices_B]
    else:
        test_set_B.data = test_set.data[test_indices_B]
        test_set_B.targets = np.array(test_set.targets)[test_indices_B].tolist()

    user_groups = dataset_dir_new(train_set, args.num_users, args.alpha)
    train_classes = [all_classes[i] for i in sorted(train_labels)]
    test_classes = [all_classes[i] for i in sorted(unseen_labels)]

    return train_set, test_set_A, test_set_B, train_classes, test_classes, user_groups, user_base_labels


def get_dataset_path(args, transform):
    if args.dataname == 'CIFAR10':
        train_set = datasets.CIFAR10(root='./cifar10', train=True, download=False, transform=transform)
        test_set = datasets.CIFAR10(root='./cifar10', train=False, download=False, transform=transform)
    elif args.dataname == 'CIFAR100':
        train_set = datasets.CIFAR100(root='./cifar100', train=True, download=False, transform=transform)
        test_set = datasets.CIFAR100(root='./cifar100', train=False, download=False, transform=transform)
    elif args.dataname == 'OxfordPet':
        train_set = OxfordPetDataset(root='./oxfordpet/', split='train', transform=transform)
        test_set = OxfordPetDataset(root='./oxfordpet/', split='test', transform=transform)
    elif args.dataname == 'Flower102':
        train_set = Flowers102Dataset(root='./flower102/', split='train', transform=transform)
        test_set = Flowers102Dataset(root='./flower102/', split='test', transform=transform)
    elif args.dataname == 'Tiny_ImageNet':
        train_set = Tiny_ImageNetDataset(root='./tiny_imagenet/tiny-imagenet-200/', split='train', transform=transform)
        test_set = Tiny_ImageNetDataset(root='./tiny_imagenet/tiny-imagenet-200/', split='val', transform=transform)
    elif args.dataname == 'Food101':
        train_set = Food101Dataset(root='./food101/food-101', split='train', transform=transform)
        test_set = Food101Dataset(root='./food101/food-101', split='test', transform=transform)
    elif args.dataname == 'EuroSAT':
        train_set = EuroSATDataset(root='./eurosat', split='train', transform=transform)
        test_set = EuroSATDataset(root='./eurosat', split='test', transform=transform)
    elif args.dataname == 'Caltech101':
        train_set = Caltech101Dataset(root='./caltech101', split='train', transform=transform)
        test_set = Caltech101Dataset(root='./caltech101', split='test', transform=transform)
    elif args.dataname == 'Caltech256':
        train_set = Caltech256Dataset(root='./caltech256', split='train', transform=transform)
        test_set = Caltech256Dataset(root='./caltech256', split='test', transform=transform)
    else:
        raise ValueError("choose data_name from [CIFAR10, CIFAR100, OxfordPet, Flower102, Tiny_ImageNet, Food101, EuroSAT, Caltech101, Caltech256]")

    all_classes = [f'{label}' for label in train_set.classes]
    num_classes = len(all_classes)
    mid = num_classes // 2
    base_pool = list(range(mid))
    new_pool = list(range(mid, num_classes))

    num_users = args.num_users
    num_base_per_user = mid // num_users
    user_base_labels = {}
    assigned_base = set()

    for u in range(num_users):
        start = u * num_base_per_user
        end = (u + 1) * num_base_per_user
        user_base_labels[u] = base_pool[start:end]
        assigned_base.update(user_base_labels[u])

    remaining_base = [c for c in base_pool if c not in assigned_base]
    if len(remaining_base) > 0:
        new_pool.extend(remaining_base)
    for u in range(num_users):
        while len(user_base_labels[u]) < num_base_per_user and len(new_pool) > 0:
            user_base_labels[u].append(new_pool.pop(0))

    train_labels = set(assigned_base)
    unseen_labels = set(new_pool)

    train_indices = [i for i, y in enumerate(train_set.targets) if y in train_labels]
    if (args.dataname in others):
        train_set.data = [train_set.data[i] for i in train_indices]
        train_set.targets = [train_set.targets[i] for i in train_indices]
    else:
        train_set.data = train_set.data[train_indices]
        train_set.targets = np.array(train_set.targets)[train_indices].tolist()

    test_indices_A = [i for i, y in enumerate(test_set.targets) if y in train_labels]
    test_set_A = copy.deepcopy(test_set)
    if (args.dataname in others):
        test_set_A.data = [test_set.data[i] for i in test_indices_A]
        test_set_A.targets = [test_set.targets[i] for i in test_indices_A]
    else:
        test_set_A.data = test_set.data[test_indices_A]
        test_set_A.targets = np.array(test_set.targets)[test_indices_A].tolist()

    test_indices_B = [i for i, y in enumerate(test_set.targets) if y in unseen_labels]
    test_set_B = copy.deepcopy(test_set)
    if (args.dataname in others):
        test_set_B.data = [test_set.data[i] for i in test_indices_B]
        test_set_B.targets = [test_set.targets[i] for i in test_indices_B]
    else:
        test_set_B.data = test_set.data[test_indices_B]
        test_set_B.targets = np.array(test_set.targets)[test_indices_B].tolist()

    user_groups = dataset_by_label(train_set, user_base_labels)
    train_classes = [all_classes[i] for i in sorted(train_labels)]
    test_classes = [all_classes[i] for i in sorted(unseen_labels)]

    return train_set, test_set_A, test_set_B, train_classes, test_classes, user_groups, user_base_labels


def get_dataset_domain(args, transform):
    if args.dataname == 'Office_Caltech10':
        train_set = OfficeCaltechDataset(root='./office_caltech_10/', domains=['amazon', 'dslr', 'webcam', 'caltech'], split='train', transform=transform)
        test_set = OfficeCaltechDataset(root='./office_caltech_10/', domains=['amazon', 'dslr', 'webcam', 'caltech'], split='test', transform=transform)
    elif args.dataname == 'DomainNet':
        train_set = DomainNetDataset_10(root='./domainnet/', domains=['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'], split='train', transform=transform)
        test_set = DomainNetDataset_10(root='./domainnet/', domains=['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'], split='test', transform=transform)
    else:
        raise ValueError("choose data_name from [Office_Caltech10, DomainNet]")

    if args.IID == "IID":
        if args.dataname == 'Office_Caltech10':
            user_groups = dataset_domain(train_set, args.num_users, ['amazon', 'dslr', 'webcam', 'caltech'])
        else:
            user_groups = dataset_domain(train_set, args.num_users,
                                         ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'])
    elif args.IID == "Dirichlet_domain":
        user_groups = dataset_per_domain_dirichlet(train_set, args.num_users_per_domain, args.alpha)
    elif args.IID == "IID_domain":
        user_groups = dataset_per_domain_iid(train_set, args.num_users_per_domain)
    else:
        user_groups = dataset_dir_new(train_set, args.num_users, args.alpha)
    classes = [f'{label}' for label in train_set.classes]

    user_base_labels = {}
    all_classes = [f'{label}' for label in train_set.classes]
    mid = len(all_classes)

    if (args.IID == "Dirichlet_domain") or (args.IID == "IID_domain"):
        total_users = len(user_groups)
        for u in range(total_users):
            user_base_labels[u] = list(range(mid))
    else:
        for u in range(args.num_users):
            user_base_labels[u] = list(range(mid))[0:mid]

    return train_set, test_set, classes, user_groups, user_base_labels


class SFTConvergenceMonitor:
    def __init__(self, acc_threshold=1e-3, patience=3, device="cuda"):
        self.acc_threshold = acc_threshold
        self.patience = patience
        self.device = device
        self.history = []
        self.counter = 0

    @torch.no_grad()
    def check_convergence(self, acc):
        if len(self.history) > 0:
            prev_acc = self.history[-1]["acc"]
            delta_acc = abs(acc - prev_acc)

            if (delta_acc < self.acc_threshold):
                self.counter += 1
            else:
                self.counter = 0

            if self.counter >= self.patience:
                return True

        self.history.append({"acc": acc})
        return False
