from multiprocessing import current_process
import os
import tensorflow as tf
from typing import Dict, Callable, Optional, Tuple, List
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from CLFLClient import Client
from CLFLServer import Server
#import tensorflow_model_optimization as tfmot
# prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
from sklearn.metrics.pairwise import cosine_similarity
from clustering import flatten
from sklearn.decomposition import PCA


import sys
import argparse

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
# The GPU id to use, usually either "0" or "1";
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3";
os.environ["CUDA_VISIBLE_DEVICES"]="1";
# os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

BACKUP_FREQ  = 1
LOG_FREQ = 1
sim_seed = 13
batch = 64
start = 0

tf.random.set_seed(sim_seed)

parser = argparse.ArgumentParser()
parser.add_argument('--cluster', type=str, help='Enter the clustering algorithm')
parser.add_argument('--distance', type=str, help='Enter the distance method')
parser.add_argument('--users', type=int, help='Enter the number of users')
parser.add_argument('--epochs', type=int, help='Enter the number of local epochs')
parser.add_argument('--rounds', type=int, help='Enter the number of rounds')
parser.add_argument('--concepts', type=int, help='Enter the number of concepts')
parser.add_argument('--eid', type=str, help='Enter EID')
parser.add_argument('--dir', type=str, help='Enter dir')


args = parser.parse_args()
dir = '5_20' if args.dir is None else args.dir
n_concept = 5 if args.concepts is None else args.concepts
EID = 'e' if args.eid is None else args.eid
cluster_algo = "kmeans" if args.cluster is None else args.cluster 
distance_method = "m" if args.distance is None else args.distance
local_epoch = 30 if args.epochs is None else args.epochs
n_round = 100 if args.rounds is None else args.rounds
n_user = 20 if args.users is None else args.users

fedweit_data_dir = "./FedWeIT_6_Tasks/generated_task_data/%s"%dir
MODEL_BACKUP_PATH = './'+EID+'/server_model'
# concepts = ['cifar_0', 'cifar_1', 'face_scrub_0', 'face_scrub_1', 'mnist_0', 'traffic_sign_0']
concepts = ['svhn_0', 'face_scrub_0', 'face_scrub_1', 'mnist_0', 'traffic_sign_0']
# concepts = ['svhn_0']

log_dir = "./%s/logs/"%EID

if not os.path.exists("./%s/"%EID):
    os.makedirs("./%s/"%EID)

if not os.path.exists(MODEL_BACKUP_PATH):
    os.makedirs(MODEL_BACKUP_PATH)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

user_clusters = []
client_list = []

for i in range(n_user):
    new_client = Client(i, concepts, local_epoch, batch)
    new_client.create_model(fedweit_data_dir)
    client_list.append(new_client)

server = Server(n_user, concepts, cluster_algo, distance_method, BACKUP_FREQ, MODEL_BACKUP_PATH)

if n_concept!=1:
    server.create_models()
else:
    server.create_single_model()

# pca = PCA(n_components=30)    

def log_results(r_cur):
    if (r_cur+1) % LOG_FREQ == 0:
        for cid in range(n_user):
            client_list[cid].log_result(log_dir)
        server.log_result(log_dir)

for r_cur in range(n_round):
    print('round %s'%r_cur)

    ################### Clients #####################
    user_concepts = []
    client_weights_list = []
    for cid in range(n_user):
        if n_concept!=1:
            cur_client = client_list[cid]
            # prev_concept = cur_client.get_min_loss_concept()
            cur_client.set_new_round()
            cur_client.concept_test(server.get_model_list(), fedweit_data_dir)
            # print('Client %s\'s concept changes from %s to %s.' %(cid, prev_concept, cur_client.get_min_loss_concept()))
            # try:
            #     cur_client.fine_tune(server.get_model_list(), fedweit_data_dir)
            # except:
            #     continue
        else:
            cur_client = client_list[cid]
            cur_client.set_new_round()
            cur_client.use_model0()
    
        cur_client.fine_tune(server.get_model_list(), fedweit_data_dir)
        client_weights_list.append(cur_client.get_weights())
        user_concepts.append(cur_client.get_data_concept())
    # user_clusters = clustering.spectral(client_weights_list,n_concept)
    # user_clusters = clustering.birch(client_weights_list,n_concept)
    # user_clusters = clustering.pca_kmean(client_weights_list,n_concept)

    ################## Server ########################
    if n_concept!=1:
        server.process_new_round(client_weights_list, user_concepts)
        log_results(r_cur)
    else:
        server.vanilla_new_round(client_weights_list)

