import torch.nn as nn
import torch
import copy
from torchvision import transforms
import numpy as np
from torch.nn import functional as F
from PIL import Image
import torch.optim as optim
from myNetwork import *
from iCIFAR100 import iCIFAR100
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment as linear_assignment
import random
from FedNovel_ours import * 


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def model_to_device(model, parallel, device):
    if parallel:
        model = nn.DataParallel(model)
        model = model.cuda()
    else:
        card = torch.device("cuda:{}".format(device))
        model.to(card)
    return model

def local_train(args, clients, index, model_g, task_id, old_task_id, model_old, ep_g, data_ids):
    clients[index].model = copy.deepcopy(model_g)

    clients[index].beforeTrain(task_id, data_ids)
    log_print('client: {} current class: {}'.format(index, clients[index].current_class), args.out_file)

    if task_id > 0:
        if task_id != old_task_id:
            log_print('client: {} starts to find potential local centers'.format(index), args.out_file)
            local_centers = clients[index].local_clustering_start()
            local_model = None
        else:
            clients[index].train(ep_g, model_old)
            local_model = clients[index].model.state_dict()
            local_centers = []
    else:
        clients[index].train(ep_g, model_old)
        local_centers = []
        local_model = clients[index].model.state_dict()
    log_print('*'*100, args.out_file)

    return local_model, local_centers

def FedAvg(models, task_id, local_center_pool=None, model_g=None, device=None, best_k=0, args=None):
    if models[0] != None:
        w_avg = copy.deepcopy(models[0])
        for k in w_avg.keys():
            for i in range(1, len(models)):
                w_avg[k] += models[i][k]
            w_avg[k] = torch.div(w_avg[k], len(models))
    else:
        w_avg = None
    
    if task_id > 0:
        if best_k == 0:
            best_k = estimate_best_k(local_center_pool, model_g, device, args)
            log_print('true_k{}, best_k{}'.format(args.task_classes[task_id], best_k), args.out_file)
        global_centers = get_glob_centers(local_center_pool, best_k)
    else:
        global_centers = []
    
    return w_avg, global_centers

def estimate_best_k(local_center_pool, model_g, device, args):
    all_centers = np.array(local_center_pool[0])
    for i in range(1, len(local_center_pool)):
        all_centers = np.concatenate((all_centers, np.array(local_center_pool[i])), axis=0)
    local_dis = pdist(all_centers)
    local_min, local_mean, local_max = local_dis.min(), local_dis.mean(), local_dis.max()

    # rising DBSCAN clustering
    sample_wise_dis_min, sample_wise_dis_max = local_min, local_mean
    db_r_stride = (sample_wise_dis_max - sample_wise_dis_min) / 50
    # if args.dataset == 'cifar100':
    #     mini_size = 3
    # else:
    mini_size = 2
    max_unique_clsnum = 0
    
    for i in range(50):
        eps = sample_wise_dis_min + i * db_r_stride
        db = DBSCAN(eps=eps, min_samples=mini_size).fit(all_centers)
        local_label = db.labels_

        # print(len(np.unique(local_label)))
        if len(np.unique(local_label)) >= max_unique_clsnum:
            max_unique_clsnum = len(np.unique(local_label))
    
    print(max_unique_clsnum)
    res = max_unique_clsnum
    return res
    
def get_glob_centers(local_center_pool, best_k):
    glob_centers = []

    if len(local_center_pool) > 0:
        all_centers = np.array(local_center_pool[0])
        for i in range(1, len(local_center_pool)):
            all_centers = np.concatenate((all_centers, np.array(local_center_pool[i])), axis=0)

        km = KMeans(n_clusters=best_k).fit(all_centers)
        glob_centers = km.cluster_centers_

    return glob_centers

def model_global_eval(model_g, test_dataset, task_id, args):
    model_to_device(model_g, False, args.device)
    model_g.eval()
    test_dataset.getTestData([0, sum(args.task_classes[:task_id+1])])
    test_loader = DataLoader(dataset=test_dataset, shuffle=True, batch_size=128)
    correct, total = 0, 0
    output_known, output_novel = [], []
    labels_known, labels_novel = [], []
    for setp, (indexs, imgs, labels) in enumerate(test_loader):
        imgs, labels = imgs.cuda(args.device), labels.cuda(args.device)
        known_mask = (labels < args.task_classes[0])
        with torch.no_grad():
            outputs = model_g(imgs)
        if known_mask.sum() > 0:
            output_known.append(outputs[known_mask])
            labels_known.append(labels[known_mask])
        if (~known_mask).sum() > 0:
            output_novel.append(outputs[~known_mask])
            labels_novel.append(labels[~known_mask])
    model_g.train()
    
    if task_id > 0:
        output_novel = torch.cat(output_novel)
        labels_novel = torch.cat(labels_novel)
    output_known = torch.cat(output_known)
    labels_known = torch.cat(labels_known)
    
    ## known class acc
    known_data_size = output_known.size(0)
    pred_known = torch.max(output_known[:, :args.task_classes[0]], dim=1)[1]
    correct_known = (pred_known.cpu() == labels_known.cpu()).sum()
    known_acc = 100 * correct_known / known_data_size

    ## novel class acc
    if task_id > 0:
        novel_data_size = output_novel.size(0)
        pred_novel = torch.max(output_novel[:, args.task_classes[0]:], dim=1)[1]
        novel_acc = cluster_acc(labels_novel.cpu().numpy(), pred_novel.cpu().numpy())
        all_acc = 100 * (correct_known + novel_acc * novel_data_size) / (known_data_size + novel_data_size)
        novel_acc *= 100
    else:
        novel_acc = 0.
        all_acc = known_acc

    return known_acc, novel_acc, all_acc

def log_print(message, file, p=True, l=True):
    if p == True:
        print(message)
    if l == True:
        f = open(file, "a")
        f.write(message+'\n')
        f.close()

def cluster_acc(y_pred, y_true, return_ind=False):
    """
    Calculate clustering accuracy. Require scikit-learn installed

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    # D = max(len(np.unique(y_pred)), len(np.unique(y_true)))
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    ind_arr, jnd_arr = linear_assignment(w.max() - w)
    ind = np.array(list(zip(ind_arr, jnd_arr)))

    if return_ind:
        return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind

    else:
        return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size