import os,pickle
# Make TensorFlow logs less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from models import multi_concept_model
import numpy as np
from numpy import genfromtxt
import csv
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--users', type=int, help='Enter the number of users')
parser.add_argument('--rounds', type=int, help='Enter the number of rounds')
parser.add_argument('--start', type=int, help='Enter start round')
parser.add_argument('--eid', type=str, help='Enter EID')
parser.add_argument('--concepts', type=int, help='Enter the number of concepts')
args = parser.parse_args()
n_concepts = 5 if args.concepts is None else args.concepts
EID = 'concepts6' if args.eid is None else args.eid
n_round = 100 if args.rounds is None else args.rounds
start = 0 if args.start is None else args.start
n_user = 20 if args.users is None else args.users

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"]="1";
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

fedweit_data_dir = "./FedWeIT_6_Tasks/generated_task_data/5_20"
MODEL_BACKUP_PATH = './'+EID+'/server_model'
log_dir = "/l/data/xj8/CLFL-main/%s/logs/"%EID

model = multi_concept_model()
model.compile("adam", "categorical_crossentropy", metrics=["accuracy"])

concepts = ['svhn_0', 'face_scrub_0', 'face_scrub_1', 'mnist_0', 'traffic_sign_0']
# concepts = ['svhn_0']

def get_concept_name(conceptID):
    return concepts[conceptID]

def get_test_data(conceptID):
    x_test_all = []
    y_test_all = []
    for cid in range(n_user):
        file_path = fedweit_data_dir + "/" + "cid_" + str(cid) + "_" + get_concept_name(conceptID) + "_" + "test.npy"
        data = np.load(file_path, allow_pickle=True).item()
        x_test, y_test = data['x_test'], data['y_test']
        x_test_all.append(x_test)
        y_test_all.append(y_test)
    x_np_arr = np.concatenate(x_test_all, axis=0)
    y_np_arr = np.concatenate(y_test_all, axis=0)
    return x_np_arr, y_np_arr

file_path = log_dir + "/" + "global_model_results.csv"

with open(file_path, "a", newline="") as file:
    writer = csv.writer(file)
    column_names = ["round", "globalModelID", "conceptID", "loss", "accuracy", "n_sample"]
    writer.writerow(column_names)
    
    for r in range(start, n_round):
        print("round %s"%(r))
        with open('%s/server_weights_%s'%(MODEL_BACKUP_PATH,r), 'rb') as f:
            server_weights_list = pickle.load(f)
        for globalModelID in range(len(server_weights_list)):
            model.set_weights(server_weights_list[globalModelID])
            for conceptID in range(n_concepts):
                x_test, y_test = get_test_data(conceptID)
                loss, accuracy = model.evaluate(x_test, y_test,verbose=0)
                print("Model %s on data %s: loss %s, accuracy %s"%(globalModelID, conceptID, loss, accuracy))

                row = []
                row.append(r)
                row.append(globalModelID)
                row.append(conceptID)
                row.append(loss)
                row.append(accuracy)
                row.append(len(y_test))
                writer.writerow(row)
