"""Run Experiment

This script allows to run one federated learning experiment; the experiment name, the method and the
number of clients/tasks should be precised along side with the hyper-parameters of the experiment.

The results of the experiment (i.e., training logs) are written to ./logs/ folder.

This file can also be imported as a module and contains the following function:

    * run_experiment - runs one experiments given its arguments
"""
from sklearn import cluster
from utils.utils import *
from utils.constants import *
from utils.args import *

from torch.utils.tensorboard import SummaryWriter
import os

cpu_num = 4
os.environ["OMP_NUM_THREADS"] = str(cpu_num)
os.environ["MKL_NUM_THREADS"] = str(cpu_num)
torch.set_num_threads(cpu_num)


def init_clients(args_, root_path, logs_dir):
    """
    initialize clients from data folders
    :param args_:
    :param root_path: path to directory containing data folders
    :param logs_dir: path to logs root
    :return: List[Client]
    """
    print("===> Building data iterators..")
    if LOADER_TYPE[args_.experiment] == 'cifar10-c':
        class_number = 10
    elif LOADER_TYPE[args_.experiment] == 'cifar100-c':
        class_number = 100
    if LOADER_TYPE[args_.experiment] == 'cifar10-c':
        if 'test' in root_path:
            train_iterators, val_iterators, test_iterators, client_types =\
                get_cifar10C_loaders(
                    root_path='./data/cifar10-c',
                    batch_size=args_.bz,
                    is_validation=args_.validation,
                    test = True
                )
        else:
            train_iterators, val_iterators, test_iterators, client_types =\
                get_cifar10C_loaders(
                    root_path='./data/cifar10-c',
                    batch_size=args_.bz,
                    is_validation=args_.validation
                )
    elif LOADER_TYPE[args_.experiment] == 'cifar100-c':
        if 'test' in root_path:
            train_iterators, val_iterators, test_iterators, client_types =\
                get_cifar10C_loaders(
                    root_path='./data/cifar100-c',
                    batch_size=args_.bz,
                    is_validation=args_.validation,
                    test = True
                )
        else:
            train_iterators, val_iterators, test_iterators, client_types =\
                get_cifar10C_loaders(
                    root_path='./data/cifar100-c',
                    batch_size=args_.bz,
                    is_validation=args_.validation
                )
    elif LOADER_TYPE[args_.experiment] == 'cifar100-c-10':
        if 'test' in root_path:
            train_iterators, val_iterators, test_iterators, client_types =\
                get_cifar10C_loaders(
                    root_path='./data/cifar100-c-10',
                    batch_size=args_.bz,
                    is_validation=args_.validation,
                    test = True
                )
        else:
            train_iterators, val_iterators, test_iterators, client_types =\
                get_cifar10C_loaders(
                    root_path='./data/cifar100-c',
                    batch_size=args_.bz,
                    is_validation=args_.validation
                )
    else:
        train_iterators, val_iterators, test_iterators =\
            get_loaders(
                type_=LOADER_TYPE[args_.experiment],
                root_path=root_path,
                batch_size=args_.bz,
                is_validation=args_.validation
            )
        client_types = [0] * len(train_iterators)

    print("===> Initializing clients..")
    clients_ = []
    for task_id, (train_iterator, val_iterator, test_iterator) in \
            enumerate(tqdm(zip(train_iterators, val_iterators, test_iterators), total=len(train_iterators))):

        if train_iterator is None or test_iterator is None:
            continue

        learners_ensemble =\
            get_learners_ensemble(
                n_learners=args_.n_learners,
                name=args_.experiment,
                device=args_.device,
                optimizer_name=args_.optimizer,
                scheduler_name=args_.lr_scheduler,
                initial_lr=args_.lr,
                input_dim=args_.input_dimension,
                output_dim=args_.output_dimension,
                n_rounds=args_.n_rounds,
                seed=args_.seed,
                mu=args_.mu
            )

        logs_path = os.path.join(logs_dir, "task_{}".format(task_id))
        os.makedirs(logs_path, exist_ok=True)
        logger = SummaryWriter(logs_path)

        if (CLIENT_TYPE[args_.method] == "conceptEM_tune" or args_.method == "FedAvg_tune") and "train" in logs_dir:

            client = get_client(
                client_type=CLIENT_TYPE[args_.method],
                learners_ensemble=learners_ensemble,
                q=args_.q,
                train_iterator=train_iterator,
                val_iterator=val_iterator,
                test_iterator=test_iterator,
                logger=logger,
                local_steps=args_.local_steps,
                tune_locally=True,
                data_type = client_types[task_id],
                class_number = class_number
            )
        else:
            client = get_client(
                client_type=CLIENT_TYPE[args_.method],
                learners_ensemble=learners_ensemble,
                q=args_.q,
                train_iterator=train_iterator,
                val_iterator=val_iterator,
                test_iterator=test_iterator,
                logger=logger,
                local_steps=args_.local_steps,
                tune_locally=args_.locally_tune_clients,
                data_type = client_types[task_id],
                class_number = class_number
            )

        clients_.append(client)

    return clients_


