from copy import deepcopy
import torch
import numpy as np
import sys

class Benchmark_Fed_Pme():
    
    def __init__(self, num_nodes, near_center_idx, far_center_idx, loss_function, train_size, val_size, test_size):
        
        self.num_nodes = num_nodes
        self.near_center_idx = near_center_idx
        self.far_center_idx = far_center_idx
        self.loss_function = loss_function
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
    
    # initialize dictionaries for storing the results
    def Initialize_Results(self):
        
        self.metric_dict = {}

        # self.metric_dict["near"] = { "num_points": [], "num_syn": [], "top1_train_acc": [], "top2_train_acc": [],
        self.metric_dict["near"] = { "num_points": [], "num_syn": [],
                                     "near_val_loss": [], "near_top1_val_acc": [], "near_top2_val_acc": [],
                                     "near_test_loss": [], "near_top1_test_acc": [], "near_top2_test_acc": [],
                                     "far_val_loss": [], "far_top1_val_acc": [], "far_top2_val_acc": [],
                                     "far_test_loss": [], "far_top1_test_acc": [], "far_top2_test_acc": [] }
        # self.metric_dict["far"] = { "num_points": [], "num_syn": [], "top1_train_acc": [], "top2_train_acc": [],
        self.metric_dict["far"] = { "num_points": [], "num_syn": [],
                                    "near_val_loss": [], "near_top1_val_acc": [], "near_top2_val_acc": [],
                                    "near_test_loss": [], "near_top1_test_acc": [], "near_top2_test_acc": [],
                                    "far_val_loss": [], "far_top1_val_acc": [], "far_top2_val_acc": [],
                                    "far_test_loss": [], "far_top1_test_acc": [], "far_top2_test_acc": [] }
        
        self.evaluate_dict = { "near": self.near_center_idx, "far": self.far_center_idx }

        self.current_num_points = 0
        self.current_num_syn = 0

    # initilaize models with the same parameters
    def Initialize_Variables(self, model_shared_dict, model_personal_dict, train_loaders, val_loaders, test_loaders, num_batches):

        self.model_shared_dict = model_shared_dict # theta for shared models and initialize theta
        self.model_personal_dict = model_personal_dict # phi for personalized models and initialize phi
        self.current_num_syn += 1

        self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)

        self.train_size_dict = {}
        for node_idx in range(self.num_nodes):
            
            self.train_size_dict[node_idx] = 0 # prepare full_batch size for the weighted sum on mini_batches

            for step, batch_data in enumerate(train_loaders[node_idx]):      

                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                self.train_size_dict[node_idx] += int(mini_batch_size)     
        
        # divide training data in to batch tuples
        self.mini_batch_tuple_dict = {}
       
        for batch_idx in range(num_batches):
           
            self.mini_batch_tuple_dict[batch_idx] = {}
            for node_idx in range(self.num_nodes):
                self.mini_batch_tuple_dict[batch_idx][node_idx] = list(train_loaders[node_idx])[batch_idx]

    # input a gradient in vector form and return it in list form with model parameter like
    def Grad_Vec_to_List(self, model, grad_vec):
        
        pointer = 0
        grad_list = []
        for model_param in model.parameters():
            num_param = model_param.numel()
            grad_list.append(grad_vec[pointer:pointer+num_param].view_as(model_param).data)
            pointer += num_param
           
        return grad_list
    
    # gradient descent step
    def Gradient_Descent(self, node_idx, grad_list, lr, indicator):
        
        if indicator == "shared":

            model = self.model_shared_dict[node_idx]
            model_params_generator = model.parameters()
        
            model_indication = deepcopy(model)
            model_indication = model_indication.to(self.device)
        
            with torch.no_grad():
                # each parameter in the model do gradient descent
                for (idx, param) in enumerate(model_indication.parameters()):
                        param = next(model_params_generator)
                        param.data -= lr * grad_list[idx]
        
        elif indicator == "personal":

            model = self.model_personal_dict[node_idx]
            model_params_generator = model.parameters()
        
            model_indication = deepcopy(model)
            model_indication = model_indication.to(self.device)
        
            with torch.no_grad():
                # each parameter in the model do gradient descent
                for (idx, param) in enumerate(model_indication.parameters()):
                        param = next(model_params_generator)
                        param.data -= lr * grad_list[idx]

    # communicate by fed_avg algorithm
    def Fed_Avg(self):
        
        # get all model parameter generators
        model_generator_dict = {}
        for node_idx in range(self.num_nodes):
            model_generator_dict[node_idx] = self.model_shared_dict[node_idx].parameters()
        
        # get the pivot model for showing the order of model parameters
        model_fed_avg = deepcopy(self.model_shared_dict[self.near_center_idx]) 
        model_fed_avg = model_fed_avg.to(self.device)
       
        with torch.no_grad():
            # parameters in each model multiplied by the respective weight
            for (idx, param) in enumerate(model_fed_avg.parameters()):
               
                param_dict = {}
                for node_idx in range(self.num_nodes):
                    param_dict[node_idx] = next(model_generator_dict[node_idx])
               
                # model parameter in node 0 collect other parameters and get the sum
                for node_idx in range(self.num_nodes):
                    if node_idx != self.near_center_idx:
                        param_dict[self.near_center_idx].data += param_dict[node_idx].data
                
                param_dict[self.near_center_idx].data /= self.num_nodes
               
            # other nodes get the result of communication
            model_params = self.model_shared_dict[self.near_center_idx].state_dict()
           
            for node_idx in range(self.num_nodes):
                self.model_shared_dict[node_idx].load_state_dict(model_params)
                self.model_shared_dict[node_idx].zero_grad()

    # Local-SGD for updating phi (personalize models) and theta
    def Update_Shared_and_Personal_Models_SGD(self, val_loaders, test_loaders, num_epoch, num_batches, num_steps_ub, delta, communicate_period, lr_inner, lambda_val):
        # arguments: D^{train}, T_{1}, tau, eta, lambda

        step_count = 0 # count in t in { 1, 2, ..., T_{1} }
       
        for i in range(int(num_epoch*num_batches/communicate_period)): # for t = 0, 1, ..., T_{1}-1

            for j in range(communicate_period):

                step_count += 1

                batch_idx = int(step_count % num_batches)

                for node_idx in range(self.num_nodes): # for node i = 1, 2, ..., n in parallel do

                    model_shared = self.model_shared_dict[node_idx] # prepare theta 
                    model_shared.zero_grad()
                    model_shared_params_generator = model_shared.parameters()
                    model_shared_params = list(model_shared_params_generator)
                    model_shared_params_vec = torch.nn.utils.parameters_to_vector(model_shared.parameters())

                    # choose the mini_batch in the current mini_batch tuple
                    batch_data = self.mini_batch_tuple_dict[batch_idx][node_idx]
                    mini_batch_size = batch_data[0].size(dim=0)

                    inputs, labels = (
                            batch_data[0].to(self.device),
                            batch_data[1].to(self.device),
                        )
                    
                    for k in range(num_steps_ub):
                        
                        model_personal = self.model_personal_dict[node_idx] # prepare phi
                        model_personal.zero_grad()
                        model_personal_params_generator = model_personal.parameters()
                        model_personal_params = list(model_personal_params_generator)
                        model_personal_params_vec = torch.nn.utils.parameters_to_vector(model_personal.parameters())

                        # apply theta
                        outputs = model_personal(inputs)
                        loss = self.loss_function(outputs, labels)
                        
                        # prepare node mini-batch gradient in vector form (if not, will get a tuple)
                        # this is sgd-gradient
                        mini_batch_grad = torch.nn.utils.parameters_to_vector(torch.autograd.grad(loss, model_personal_params, create_graph=False))
                        if node_idx == 0:
                            self.current_num_points += mini_batch_size

                        personal_step_grad = mini_batch_grad + lambda_val * (model_personal_params_vec - model_shared_params_vec)
                        
                        if (torch.norm(input=personal_step_grad, p='fro')) ** 2 <= delta:
                            break

                        # transfer the sgd-gradient from vector form to list form with model parameter like
                        grad_list = self.Grad_Vec_to_List(model=self.model_personal_dict[node_idx], grad_vec=personal_step_grad)
                        
                        # perform gradient descent on phi
                        self.Gradient_Descent(node_idx=node_idx, grad_list=grad_list, lr=lr_inner, indicator="personal")
                    
                    model_personal = self.model_personal_dict[node_idx] # prepare phi
                    model_personal.zero_grad()
                    model_personal_params_vec = torch.nn.utils.parameters_to_vector(model_personal.parameters())

                    shared_batch_grad = model_shared_params_vec - model_personal_params_vec

                    # transfer the sgd-gradient from vector form to list form with model parameter like
                    grad_list = self.Grad_Vec_to_List(model=self.model_shared_dict[node_idx], grad_vec=shared_batch_grad)

                    # perform gradient descent on theta
                    self.Gradient_Descent(node_idx=node_idx, grad_list=grad_list, lr=lr_inner*lambda_val, indicator="shared")

            # perform fed_avg algorithm if conditions are satisfied 
            # always communicate after the last iteration  

            self.Fed_Avg()
            self.current_num_syn += 1    

            if step_count % (5*communicate_period) == 0 or (step_count+1) == num_epoch*num_batches:
                self.Evaluation(val_loaders=val_loaders, test_loaders=test_loaders)
    
    # evaluate validation and loss, and training, validation and testing accuracy 
    # def Evaluation(self, train_loaders, val_loaders, test_loaders):
    def Evaluation(self, val_loaders, test_loaders):
        
        for evaluate_type in self.evaluate_dict.keys():

            # get theta which is fixed in this method
            model = self.model_personal_dict[self.evaluate_dict[evaluate_type]]
            
            '''
            # training accuracy in bilevel problem
            # this is training accuracy in the inner level of the bilevel problem
            full_train_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
            full_top1_train_acc = 0
            full_top2_train_acc = 0

            for node_idx in range(self.num_nodes):
                    
                for step, batch_data in enumerate(train_loaders[node_idx]):        
                    
                    model.zero_grad()

                    mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                    full_train_set_size += int(mini_batch_size)

                    inputs, labels = (
                            batch_data[0].to(self.device),
                            batch_data[1].to(self.device),
                        )
                    
                    outputs = model(inputs)

                    # get predicted class labels
                    values, top1_predicts = torch.max(outputs, dim=1)
                    top1_predicts = top1_predicts.long().to(self.device)

                    mini_batch_top1_train_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                    values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                    top2_predicts = top2_predicts.long().to(self.device)
                    mini_batch_top2_train_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                            .to(self.device).nonzero().reshape(-1).size(dim=0))

                    # mini_batch loss is averaged according to this mini_batch, so need to
                    # re-assign the weight for mini_batch loss 
                    full_top1_train_acc += mini_batch_top1_train_acc 
                    full_top2_train_acc += mini_batch_top2_train_acc 

            full_top1_train_acc /= int(full_train_set_size)
            full_top2_train_acc /= int(full_train_set_size)
            self.metric_dict[evaluate_type]["top1_train_acc"].append(full_top1_train_acc)
            self.metric_dict[evaluate_type]["top2_train_acc"].append(full_top2_train_acc)
            print("full_top1_train_acc: ", full_top1_train_acc)
            print("full_top2_train_acc: ", full_top2_train_acc)
            print("-" * 30)     
            '''

            # validation loss and validation accuracy in local train
            near_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
            near_full_val_loss = 0
            near_full_top1_val_acc = 0
            near_full_top2_val_acc = 0

            for step, batch_data in enumerate(val_loaders[self.near_center_idx]):        
                
                model.zero_grad()

                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                near_full_val_set_size += int(mini_batch_size)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                
                outputs = model(inputs)

                # get mini_batch loss
                mini_batch_val_loss = self.loss_function(outputs, labels)
                mini_batch_val_loss = mini_batch_val_loss.item()
                
                # get predicted class labels
                values, top1_predicts = torch.max(outputs, dim=1)
                top1_predicts = top1_predicts.long().to(self.device)

                mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                top2_predicts = top2_predicts.long().to(self.device)
                mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                        .to(self.device).nonzero().reshape(-1).size(dim=0))

                # mini_batch loss is averaged according to this mini_batch, so need to
                # re-assign the weight for mini_batch loss 
                near_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
                near_full_top1_val_acc += mini_batch_top1_val_acc 
                near_full_top2_val_acc += mini_batch_top2_val_acc 

            near_full_val_loss /= int(near_full_val_set_size)
            near_full_top1_val_acc /= int(near_full_val_set_size)
            near_full_top2_val_acc /= int(near_full_val_set_size)
            self.metric_dict[evaluate_type]["near_val_loss"].append(near_full_val_loss)
            self.metric_dict[evaluate_type]["near_top1_val_acc"].append(near_full_top1_val_acc)
            self.metric_dict[evaluate_type]["near_top2_val_acc"].append(near_full_top2_val_acc)
            print("near_full_val_loss: ", near_full_val_loss) 
            print("near_full_top1_val_acc: ", near_full_top1_val_acc)
            print("near_full_top2_val_acc: ", near_full_top2_val_acc)
            print("-" * 30)
            
            far_full_val_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
            far_full_val_loss = 0
            far_full_top1_val_acc = 0
            far_full_top2_val_acc = 0

            for step, batch_data in enumerate(val_loaders[self.far_center_idx]):        
                
                model.zero_grad()

                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                far_full_val_set_size += int(mini_batch_size)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                
                outputs = model(inputs)

                # get mini_batch loss
                mini_batch_val_loss = self.loss_function(outputs, labels)
                mini_batch_val_loss = mini_batch_val_loss.item()
                
                # get predicted class labels
                values, top1_predicts = torch.max(outputs, dim=1)
                top1_predicts = top1_predicts.long().to(self.device)

                mini_batch_top1_val_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                top2_predicts = top2_predicts.long().to(self.device)
                mini_batch_top2_val_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                        .to(self.device).nonzero().reshape(-1).size(dim=0))

                # mini_batch loss is averaged according to this mini_batch, so need to
                # re-assign the weight for mini_batch loss 
                far_full_val_loss += mini_batch_val_loss * int(mini_batch_size)
                far_full_top1_val_acc += mini_batch_top1_val_acc 
                far_full_top2_val_acc += mini_batch_top2_val_acc 

            far_full_val_loss /= int(far_full_val_set_size)
            far_full_top1_val_acc /= int(far_full_val_set_size)
            far_full_top2_val_acc /= int(far_full_val_set_size)
            self.metric_dict[evaluate_type]["far_val_loss"].append(far_full_val_loss)
            self.metric_dict[evaluate_type]["far_top1_val_acc"].append(far_full_top1_val_acc)
            self.metric_dict[evaluate_type]["far_top2_val_acc"].append(far_full_top2_val_acc)
            print("far_full_val_loss: ", far_full_val_loss) 
            print("far_full_top1_val_acc: ", far_full_top1_val_acc)
            print("far_full_top2_val_acc: ", far_full_top2_val_acc)
            print("-" * 30)

            # testing loss and testing accuracy in bilevel problem
            near_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
            near_full_test_loss = 0
            near_full_top1_test_acc = 0
            near_full_top2_test_acc = 0

            for step, batch_data in enumerate(test_loaders[self.near_center_idx]):        
                
                model.zero_grad()
                
                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                near_full_test_set_size += int(mini_batch_size)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                
                outputs = model(inputs)
                # get mini_batch loss
                mini_batch_test_loss = self.loss_function(outputs, labels)
                mini_batch_test_loss = mini_batch_test_loss.item()
                
                # get predicted class labels
                values, top1_predicts = torch.max(outputs, dim=1)
                top1_predicts = top1_predicts.long().to(self.device)

                mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                top2_predicts = top2_predicts.long().to(self.device)
                mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                        .to(self.device).nonzero().reshape(-1).size(dim=0))

                # mini_batch loss is averaged according to this mini_batch, so need to
                # re-assign the weight for mini_batch loss 
                near_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
                near_full_top1_test_acc += mini_batch_top1_test_acc 
                near_full_top2_test_acc += mini_batch_top2_test_acc 

            near_full_test_loss /= int(near_full_test_set_size)
            near_full_top1_test_acc /= int(near_full_test_set_size)
            near_full_top2_test_acc /= int(near_full_test_set_size)
            self.metric_dict[evaluate_type]["near_test_loss"].append(near_full_test_loss)
            self.metric_dict[evaluate_type]["near_top1_test_acc"].append(near_full_top1_test_acc)
            self.metric_dict[evaluate_type]["near_top2_test_acc"].append(near_full_top2_test_acc)
            print("near_full_test_loss: ", near_full_test_loss) 
            print("near_full_top1_test_acc: ", near_full_top1_test_acc)
            print("near_full_top2_test_acc: ", near_full_top2_test_acc)
            print("-" * 30)
            
            far_full_test_set_size = 0 # prepare full_batch size for the weighted sum on mini_batches
            far_full_test_loss = 0
            far_full_top1_test_acc = 0
            far_full_top2_test_acc = 0

            for step, batch_data in enumerate(test_loaders[self.far_center_idx]):        
                
                model.zero_grad()
                
                mini_batch_size = batch_data[0].size(dim=0) # get mini_batch size
                far_full_test_set_size += int(mini_batch_size)

                inputs, labels = (
                        batch_data[0].to(self.device),
                        batch_data[1].to(self.device),
                    )
                
                outputs = model(inputs)
                # get mini_batch loss
                mini_batch_test_loss = self.loss_function(outputs, labels)
                mini_batch_test_loss = mini_batch_test_loss.item()
                
                # get predicted class labels
                values, top1_predicts = torch.max(outputs, dim=1)
                top1_predicts = top1_predicts.long().to(self.device)

                mini_batch_top1_test_acc = int((top1_predicts == labels).to(self.device).nonzero().reshape(-1).size(dim=0))

                values, top2_predicts = torch.topk(outputs, k=2, dim=1)
                top2_predicts = top2_predicts.long().to(self.device)
                mini_batch_top2_test_acc = int(torch.tensor([labels[i] in top2_predicts[i] for i in range(top2_predicts.size(dim=0))])\
                                        .to(self.device).nonzero().reshape(-1).size(dim=0))

                # mini_batch loss is averaged according to this mini_batch, so need to
                # re-assign the weight for mini_batch loss 
                far_full_test_loss += mini_batch_test_loss * int(mini_batch_size)
                far_full_top1_test_acc += mini_batch_top1_test_acc 
                far_full_top2_test_acc += mini_batch_top2_test_acc 

            far_full_test_loss /= int(far_full_test_set_size)
            far_full_top1_test_acc /= int(far_full_test_set_size)
            far_full_top2_test_acc /= int(far_full_test_set_size)
            self.metric_dict[evaluate_type]["far_test_loss"].append(far_full_test_loss)
            self.metric_dict[evaluate_type]["far_top1_test_acc"].append(far_full_top1_test_acc)
            self.metric_dict[evaluate_type]["far_top2_test_acc"].append(far_full_top2_test_acc)
            print("far_full_test_loss: ", far_full_test_loss) 
            print("far_full_top1_test_acc: ", far_full_top1_test_acc)
            print("far_full_top2_test_acc: ", far_full_top2_test_acc)
            print("-" * 30)

            self.metric_dict[evaluate_type]["num_points"].append(self.current_num_points)
            self.metric_dict[evaluate_type]["num_syn"].append(self.current_num_syn)
    
    # return all evaluation results after all stages
    def Output_Results(self):
        
        return self.metric_dict