#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib


from Algorithm.Phoenix_util import *

matplotlib.use('Agg')
import copy
import wandb
import torch
import torch.multiprocessing as mp

from utils.options import args_parser
from utils.set_seed import set_random_seed
from models.Update import *
from models.Nets import *
from models.Fed import Aggregation,AggregationMut
from models.test import *
from models.resnetcifar import *

from utils.get_dataset import * 
from utils.utils import save_result,save_model
from Algorithm.Training_FedGen import FedGen

from Algorithm.Training_FedMut import FedMut

from torch.autograd import Variable
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap

import sys
import io   

def FedPhoenix(net_glob, dataset_train, dataset_test, dict_users):
    
    net_glob.train()
    # training
    acc = []
    train_loss=[]
    train_acc=[]
    test_loss=[]

    args.density_local=0.01


    for iter in range(args.epochs):
        
        
        
        
        if args.density_local>1 or args.density_local<0 :
            args.density_local=0
  
        print('*'*80)
        print('Round {:3d}'.format(iter))
       
        w_locals = []
        lens = []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        clients_mloss = 0  
        for idx in idxs_users:           
            #Reset
            net_local = None
            net_local = copy.deepcopy(net_glob).to(args.device)
            layer_count,stopped_count=reset_kernels_and_neurons_stair(net_local, iter, args.FP_conv, conv_reset_ratio=args.reset, fc_reset_ratio=0,scale_factor=1,init_method='uniform')
            local = LocalUpdate_FedAvg(args=args, dataset=dataset_train, idxs=dict_users[idx],dataset_test=dataset_test)
            w = local.train(net=net_local)
            w_locals.append(copy.deepcopy(w))   
            lens.append(len(dict_users[idx]))
        # update global weights 
        w_glob = Aggregation(w_locals, lens )
     
        

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)
      
        if iter % 5==4:
            item_acc,item_loss = test_with_loss(net_glob, dataset_test, args) 
            print(f'model:{args.model},Algorithm:{args.algorithm},dataset:{args.dataset}')

            test_loss.append(item_loss)
            acc.append(item_acc)
            # wandb.log({"epoch":iter,"acc": item_acc,"density_local":args.density_local,"lr":args.lr})
    save_FedPhoenix_result(acc, f"acc_{iter}Epoch", args)
    calculate_max_average(acc)
    save_model(net_glob.state_dict(), f"model_{iter}Epoch", args)

def FedAvg(net_glob, dataset_train, dataset_test, dict_users):
    
    net_glob.train()
    # training
    acc = []
    test_loss=[]
    for iter in range(args.epochs):
        print('*'*80)
        print('Round {:3d}'.format(iter))

        w_locals = []
      
        lens = []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            net_local = None
            net_local = copy.deepcopy(net_glob).to(args.device)
            local = LocalUpdate_FedAvg(args=args, dataset=dataset_train, idxs=dict_users[idx],dataset_test=dataset_test)
            w = local.train(net=net_local)
 
            w_locals.append(copy.deepcopy(w))   
            lens.append(len(dict_users[idx]))
        # update global weights   
        w_glob = Aggregation(w_locals, lens )
        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)
        if iter % 5==4:
            print(f'model:{args.model},Algorithm:{args.algorithm},dataset:{args.dataset}')
            
            item_acc,item_loss = test_with_loss(net_glob, dataset_test, args) 
            test_loss.append(item_loss)
            acc.append(item_acc)

        
    item_acc = test(net_glob, dataset_test, args)
    save_result(acc, f"acc_{iter}Epoch", args)
 
    
    save_model(net_glob.state_dict(), 'test_model', args)

