from models import multi_concept_model
import random
import numpy as np
import csv
import tensorflow as tf
import trainer
import ewc

EWC = False
# EWC requires to adjust the loss function after training with each concept. However, recompling with the optimizer (adam) will also cause some problems in the training.  Have to save the optimizer and load it. 

class LogItems:
    def __init__(self, cid):
        self.cid = cid

        self.cur_concept_list = []
        self.min_loss_concept_list = []
        self.concept_accuracy_list = []
        self.concept_loss_list = []
        self.test_accuracy = []
        self.test_loss = []
        
        self.column_names = ["round", "data_concept_id", "model_id", "concept_accuracy", "concept_loss", 'acc','loss']

class Client:
    def __init__(self, cid, concepts, local_epoch, batch_size):
        self.cid = cid
        self.concepts = concepts
        self.n_concept = len(self.concepts)

        self.client_model = None
        self.client_weights = None 
        
        self.current_concept = -1
        self.min_loss_concept = -1
        self.cur_round = 0
        
        self.local_epoch = local_epoch
        self.batch_size = batch_size

        self.log_items = LogItems(self.cid)

    def get_cid(self):
        return self.cid

    def get_concept_name(self, r):
        return self.concepts[r]
    
    def create_model(self,data_dir):
        self.client_model = multi_concept_model()
        if EWC:
            file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "train.npy"
            data = np.load(file_path, allow_pickle=True).item()
            x_train, y_train = data['x_train'], data['y_train']
            print('Training with EWC.')   
            regularisers = []
            loss_fn = ewc.ewc_loss( self.client_model, (x_train, y_train))
            regularisers.append(loss_fn)
            trainer.compile_model(self.client_model, 0.001, extra_losses=regularisers)
        else:
            self.client_model.compile("adam", "categorical_crossentropy", metrics=["accuracy"])            
        self.client_weights = self.client_model.get_weights()

    def create_model_backup(self):
        self.client_model = multi_concept_model()
        self.client_model.compile("adam", "categorical_crossentropy", metrics=["accuracy"])

        self.client_weights = self.client_model.get_weights()
    
    def set_new_round(self):
        self.cur_round = self.cur_round + 1
        self.current_concept = random.randint(1, self.n_concept) - 1
        self.log_items.cur_concept_list.append(self.current_concept)

    def use_model0(self):
        self.min_loss_concept = 0

    def concept_test(self, server_model_list, data_dir):
        loss_min =  float('inf')

        file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "train.npy"
        data = np.load(file_path, allow_pickle=True).item()
        x_test, y_test = data['x_train'], data['y_train']

        concept_accuracy_for_cur_round = [0] * self.n_concept
        concept_loss_for_cur_round = [0] * self.n_concept
        for i in range(self.n_concept):
            m = server_model_list[i]
            loss, accuracy = m.evaluate(x_test, y_test,verbose=0)
            # print ("Concept %s accuracy %s."%(i, accuracy))
            concept_accuracy_for_cur_round[i] = accuracy
            concept_loss_for_cur_round[i] = loss

            if loss < loss_min:
                loss_min = loss
                self.min_loss_concept = i

        self.log_items.min_loss_concept_list.append(self.min_loss_concept)
        self.log_items.concept_accuracy_list.append(concept_accuracy_for_cur_round)
        self.log_items.concept_loss_list.append(concept_loss_for_cur_round)

    def fine_tune(self, server_model_list, data_dir):
        print('Client %s trains model %s with data %s'%(self.cid, self.min_loss_concept, self.current_concept)) 
        file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "train.npy"
        data = np.load(file_path, allow_pickle=True).item()
        x_train, y_train = data['x_train'], data['y_train']
        # print(y_train.shape[0])
        index = np.random.choice(x_train.shape[0], 320, replace=False)  
        x_train = x_train[index]
        y_train = y_train[index]
        # print(y_train.shape[0])
        val_file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "valid.npy"
        val_data = np.load(val_file_path, allow_pickle=True).item()
        x_val, y_val = val_data['x_valid'], val_data['y_valid']
        callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
        self.client_model = server_model_list[self.min_loss_concept]
        # if EWC:
        #     # for i in range(20):
        #     #     trainer.train_epoch(self.client_model, x_train, y_train, self.batch_size)
        # else:            
        self.client_model.fit(x_train, y_train,  validation_data=(x_val, y_val), batch_size=self.batch_size, epochs=self.local_epoch, callbacks=[callback], verbose=0)
        test_file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "test.npy"
        test_data = np.load(test_file_path, allow_pickle=True).item()
        x_test, y_test = test_data['x_test'], test_data['y_test']        
        loss, accuracy = self.client_model.evaluate(x_test, y_test,verbose=0)
        print("Client %s on local data %s: loss %s, accuracy %s"%(self.cid, self.current_concept, loss, accuracy))
        self.log_items.test_accuracy.append(accuracy)
        self.log_items.test_loss.append(loss)
        # self.client_model.fit(x_train, y_train, batch_size=self.batch_size, epochs=self.local_epoch, verbose=1)
        self.client_weights = self.client_model.get_weights()
    
    def fine_tune_backup(self, server_model_list, data_dir):
        print('Client %s trains model %s with data %s'%(self.cid, self.min_loss_concept, self.current_concept)) 
        file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "train.npy"
        data = np.load(file_path, allow_pickle=True).item()
        x_train, y_train = data['x_train'], data['y_train']
        val_file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "valid.npy"
        val_data = np.load(val_file_path, allow_pickle=True).item()
        x_val, y_val = val_data['x_valid'], val_data['y_valid']
        callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
        self.client_model = server_model_list[self.min_loss_concept]       
        self.client_model.fit(x_train, y_train,  validation_data=(x_val, y_val), batch_size=self.batch_size, epochs=self.local_epoch, callbacks=[callback], verbose=0)
        test_file_path = data_dir + "/" + "cid_" + str(self.cid) + "_" + self.get_concept_name(self.current_concept) + "_" + "test.npy"
        test_data = np.load(test_file_path, allow_pickle=True).item()
        x_test, y_test = test_data['x_test'], test_data['y_test']        
        loss, accuracy = self.client_model.evaluate(x_test, y_test,verbose=0)
        print("Client %s on local data %s: loss %s, accuracy %s"%(self.cid, self.current_concept, loss, accuracy))
        self.log_items.test_accuracy.append(accuracy)
        self.log_items.test_loss.append(loss)
        # self.client_model.fit(x_train, y_train, batch_size=self.batch_size, epochs=self.local_epoch, verbose=1)
        self.client_weights = self.client_model.get_weights()

    def get_min_loss_concept(self):
        return self.min_loss_concept
    
    def get_data_concept(self):
        return self.current_concept

    def get_weights(self):
        return self.client_weights

    def log_result(self, log_dir):
        file_path = log_dir + "/" + "cid_" + str(self.cid) + "_results.csv"
        with open(file_path, "w", newline="") as file:
            writer = csv.writer(file)

            writer.writerow(self.log_items.column_names)
            for i in range(self.cur_round):
                row = []
                row.append(i)
                row.append(self.log_items.cur_concept_list[i])
                row.append(self.log_items.min_loss_concept_list[i])
                row.append(self.log_items.concept_accuracy_list[i])
                row.append(self.log_items.concept_loss_list[i])
                row.append(self.log_items.test_accuracy[i])
                row.append(self.log_items.test_loss[i])
                writer.writerow(row)