import torch
import time
import copy
import re
import random
import numpy as np

from sim.algorithms.FL_base import SFLClient, SFLServer
from sim.data.datasets import build_dataset
from sim.data.partition import build_partition
from sim.models.build_models import build_model
from sim.utils.record_utils import logconfig, add_log, record_exp_result2
from sim.utils.utils import setup_seed, replace_with_samples, label_distribution, distance, relu, client_probabilities, initialize_state_probabilities, compute_alpha, compute_weights
from sim.utils.optim_utils import OptimKit, LrUpdater
from torch.utils.data import Dataset, DataLoader, Subset
from sim.utils.options import args_parser
from torch.nn.utils import parameters_to_vector, vector_to_parameters

args = args_parser()

torch.set_num_threads(4)
setup_seed(args.seed)
device = torch.device("cuda:{}".format(args.device) if torch.cuda.is_available() else "cpu")
args.alpha = [int(args.alpha[0]), args.alpha[1]] if args.partition == 'exdir' else args.alpha
global_models = []
cluster_identity_list = []
prev_train_set_list = []
train_set_list = []
param_list = []

def customize_record_name(args):
    if args.partition == 'exdir':
        partition = f'{args.partition}{args.alpha[0]},{args.alpha[1]}'
    elif args.partition == 'iid':
        partition = f'{args.partition}'
    elif args.partition == 'dir':
        partition = f'{args.partition}{args.alpha[0]}'
    record_name = f'FedDrift_N{args.N}_M{args.M}_S{args.S}_R{args.R}_E{args.E}_K{args.K}_{args.m}_{args.d}_{partition}'\
                + f'_seed{args.seed}_{args.capacity}_{args.distance}_set{args.set}'
    return record_name
record_name = customize_record_name(args)

def get_loss(model, dataset, device):
    criterion = torch.nn.CrossEntropyLoss()
    
    data_loader = DataLoader(dataset, batch_size=64)
    with torch.no_grad():
        model.eval()

        total_loss = 0.0
        total_samples = 0

        for _, (inputs, labels) in enumerate(data_loader):
            inputs, labels = inputs.to(device=device), labels.to(device=device)

            outputs = model(inputs)
            avg_batch_loss = criterion(outputs, labels)
            total_loss += avg_batch_loss.item() * labels.size(0)
            total_samples += labels.size(0)

        avg_loss = total_loss / total_samples

    return avg_loss

def clustering(global_models, c_id, train_dataset):
    prev_loss_list = []
    loss_list = []
    # calculate the loss of each global model on the training dataset at the previous round
    for model in global_models:
        prev_loss_list.append(get_loss(model, Subset(train_dataset, prev_train_set_list[c_id]), device))
    min_prev_loss = min(prev_loss_list)

    # calculate the loss of each global model on the training dataset at the current round
    for model in global_models:
        loss_list.append(get_loss(model, Subset(train_dataset, train_set_list[c_id]), device))
    min_loss = min(loss_list)

    if min_loss > min_prev_loss + 10.0:
        # concept drift is detected, and create a new model for all drifted clients
        best_model_idx = np.argmin(prev_loss_list)
        cluster_identity_list[c_id] = len(global_models)
        new_model = copy.deepcopy(global_models[best_model_idx])
        global_models.append(new_model)
    else:
        # select the best model from existing clusters
        cluster_identity_list[c_id] = np.array(loss_list).argmin(0)

def aggregate_with_clustering(clients):
    client_groups = {}
    for identity in range(len(global_models)):
        client_groups[identity] = []

    for c_id in clients:
        client_groups[cluster_identity_list[c_id]].append(c_id)

    # multiple-model aggregation
    for identity, global_model in enumerate(global_models):
        if len(client_groups[identity]) == 0:
            # no client updates this global model
            continue

        # model averaging
        total_size = 0
        new_params = torch.zeros_like(parameters_to_vector(global_model.parameters()))
        for c_id in client_groups[identity]:
            client_size = args.capacity
            total_size += client_size
            client_params = parameters_to_vector(param_list[c_id].parameters())
            new_params += client_size * client_params

        new_params /= total_size
        new_params.to(device)

        vector_to_parameters(new_params, global_model.parameters())