def FedProx(net_glob, dataset_train, dataset_test, dict_users):
    net_glob.train()

    acc = []
    test_loss=[]

    
    for iter in range(args.epochs):

        print('*' * 80)
        print('Round {:3d}'.format(iter))

        w_locals = []
        lens = []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            local = LocalUpdate_FedProx(args=args, glob_model=net_glob, dataset=dataset_train, idxs=dict_users[idx])
            w = local.train(net=copy.deepcopy(net_glob).to(args.device))

            w_locals.append(copy.deepcopy(w))
            lens.append(len(dict_users[idx]))
        # update global weights
        w_glob = Aggregation(w_locals, lens)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)
        if iter % 5==4:
            print(f'model:{args.model},Algorithm:{args.algorithm},dataset:{args.dataset}')
          
            item_acc,item_loss = test_with_loss(net_glob, dataset_test, args)  

            test_loss.append(item_loss)
            acc.append(item_acc)
    save_result(acc, f"acc_{iter}Epoch", args)
    save_model(net_glob.state_dict(), f"model_{iter}Epoch", args)

from utils.clustering import *
from scipy.cluster.hierarchy import linkage


def ClusteredSampling(net_glob, dataset_train, dataset_test, dict_users):

    net_glob.to('cpu')

    n_samples = np.array([len(dict_users[idx]) for idx in dict_users.keys()])
    weights = n_samples / np.sum(n_samples)
    n_sampled = max(int(args.frac * args.num_users), 1)

    gradients = get_gradients('', net_glob, [net_glob] * len(dict_users))

    net_glob.train()

    # training
    acc = []

    for iter in range(args.epochs):

        print('*' * 80)
        print('Round {:3d}'.format(iter))

        previous_global_model = copy.deepcopy(net_glob)
        clients_models = []
        sampled_clients_for_grad = []

        # GET THE CLIENTS' SIMILARITY MATRIX
        if iter == 0:
            sim_matrix = get_matrix_similarity_from_grads(
                gradients, distance_type=args.sim_type
            )

        # GET THE DENDROGRAM TREE ASSOCIATED
        linkage_matrix = linkage(sim_matrix, "ward")

        distri_clusters = get_clusters_with_alg2(
            linkage_matrix, n_sampled, weights
        )

        w_locals = []
        lens = []
        idxs_users = sample_clients(distri_clusters)
        for idx in idxs_users:
            local = LocalUpdate_ClientSampling(args=args, dataset=dataset_train, idxs=dict_users[idx])
            local_model = local.train(net=copy.deepcopy(net_glob).to(args.device))
            local_model.to('cpu')

            w_locals.append(copy.deepcopy(local_model.state_dict()))
            lens.append(len(dict_users[idx]))

            clients_models.append(copy.deepcopy(local_model))
            sampled_clients_for_grad.append(idx)

            del local_model
        # update global weights
        w_glob = Aggregation(w_locals, lens)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        gradients_i = get_gradients(
            '', previous_global_model, clients_models
        )
        for idx, gradient in zip(sampled_clients_for_grad, gradients_i):
            gradients[idx] = gradient

        sim_matrix = get_matrix_similarity_from_grads_new(
            gradients, distance_type=args.sim_type, idx=idxs_users, metric_matrix=sim_matrix
        )

        net_glob.to(args.device)
        if iter % 5==4:
            print(f'model:{args.model},Algorithm:{args.algorithm},dataset:{args.dataset}')
          
            item_acc,item_loss = test_with_loss(net_glob, dataset_test, args)  #有修改

            # test_loss.append(item_loss)
            acc.append(item_acc)
        net_glob.to('cpu')

        del clients_models

    save_model(net_glob.state_dict(), f"model_{iter}Epoch", args)
    save_result(acc, f"acc_{iter}Epoch", args)




def test(net_glob, dataset_test, args):
    
    # testing
    acc_test, loss_test = test_img(net_glob, dataset_test, args)

    print("Testing accuracy: {:.2f}".format(acc_test))

    return acc_test.item()

def test_with_loss(net_glob, dataset_test, args):
    
    # testing
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Testing Loss: {:.2f}".format(loss_test))

    print("Testing accuracy: {:.2f}".format(acc_test))

    return acc_test.item(), loss_test

