import utils, clustering
import numpy as np
import pickle
import random
import csv
from models import multi_concept_model
from sklearn.metrics.pairwise import cosine_similarity

class LogItems:
    def __init__(self):
        self.user_concepts_list = []
        self.user_clusters_list = []
        self.cluster_match = []
        self.concept_distance_list = []
       
        self.column_names = ["round", "user_concepts", "user_clusters", "cluster_match", "concept_distance"]

class Server:
    def __init__(self, n_user, concepts, cluster_algo, distance_method, backup_freq, model_backup_path):
        self.concepts = concepts
        self.n_concept = len(self.concepts)
        self.n_user = n_user

        self.cluster_algo = cluster_algo
        self.distance_method = distance_method

        self.server_weights_list = []
        self.server_model_list = []

        self.cur_round = 0
        self.backup_freq = backup_freq
        self.model_backup_path = model_backup_path

        self.dist = [float('inf')] * self.n_concept    
        self.clock = [0] * self.n_concept 

        self.log_items = LogItems() 

    def create_models(self):
        for i in range(self.n_concept):
            s = random.randint(0,100) 
            m = multi_concept_model(model_seed=s)
            m.compile("adam", "categorical_crossentropy", metrics=["accuracy"])
            self.server_model_list.append(m)
            self.server_weights_list.append(m.get_weights())

    def create_single_model(self):
        s = random.randint(0,100) 
        m = multi_concept_model(model_seed=s)
        m.compile("adam", "categorical_crossentropy", metrics=["accuracy"])
        self.server_model_list.append(m)
        self.server_weights_list.append(m.get_weights())
    
    def get_model_list(self):
        return self.server_model_list
    
    def get_model_weights_list(self):
        return self.server_weights_list
    
    def chebyshev(self, a, b):
        return max(abs(val1-val2) for val1, val2 in zip(a,b))

    def manhattan(self, a, b):
        return sum(abs(val1-val2) for val1, val2 in zip(a,b))
    
    def euclidean(self, a, b):
        return sum((val1-val2)**2 for val1, val2 in zip(a,b)) ** 0.5
    
    def calculate_distance(self, a, b):
        if self.distance_method == "m":
            return self.manhattan(np.concatenate(a),np.concatenate(b))
        elif self.distance_method == "e":
            return self.euclidean(np.concatenate(a),np.concatenate(b))
        elif self.distance_method == "c":
            # return cosine_similarity(np.concatenate(a).reshape(1,-1),np.concatenate(b).reshape(1,-1))
            return self.chebyshev(np.concatenate(a),np.concatenate(b))

    def distance_match_backup(self, new_concept_weights, server_weights_list,dist):
        # dist is a list of distance from previous round. The values in it must be smaller over rounds.
        l1 = []
        test_dist = []
        # print(new_concept_weights[len( new_concept_weights)-1].shape)
        # l1.append(new_concept_weights[len( new_concept_weights)-1].ravel())
        for l in range(len( new_concept_weights)):
            l1.append(new_concept_weights[l].ravel())
        max_s =  float('inf')
        match = 0
        found = False
        for k in range(self.n_concept):
            l2=[]
            # for l in range():
            # l2.append(server_weights_list[k][len(new_concept_weights)-1].ravel())
            for l in range(len( new_concept_weights)):
                l2.append(server_weights_list[k][l].ravel())
    #                 print(new_concept_weights[l].shape)
    #                     print(np.concatenate(l1).shape) 
    #                     pca1=  pca.fit_transform(l1)
    #                     pca2=  pca.fit_transform(l2) 
            tmp = self.calculate_distance(l1, l2)

    #                     tmp = cosine_similarity(pca1, pca2)
            print('Concept %s %s distance %s.'%(k, self.distance_method, tmp))
            if tmp < max_s and tmp<dist[k]:
                max_s = tmp
                match = k
                found = True
            test_dist.append(tmp)
        if found == False:
            return None, test_dist
        return match, test_dist

    
    def distance_match(self, new_concept_weights, server_weights_list,dist):
        # dist is a list of distance from previous round. The values in it must be smaller over rounds.
        print(dist)
        l1 = []
        # print(new_concept_weights[len( new_concept_weights)-1].shape)
        # l1.append(new_concept_weights[len( new_concept_weights)-1].ravel())
        for l in range(len( new_concept_weights)):
            l1.append(new_concept_weights[l].ravel())
        min_s =  float('inf')
        match = 0
        found = False
        for k in range(self.n_concept):
            l2=[]
            # for l in range():
            # l2.append(server_weights_list[k][len(new_concept_weights)-1].ravel())
            for l in range(len( new_concept_weights)):
                l2.append(server_weights_list[k][l].ravel())
    #                 print(new_concept_weights[l].shape)
    #                     print(np.concatenate(l1).shape) 
    #                     pca1=  pca.fit_transform(l1)
    #                     pca2=  pca.fit_transform(l2) 
            tmp = self.calculate_distance(l1, l2)

    #                     tmp = cosine_similarity(pca1, pca2)
            print('Concept %s %s distance %s.'%(k, self.distance_method, tmp))
            if  tmp < min_s and tmp<dist[k]:
                min_s = tmp
                match = k
                found = True
        if found == False:
            return None, min_s 
        return match, min_s 

    def cosine_match(self, new_concept_weights, server_weights_list):
        l1 = []
        # print(new_concept_weights[len( new_concept_weights)-1].shape)
        # l1.append(new_concept_weights[len( new_concept_weights)-1].ravel())
        for l in range(len( new_concept_weights)):
            l1.append(new_concept_weights[l].ravel())
        max_s = -1
        match = 0
        for k in range(self.n_concept):
            l2=[]
            # for l in range():
            # l2.append(server_weights_list[k][len(new_concept_weights)-1].ravel())
            for l in range(len( new_concept_weights)):
                l2.append(server_weights_list[k][l].ravel())
    #                 print(new_concept_weights[l].shape)
    #                     print(np.concatenate(l1).shape) 
    #                     pca1=  pca.fit_transform(l1)
    #                     pca2=  pca.fit_transform(l2)        
            tmp = self.cosine_similarity(np.concatenate(l1).reshape(1,-1), np.concatenate(l2).reshape(1,-1))
    #                     tmp = cosine_similarity(pca1, pca2)
            print('Concept %s similairty %s.'%(k,tmp))
            if tmp > max_s:
                max_s = tmp
                match = k
        return match

    def process_new_round(self, client_weights_list, user_concepts):

        user_clusters = clustering.run(client_weights_list, self.n_concept, self.cluster_algo)

        if self.cluster_algo=='dbscan' or self.cluster_algo=='optics':
            if -1 in user_clusters: 
                user_clusters = [x + 1 for x in user_clusters]
        print('User concepts %s.'%user_concepts)
        print('User clusters %s.'%user_clusters)

        self.log_items.user_clusters_list.append(user_clusters)
        self.log_items.user_concepts_list.append(user_concepts)
        match_list = []
        new_dist_list = []
        cluster_weights_dict={}
        for i in range(self.n_concept):
            cluster_weights_dict[i] = []
        
        for i in range(self.n_concept):
            cluster_client_ids = []
            for j in range(self.n_user):
                if user_clusters[j]==i:
                    cluster_client_ids.append(j)
            if cluster_client_ids:
                new_concept_weights = utils.fed_avg(client_weights_list, cluster_client_ids)

                if self.cur_round !=0:
                    match,new_dist = self.distance_match_backup(new_concept_weights, self.server_weights_list, self.dist)
                    new_dist_list.append(new_dist)
    # Failsafe, but never happened in experiments. If no matching was found, the cluster model is assigned to least updated model. 
                    if match is None:
                        print("No match concept found.")
                        min_value = min(self.clock)
                        match=self.clock.index(min_value)   
                    self.clock[match] = self.clock[match]+1
                    match_list.append(match)
                    print('Assign with concept %s'%(match)) 
                    cluster_weights_dict[match].append(new_concept_weights)
                else:
                    self.server_weights_list[i] = new_concept_weights
        #             print (cosine_similarity(server_weights_list[i], new_concept_weights))
                    self.server_model_list[i].set_weights(self.server_weights_list[i])
        
        self.log_items.cluster_match.append(match_list)
        self.log_items.concept_distance_list.append(new_dist_list)

        if self.cur_round!=0:
            for i in range(self.n_concept):
                if cluster_weights_dict[i]:
                # Failsafe, in case multiple cluster models match a concept model, average them, but never happened in experiments. 
                    avg_weights = utils.weight_list_avg(cluster_weights_dict[i])
                    self.server_weights_list[i] = avg_weights
                    self.server_model_list[i].set_weights(self.server_weights_list[i])
                max_dist = 0
                for j in range(len(new_dist_list)):
                    if new_dist_list[j][i] > max_dist:
                        self.dist[i] =  new_dist_list[j][i]
                        max_dist = new_dist_list[j][i]
                
        if self.cur_round % self.backup_freq == 0:
            geeky_file = open('%s/server_weights_%s'%(self.model_backup_path,self.cur_round) , 'wb')
            pickle.dump(self.server_weights_list, geeky_file)
            geeky_file.close()

        self.cur_round = self.cur_round + 1

    def process_new_round_backup(self, client_weights_list, user_concepts):

        user_clusters = clustering.run(client_weights_list, self.n_concept, self.cluster_algo)

        print('User concepts %s.'%user_concepts)
        print('User clusters %s.'%user_clusters)

        self.log_items.user_clusters_list.append(user_clusters)
        self.log_items.user_concepts_list.append(user_concepts)
    
        new_dist_list = []
        match_list = []
        for i in range(self.n_concept):
            cluster_client_ids = []
            for j in range(self.n_user):
                if user_clusters[j]==i:
                    cluster_client_ids.append(j)
            if cluster_client_ids:
                new_concept_weights = utils.fed_avg(client_weights_list, cluster_client_ids)
                if self.cur_round !=0:
                    match, new_dist= self.distance_match(new_concept_weights, self.server_weights_list, self.dist)
    # If no matching was found. The cluster model is assigned to least updated model. 
                    if match is None:
                        print("No match concept found.")
                        min_value = min(self.clock)
                        match=self.clock.index(min_value)   
                    self.clock[match] = self.clock[match]+1
                    match_list.append(match)
                    print('Assign with concept %s'%(match)) 
                    new_dist_list.append(new_dist)
                    self.server_weights_list[match] = new_concept_weights
                    self.server_model_list[match].set_weights(self.server_weights_list[match])
                    self.dist[match] = new_dist
                else:
                    self.server_weights_list[i] = new_concept_weights
                    self.server_model_list[i].set_weights(self.server_weights_list[i])
        self.log_items.cluster_match.append(match_list)
        self.log_items.concept_distance_list.append(new_dist_list)
                
        if self.cur_round % self.backup_freq == 0:
            geeky_file = open('%s/server_weights_%s'%(self.model_backup_path,self.cur_round) , 'wb')
            pickle.dump(self.server_weights_list, geeky_file)
            geeky_file.close()

        
        self.cur_round = self.cur_round + 1

    def vanilla_new_round(self, client_weights_list):
        client_ids = []
        agg_error = False
        for j in range(self.n_user):
            client_ids.append(j)
        if client_ids:
            new_weights = utils.fed_avg(client_weights_list, client_ids)
            for layer_weights in new_weights:
                if utils.nan_in_list(layer_weights):
                    agg_error = True
            if agg_error == False:
                self.server_weights_list[0]=new_weights
                self.server_model_list[0].set_weights(self.server_weights_list[0])    
            else:
                print('Nan in aggregated global weights.')                
        if self.cur_round % self.backup_freq == 0:
            geeky_file = open('%s/server_weights_%s'%(self.model_backup_path,self.cur_round) , 'wb')
            pickle.dump(self.server_weights_list, geeky_file)
            geeky_file.close()
        
        self.cur_round = self.cur_round + 1

    def log_result(self, log_dir):
        file_path = log_dir + "/" + "server_results.csv"
        with open(file_path, "w", newline="") as file:
            writer = csv.writer(file)

            writer.writerow(self.log_items.column_names)
            for i in range(self.cur_round):
                row = []
                row.append(i)
                row.append(self.log_items.user_concepts_list[i])
                row.append(self.log_items.user_clusters_list[i])
                row.append(self.log_items.cluster_match[i])
                row.append(self.log_items.concept_distance_list[i])
                writer.writerow(row)