def merge_clusters(clients, train_dataset):
    global_model_num = len(global_models)

    # check if there are some clusters to be merged when at least two global models exist
    loss_matrix = np.zeros((global_model_num, global_model_num))
    # generate loss matrix for calculating cluster distances
    for i in range(global_model_num):
        for j in range(global_model_num):
            total_data_size = 0
            for c_id in clients:
                if cluster_identity_list[c_id] == j:
                    data = train_set_list[c_id]
                    total_data_size += len(data)
                    loss = get_loss(global_models[i], Subset(train_dataset, data), device)
                    loss_matrix[i][j] += loss * len(data)
            # the loss is averaged across the clients in each cluster
            if total_data_size != 0:
                loss_matrix[i][j] /= total_data_size
            else:
                # if there is no client in cluster j, the loss is -1
                loss_matrix[i][j] = -1

    # calculate cluster distances
    cluster_distances = np.zeros((global_model_num, global_model_num))
    for i in range(global_model_num):
        for j in range(i, global_model_num):
            if loss_matrix[i][j] == -1 or loss_matrix[j][i] == -1:
                # there is no client in cluster i or cluster j
                dist = -1
            else:
                dist = max(loss_matrix[i][j] - loss_matrix[i][i], loss_matrix[j][i] - loss_matrix[j][j], 0)
            cluster_distances[i][j] = dist
            cluster_distances[j][i] = dist

    # check if there are some clusters to be merged
    deleted_models = []
    while True:
        cluster_data_size = np.zeros(global_model_num)  # number of all samples in each cluster
        for c_id in clients:
            if cluster_identity_list[c_id] is not None:
                cluster_data_size[cluster_identity_list[c_id]] += len(train_set_list[c_id])
        cluster_i = 0
        cluster_j = 0
        min_distance = 10.0
        for i in range(global_model_num):
            for j in range(i + 1, global_model_num):
                if cluster_distances[i][j] == -1:
                    continue
                if cluster_distances[i][j] < min_distance:
                    cluster_i = i
                    cluster_j = j
                    min_distance = cluster_distances[i][j]

        if min_distance == 10.0:
            break

        # merge clusters
        size_i = cluster_data_size[cluster_i]
        size_j = cluster_data_size[cluster_j]
        model_i_params = parameters_to_vector(global_models[cluster_i].parameters())
        model_j_params = parameters_to_vector(global_models[cluster_j].parameters())
        merged_model_params = (size_i * model_i_params + size_j * model_j_params) / (size_i + size_j)
        print(f"\033[34mMerge cluster {cluster_i} and cluster {cluster_j}\033[0m")

        # make model i as the new model (i.e., model k in the paper)
        vector_to_parameters(merged_model_params, global_models[cluster_i].parameters())
        deleted_models.append(cluster_j)
        for c_id in clients:
            if cluster_identity_list[c_id] == cluster_j:
                cluster_identity_list[c_id] = cluster_i

        for l in range(global_model_num):
            if l == cluster_i or l == cluster_j:
                continue
            dist = max(cluster_distances[cluster_i][l], cluster_distances[cluster_j][l])
            cluster_distances[cluster_i][l] = dist
            cluster_distances[l][cluster_i] = dist

        # reset distances
        cluster_distances[:, cluster_j] = -1
        cluster_distances[cluster_j, :] = -1

    deleted_models.sort(reverse=True)
    for i in deleted_models:
        for c_id in clients:
            if cluster_identity_list[c_id] is not None and cluster_identity_list[c_id] > i:
                cluster_identity_list[c_id] -= 1
        del global_models[i]