def tSNE(net_glob,dataset,args):
                                                    
    net_glob.eval()
    # features, labels = get_features(net_glob, dataset, args)
    # visualize_decision_boundary(features, labels, net_glob, args, epoch=0)
    zero_out_model_params(net_glob, 0.1)  
    features, labels = get_features(net_glob, dataset, args)
    visualize_decision_boundary(features, labels, net_glob, args, epoch=0)
def get_model(dataset, model_name, args=None):
    """
    Get the corresponding model instance based on the dataset and model name.
    
    Args:
        dataset (str): Dataset name
        model_name (str): Model name
        args: Model initialization parameters (optional)
    
    Returns:
        Model instance
    """
    for key in MODEL_FACTORY:
        if key in dataset:
            model_dict = MODEL_FACTORY[key]
            if model_name in model_dict:
                return model_dict[model_name](args) if args else model_dict[model_name]()
    raise ValueError(f"Unsupported dataset {dataset} or model {model_name}")
if __name__ == '__main__':
    # parse args  
    args = args_parser()
    device_index = args.gpu
    torch.cuda.set_device(device_index)
    args.device = torch.device('cuda')

    if 'timage'in args.dataset:
        dataset_train, dataset_test, dict_users = get_tiny_imagenet_data(args)
    else:
        dataset_train, dataset_test, dict_users = get_dataset(args)


    MODEL_FACTORY = {
    'timage': {
        'resnet18': lambda args=None: ResNetTinyImageNet(BasicBlock, [2, 2, 2, 2]),
        'resnet_drop': lambda args=None: ResNetTinyImageNet_Drop(BasicBlock, [2, 2, 2, 2]),
        'resnet_drop25': lambda args=None: ResNetTinyImageNet_Drop25(BasicBlock, [2, 2, 2, 2]),
        'vgg': lambda args: VGG16_timage(args),
        'vggdrop': lambda args: VGG16_timage_Drop(args),
        'mobnet': lambda args: MobileNet(args),
        'mobnet_drop': lambda args: MobileNet_Drop(args),
    },
    'cifar': {
        'cnn': lambda args: CNNCifar(args),
        'cnndrop': lambda args: CNNCifarDrop(args),
        'resnet18': lambda args: ResNet18_cifar10(num_classes=args.num_classes),
        'resnet_drop': lambda args: ResNet18_cifar10_drop(num_classes=args.num_classes),
        'mobnet': lambda args: MobileNet(args),
        'mobnet_drop': lambda args: MobileNet_Drop(args),
        'vgg': lambda args: VGG16(args),
        'vggdrop': lambda args: VGG16_Drop(args),
     
    }
}
    try:
        net_glob = get_model(args.dataset, args.model, args)
        print(net_glob)
    except ValueError as e:
        print(e) 
    net_glob.to(args.device)
    


    if args.algorithm == 'FedAvg':
        FedAvg(net_glob, dataset_train, dataset_test, dict_users)
    elif args.algorithm == 'FedProx':
        FedProx(net_glob, dataset_train, dataset_test, dict_users)
    elif args.algorithm == 'ClusteredSampling':
        ClusteredSampling(net_glob, dataset_train, dataset_test, dict_users)
    elif args.algorithm == 'FedGen':

        FedGen(args, net_glob, dataset_train, dataset_test, dict_users)
    elif args.algorithm == 'FedMut':
        FedMut(args, net_glob, dataset_train, dataset_test, dict_users)
    elif args.algorithm == 'FedPhoenix':
        FedPhoenix(net_glob, dataset_train, dataset_test, dict_users)

    elif args.algorithm == 'test':
        test(net_glob, dataset_test , args)
    elif args.algorithm == 'tSNE_train':
        test(net_glob, dataset_test, args)
        tSNE(net_glob,dataset_train,args)
    elif args.algorithm == 'tSNE_test':
        test(net_glob, dataset_test, args)
        tSNE(net_glob,dataset_test,args)