def run_experiment(args_):
    torch.manual_seed(args_.seed)

    data_dir = get_data_dir(args_.experiment)

    if "logs_dir" in args_:
        logs_dir = args_.logs_dir
    else:
        logs_dir = os.path.join("logs", args_to_string(args_))

    print("==> Clients initialization..")
    clients = init_clients(args_, root_path=os.path.join(data_dir, "train"), logs_dir=os.path.join(logs_dir, "train"))

    print("==> Test Clients initialization..")
    test_clients = init_clients(args_, root_path=os.path.join(data_dir, "test"),
                                logs_dir=os.path.join(logs_dir, "test"))

    logs_path = os.path.join(logs_dir, "train", "global")
    os.makedirs(logs_path, exist_ok=True)
    global_train_logger = SummaryWriter(logs_path)

    logs_path = os.path.join(logs_dir, "test", "global")
    os.makedirs(logs_path, exist_ok=True)
    global_test_logger = SummaryWriter(logs_path)

    global_learners_ensemble = \
        get_learners_ensemble(
            n_learners=args_.n_learners,
            name=args_.experiment,
            device=args_.device,
            optimizer_name=args_.optimizer,
            scheduler_name=args_.lr_scheduler,
            initial_lr=args_.lr,
            input_dim=args_.input_dimension,
            output_dim=args_.output_dimension,
            n_rounds=args_.n_rounds,
            seed=args_.seed,
            mu=args_.mu
        )

    if args_.decentralized:
        aggregator_type = 'decentralized'
    else:
        aggregator_type = AGGREGATOR_TYPE[args_.method]

    aggregator =\
        get_aggregator(
            aggregator_type=aggregator_type,
            clients=clients,
            global_learners_ensemble=global_learners_ensemble,
            lr_lambda=args_.lr_lambda,
            lr=args_.lr,
            q=args_.q,
            mu=args_.mu,
            communication_probability=args_.communication_probability,
            sampling_rate=args_.sampling_rate,
            log_freq=args_.log_freq,
            global_train_logger=global_train_logger,
            global_test_logger=global_test_logger,
            test_clients=test_clients,
            verbose=args_.verbose,
            seed=args_.seed
        )

    print("Training..")
    pbar = tqdm(total=args_.n_rounds)
    current_round = 0
    pre_action = 0

    K = 0
    while current_round <= args_.n_rounds:


        if pre_action == 0:
            aggregator.mix(diverse=False)
        else:
            aggregator.mix(diverse=False)

        # with open('./logs/cifar10-c/FedEM/save_clusters-diverse.txt', 'w') as f:
        #     for client in clients:
        #         f.write('{},{}'.format(client.data_type, client.cluster))
        #         f.write('\n')
        #         for k, v in client.label_stats.items():
        #             f.write('{}: {}\n'.format(k, v))
        #         f.write('\n')
        #     f.write('\n')
        if LOADER_TYPE[args_.experiment] == 'cifar10-c':
            C = 10
        elif LOADER_TYPE[args_.experiment] == 'cifar100-c':
            C = 100
        n_learner = aggregator.n_learners
        cluster_label_weights = [[0] * C for _ in range(n_learner)]
        cluster_weights = [0 for _ in range(n_learner)]
        global_flags = [[] for _ in range(n_learner)]
        with open('./logs/cifar10-c/conceptEM/save_clusters-conceptem-adam-adapt-0025.txt', 'w') as f:
            for client in clients:
                # cluster_weights[cluster_index] += 1
                for i in range(len(client.train_iterator.dataset.targets)):
                    f.write('{},{},{}\n'.format(client.data_type, client.train_iterator.dataset.targets[i], client.samples_weights.T[i]))
                    
                    for j in range(len(cluster_label_weights)):
                        # global_flags[j].append(client.samples_weights[j][i])
                    
                    # cluster_index = (client.samples_weights[:, i] == max(client.samples_weights[:, i])).nonzero().squeeze()
                        cluster_weights[j] += client.samples_weights[j][i]
                f.write('\n')
        with open('./logs/cifar10-c/conceptEM/numcluster-conceptem-adam-adapt-0025.txt', 'a+') as f:
            f.write('{}'.format(cluster_weights))
            f.write('\n')
        # global_flags_mean = torch.mean(torch.tensor(global_flags), dim=1).squeeze()
        # global_flags_std = torch.std(global_flags_mean).squeeze()

        # print(global_flags_mean)
        # print(global_flags_std)
        print(cluster_weights)

        # if pre_action >= 5 or pre_action < 0:
        K = 0
        for i in range(n_learner):
            # if K > 0:
            #     break
            if n_learner == 1:
                break
            # if global_flags_mean[i] + global_flags_std <= min(torch.mean(global_flags_mean), 0.1):
            if cluster_weights[i] <= sum(cluster_weights) * 0.025:
        #         # print(i)
                for client in clients:
                    client.remove_learner(i - K)
                for client in test_clients:
                    client.remove_learner(i - K)
                aggregator.remove_learner(i - K)
                K += 1
                cluster_label_weights.pop(i - K)
                # pre_action = -1
                # break
        
        # if pre_action >= 0:
        #     pre_action += 1
        # if pre_action >= 50:
        #     aggregator.add_learner(cluster_weights.index(max(cluster_weights)))
        #     cluster_label_weights.append([0.0] * 10)
        #     pre_action = 0
        # print(pre_action)

        for client in clients:
            client_labels_learner_weights = client.labels_learner_weights
            for j in range(len(cluster_label_weights)):
                for k in range(C):
                    cluster_label_weights[j][k] += client_labels_learner_weights[j][k]
            # for i in range(len(client.train_iterator.dataset.targets)):
            #     for j in range(len(cluster_label_weights)):
            #         cluster_label_weights[j][client.train_iterator.dataset.targets[i]] += client.samples_weights[j][i]
                    # global_flags[j].append(client.samples_weights[j][i])
        for j in range(len(cluster_label_weights)):
        #     for i in range(len(cluster_label_weights[j])):
        #         if cluster_label_weights[j][i] < 1e-8:
        #             cluster_label_weights[j][i] = 1e-8
            cluster_label_weights[j] = [i / sum(cluster_label_weights[j]) for i in cluster_label_weights[j]]

        # print(cluster_label_weights)


        for client in clients:
            client.update_labels_weights(cluster_label_weights)
            # client.global_flags_mean = global_flags_mean
            # client.global_flags_std = global_flags_std

        if aggregator.c_round != current_round:
            pbar.update(1)
            current_round = aggregator.c_round


    if "save_dir" in args_:
        save_dir = os.path.join(args_.save_dir)

        os.makedirs(save_dir, exist_ok=True)
        aggregator.save_state(save_dir)


if __name__ == "__main__":
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    args = parse_args()
    run_experiment(args)