def main():
    global args, record_name, device, global_models, cluster_identity_list, prev_train_set_list, train_set_list, param_list
    logconfig(name=record_name, flag=args.log)
    add_log('{}'.format(args), flag=args.log)
    add_log('record_name: {}'.format(record_name), flag=args.log)
    
    client = SFLClient()
    server = SFLServer()

    net_dataidx_map = build_partition(args.d, args.M, args.partition, args.alpha)
    
    train_dataset, test_dataset = build_dataset(args.d)

    client.setup_test_dataset(test_dataset)
    train_labels = torch.tensor([label for _, label in train_dataset])
    uniform_distribution = np.ones(args.num_classes) / args.num_classes
    class_probabilities_list = [label_distribution(train_labels, net_dataidx_map[i]) for i in range(args.M)]
    distance_list = [distance(args, class_probabilities_list[i], uniform_distribution) for i in range(args.M)]
    add_log('distance_list: {}'.format(distance_list), flag=args.log)
    
    global_model = build_model(model=args.m, dataset=args.d)
    server.setup_model(global_model.to(device))
    add_log('{}'.format(global_model), flag=args.log)

    # construct optim kit
    client_optim_kit = OptimKit(optim_name=args.optim, batch_size=args.batch_size, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    client_optim_kit.setup_lr_updater(LrUpdater.exponential_lr_updater, mul=args.lr_decay)
    client.setup_optim_kit(client_optim_kit)
    client.setup_criterion(torch.nn.CrossEntropyLoss())
    server.setup_optim_settings(lr=args.global_lr)

    pi_list, data_map_list, nonzero_indices_list = initialize_state_probabilities(args.N, args.M, args.S, args.set)
    add_log('pi_list: {}'.format(pi_list), flag=args.log)

    client_probs = client_probabilities(args.N)

    record_exp_result2(record_name, {'round':0})
    for round in range(0, args.R):
        selected_clients = server.select_clients_prob(args.N, args.P, args.probs)
        add_log('selected clients: {}'.format(selected_clients), flag=args.log)

        for c_id in selected_clients:
            server.setup_temp_model(copy.deepcopy(param_list[c_id].to(device)))
            state_list = np.random.choice(args.M, args.E, p=pi_list[c_id], replace=True)
            add_log('state_list: {}'.format(state_list), flag=args.log)

            for id, state in enumerate(state_list):
                if len(net_dataidx_map[state]) <= args.capacity:
                    data_map_list[c_id] = np.array(net_dataidx_map[state])
                else:
                    data_map_list[c_id] = np.random.choice(net_dataidx_map[state], args.capacity, replace=False)

                if id == 0:
                    train_set_list[c_id] = copy.deepcopy(data_map_list[c_id])
                    if prev_train_set_list[c_id] is None:
                        prev_train_set_list[c_id] = copy.deepcopy(train_set_list[c_id])
                    clustering(global_models, c_id, train_dataset)
                if id == len(state_list) - 1:
                    prev_train_set_list[c_id] = copy.deepcopy(data_map_list[c_id])
                local_distribution = label_distribution(train_labels, data_map_list[c_id])
                add_log('Client {}\'s local distribution: {}'.format(c_id, local_distribution), flag=args.log)

                local_distance = distance(args, local_distribution, uniform_distribution)
                add_log('Client {}\'s local distance: {:.4f}'.format(c_id, local_distance), flag=args.log)

                local_dataset = Subset(train_dataset, data_map_list[c_id])
                local_param, local_delta = client.local_update_step(local_dataset=local_dataset, model=copy.deepcopy(server.temp_model), num_steps=args.K, device=device, clip=args.clip)
                torch.nn.utils.vector_to_parameters(local_param, server.temp_model.parameters())

            param_list[c_id] = copy.deepcopy(server.temp_model)
        if len(global_models) > 1 and round % 5 == 0:
            merge_clusters(selected_clients, train_dataset)
        
        aggregate_with_clustering(selected_clients)
        for c_id in range(args.N):
            param_list[c_id] = global_models[cluster_identity_list[c_id]]
        
        client.optim_kit.update_lr()
        test_losses_list = []
        test_top1_list = []
        test_top5_list = []
        for model in global_models:
            test_losses, test_top1, test_top5 = client.evaluate_dataset(model=model, dataset=client.test_dataset, device=args.device)
            test_losses_list.append(test_losses.avg)
            test_top1_list.append(test_top1.avg)
            test_top5_list.append(test_top5.avg)
        print(test_top1_list)
        add_log("Round {}'s server  test  acc: {:6.2f}%, test  loss: {:.4f}".format(round+1, max(test_top1_list), min(test_losses_list)), 'red', flag=args.log)

        record_exp_result2(record_name, {'round':round+1, 'test_loss'  : min(test_losses_list), 'test_top1'  : max(test_top1_list)})

    if args.save_model == 1:
        torch.save({'model': torch.nn.utils.parameters_to_vector(server.global_model.parameters())}, './save_model/{}.pt'.format(record_name))

if __name__ == '__main__':
    main()